From 02cb64d3d9a83a43754d7bf977a54260b5820e63 Mon Sep 17 00:00:00 2001 From: Azareal Date: Mon, 28 Oct 2019 09:53:16 +1000 Subject: [PATCH] Fix attachment parsing. Shorten some things to reduce boilerplate. Save some allocations. --- common/misc_logs.go | 22 +++++++++++--------- common/parser.go | 47 +++++++++++++++++++++---------------------- common/permissions.go | 2 +- common/pluginlangs.go | 10 ++++++--- 4 files changed, 44 insertions(+), 37 deletions(-) diff --git a/common/misc_logs.go b/common/misc_logs.go index 36ca9b57..875d75c3 100644 --- a/common/misc_logs.go +++ b/common/misc_logs.go @@ -29,9 +29,10 @@ var regLogStmts RegLogStmts func init() { DbInits.Add(func(acc *qgen.Accumulator) error { + rl := "registration_logs" regLogStmts = RegLogStmts{ - update: acc.Update("registration_logs").Set("username = ?, email = ?, failureReason = ?, success = ?").Where("rlid = ?").Prepare(), - create: acc.Insert("registration_logs").Columns("username, email, failureReason, success, ipaddress, doneAt").Fields("?,?,?,?,?,UTC_TIMESTAMP()").Prepare(), + update: acc.Update(rl).Set("username = ?, email = ?, failureReason = ?, success = ?").Where("rlid = ?").Prepare(), + create: acc.Insert(rl).Columns("username, email, failureReason, success, ipaddress, doneAt").Fields("?,?,?,?,?,UTC_TIMESTAMP()").Prepare(), } return acc.FirstError() }) @@ -65,9 +66,10 @@ type SQLRegLogStore struct { } func NewRegLogStore(acc *qgen.Accumulator) (*SQLRegLogStore, error) { + rl := "registration_logs" return &SQLRegLogStore{ - count: acc.Count("registration_logs").Prepare(), - getOffset: acc.Select("registration_logs").Columns("rlid, username, email, failureReason, success, ipaddress, doneAt").Orderby("doneAt DESC").Limit("?,?").Prepare(), + count: acc.Count(rl).Prepare(), + getOffset: acc.Select(rl).Columns("rlid, username, email, failureReason, success, ipaddress, doneAt").Orderby("doneAt DESC").Limit("?,?").Prepare(), }, acc.FirstError() } @@ -116,9 +118,10 @@ var loginLogStmts LoginLogStmts func init() { DbInits.Add(func(acc *qgen.Accumulator) error { + ll := "login_logs" loginLogStmts = LoginLogStmts{ - update: acc.Update("login_logs").Set("uid = ?, success = ?").Where("lid = ?").Prepare(), - create: acc.Insert("login_logs").Columns("uid, success, ipaddress, doneAt").Fields("?,?,?,UTC_TIMESTAMP()").Prepare(), + update: acc.Update(ll).Set("uid = ?, success = ?").Where("lid = ?").Prepare(), + create: acc.Insert(ll).Columns("uid, success, ipaddress, doneAt").Fields("?,?,?,UTC_TIMESTAMP()").Prepare(), } return acc.FirstError() }) @@ -154,10 +157,11 @@ type SQLLoginLogStore struct { } func NewLoginLogStore(acc *qgen.Accumulator) (*SQLLoginLogStore, error) { + ll := "login_logs" return &SQLLoginLogStore{ - count: acc.Count("login_logs").Prepare(), - countForUser: acc.Count("login_logs").Where("uid = ?").Prepare(), - getOffsetByUser: acc.Select("login_logs").Columns("lid, success, ipaddress, doneAt").Where("uid = ?").Orderby("doneAt DESC").Limit("?,?").Prepare(), + count: acc.Count(ll).Prepare(), + countForUser: acc.Count(ll).Where("uid = ?").Prepare(), + getOffsetByUser: acc.Select(ll).Columns("lid, success, ipaddress, doneAt").Where("uid = ?").Orderby("doneAt DESC").Limit("?,?").Prepare(), }, acc.FirstError() } diff --git a/common/parser.go b/common/parser.go index 3dfa0efe..ee8eda4c 100644 --- a/common/parser.go +++ b/common/parser.go @@ -701,8 +701,8 @@ func validateURLString(data string) bool { // ? - There should only be one : and that's only if the URL is on a non-standard port. Same for ?s. for ; len(data) > i; i++ { - char := data[i] - if char != '\\' && char != '_' && char != ':' && char != '?' && char != '&' && char != '=' && char != ';' && char != '@' && char != '#' && char != ']' && !(char > 44 && char < 58) && !(char > 64 && char < 92) && !(char > 96 && char < 123) { // 90 is Z, 91 is [ + ch := data[i] // char + if ch != '\\' && ch != '_' && ch != ':' && ch != '?' && ch != '&' && ch != '=' && ch != ';' && ch != '@' && ch != '#' && ch != ']' && !(ch > 44 && ch < 58) && !(ch > 64 && ch < 92) && !(ch > 96 && ch < 123) { // 90 is Z, 91 is [ return false } } @@ -727,8 +727,8 @@ func validatedURLBytes(data []byte) (url []byte) { // ? - There should only be one : and that's only if the URL is on a non-standard port. Same for ?s. for ; datalen > i; i++ { - char := data[i] - if char != '\\' && char != '_' && char != ':' && char != '?' && char != '&' && char != '=' && char != ';' && char != '@' && char != '#' && char != ']' && !(char > 44 && char < 58) && !(char > 64 && char < 92) && !(char > 96 && char < 123) { // 90 is Z, 91 is [ + ch := data[i] //char + if ch != '\\' && ch != '_' && ch != ':' && ch != '?' && ch != '&' && ch != '=' && ch != ';' && ch != '@' && ch != '#' && ch != ']' && !(ch > 44 && ch < 58) && !(ch > 64 && ch < 92) && !(ch > 96 && ch < 123) { // 90 is Z, 91 is [ return InvalidURL } } @@ -755,8 +755,8 @@ func PartialURLString(data string) (url []byte) { // ? - There should only be one : and that's only if the URL is on a non-standard port. Same for ?s. for ; end >= i; i++ { - char := data[i] - if char != '\\' && char != '_' && char != ':' && char != '?' && char != '&' && char != '=' && char != ';' && char != '@' && char != '#' && char != ']' && !(char > 44 && char < 58) && !(char > 64 && char < 92) && !(char > 96 && char < 123) { // 90 is Z, 91 is [ + ch := data[i] // char + if ch != '\\' && ch != '_' && ch != ':' && ch != '?' && ch != '&' && ch != '=' && ch != ';' && ch != '@' && ch != '#' && ch != ']' && !(ch > 44 && ch < 58) && !(ch > 64 && ch < 92) && !(ch > 96 && ch < 123) { // 90 is Z, 91 is [ end = i } } @@ -791,12 +791,12 @@ func PartialURLStringLen(data string) (int, bool) { f := i //fmt.Println("f:",f) for ; len(data) > i; i++ { - char := data[i] - if char < 33 { // space and invisibles + ch := data[i] //char + if ch < 33 { // space and invisibles //fmt.Println("e2:",i) return i, i != f - } else if char != '\\' && char != '_' && char != ':' && char != '?' && char != '&' && char != '=' && char != ';' && char != '@' && char != '#' && char != ']' && !(char > 44 && char < 58) && !(char > 64 && char < 92) && !(char > 96 && char < 123) { // 90 is Z, 91 is [ - //log.Print("Bad Character: ", char) + } else if ch != '\\' && ch != '_' && ch != ':' && ch != '?' && ch != '&' && ch != '=' && ch != ';' && ch != '@' && ch != '#' && ch != ']' && !(ch > 44 && ch < 58) && !(ch > 64 && ch < 92) && !(ch > 96 && ch < 123) { // 90 is Z, 91 is [ + //log.Print("Bad Character: ", ch) //fmt.Println("e3") return i, false } @@ -829,9 +829,9 @@ func PartialURLStringLen2(data string) int { // ? - There should only be one : and that's only if the URL is on a non-standard port. Same for ?s. for ; len(data) > i; i++ { - char := data[i] - if char != '\\' && char != '_' && char != ':' && char != '?' && char != '&' && char != '=' && char != ';' && char != '@' && char != '#' && !(char > 44 && char < 58) && !(char > 64 && char < 91) && !(char > 96 && char < 123) { // 90 is Z, 91 is [ - //log.Print("Bad Character: ", char) + ch := data[i] //char + if ch != '\\' && ch != '_' && ch != ':' && ch != '?' && ch != '&' && ch != '=' && ch != ';' && ch != '@' && ch != '#' && !(ch > 44 && ch < 58) && !(ch > 64 && ch < 91) && !(ch > 96 && ch < 123) { // 90 is Z, 91 is [ + //log.Print("Bad Character: ", ch) return i } } @@ -857,20 +857,19 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if err != nil { return media, false } - - hostname := uurl.Hostname() + host := uurl.Hostname() scheme := uurl.Scheme port := uurl.Port() query, err := url.ParseQuery(uurl.RawQuery) if err != nil { return media, false } - //log.Print("hostname:",hostname) + //fmt.Println("host:", host) //log.Print("Site.URL:",Site.URL) - samesite := hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == Site.URL + samesite := host == "localhost" || host == "127.0.0.1" || host == "::1" || host == Site.URL if samesite { - hostname = strings.Split(Site.URL, ":")[0] + host = strings.Split(Site.URL, ":")[0] // ?- Test this as I'm not sure it'll do what it should. If someone's running SSL on port 80 or non-SSL on port 443 then... Well... They're in far worse trouble than this... port = Site.Port if Site.EnableSsl { @@ -885,13 +884,13 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { path := uurl.EscapedPath() pathFrags := strings.Split(path, "/") if len(pathFrags) >= 2 { - if samesite && pathFrags[1] == "attachs" && (scheme == "http" || scheme == "https") { + if samesite && pathFrags[1] == "attachs" && (scheme == "http:" || scheme == "https:") { var sport string // ? - Assumes the sysadmin hasn't mixed up the two standard ports if port != "443" && port != "80" && port != "" { sport = ":" + port } - media.URL = scheme + "//" + hostname + sport + path + media.URL = scheme + "//" + host + sport + path extarr := strings.Split(path, ".") if len(extarr) == 0 { // TODO: Write a unit test for this @@ -909,7 +908,7 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { // ? - I don't think this hostname will hit every YT domain // TODO: Make this a more customisable handler rather than hard-coding it in here - if strings.HasSuffix(hostname, ".youtube.com") && path == "/watch" { + if strings.HasSuffix(host, ".youtube.com") && path == "/watch" { video, ok := query["v"] if ok && len(video) >= 1 && video[0] != "" { media.Type = "raw" @@ -929,7 +928,7 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if port != "443" && port != "80" && port != "" { sport = ":" + port } - media.URL = scheme + "//" + hostname + sport + path + media.URL = scheme + "//" + host + sport + path return media, true } } @@ -947,8 +946,8 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if len(uurl.Fragment) > 0 { frag = "#" + uurl.Fragment } - media.URL = scheme + "//" + hostname + sport + path + q + frag - media.FURL = hostname + sport + path + q + frag + media.URL = scheme + "//" + host + sport + path + q + frag + media.FURL = host + sport + path + q + frag return media, true } diff --git a/common/permissions.go b/common/permissions.go index 90223ada..405f1d38 100644 --- a/common/permissions.go +++ b/common/permissions.go @@ -179,7 +179,7 @@ func RebuildGroupPermissions(group *Group) error { log.Print("Reloading a group") // TODO: Avoid re-initting this all the time - getGroupPerms, err := qgen.Builder.SimpleSelect("users_groups", "permissions", "gid = ?", "", "") + getGroupPerms, err := qgen.Builder.SimpleSelect("users_groups", "permissions", "gid=?", "", "") if err != nil { return err } diff --git a/common/pluginlangs.go b/common/pluginlangs.go index 509d6eca..7a8d2bab 100644 --- a/common/pluginlangs.go +++ b/common/pluginlangs.go @@ -66,14 +66,18 @@ func InitPluginLangs() error { continue } + e := func(field string, name string) error { + return errors.New("The "+field+" field must not be blank on plugin '" + name + "'") + } + if plugin.UName == "" { - return errors.New("The UName field must not be blank on plugin '" + pluginItem + "'") + return e("UName",pluginItem) } if plugin.Name == "" { - return errors.New("The Name field must not be blank on plugin '" + pluginItem + "'") + return e("Name",pluginItem) } if plugin.Author == "" { - return errors.New("The Author field must not be blank on plugin '" + pluginItem + "'") + return e("Author",pluginItem) } if plugin.Main == "" { return errors.New("Couldn't find a main file for plugin '" + pluginItem + "'")