`)
-var topic_alt_26 = []byte(`
Level `)
-var topic_alt_28 = []byte(`
`)
+var topic_alt_29 = []byte(`
Level `)
+var topic_alt_31 = []byte(`
diff --git a/common/poll_store.go b/common/poll_store.go index 25c048f5..e0c7d72f 100644 --- a/common/poll_store.go +++ b/common/poll_store.go @@ -10,8 +10,10 @@ import ( var Polls PollStore type Poll struct { - ID int - Type int // 0: Single choice, 1: Multiple choice, 2: Multiple choice w/ points + ID int + ParentID int + ParentTable string + Type int // 0: Single choice, 1: Multiple choice, 2: Multiple choice w/ points //AntiCheat bool // Apply various mitigations for cheating // GroupPower map[gid]points // The number of points a group can spend in this poll, defaults to 1 @@ -21,6 +23,10 @@ type Poll struct { VoteCount int } +func (poll *Poll) CastVote(optionIndex int, uid int, ipaddress string) error { + return Polls.CastVote(optionIndex, poll.ID, uid, ipaddress) // TODO: Move the query into a pollStmts rather than having it in the store +} + func (poll *Poll) Copy() Poll { return *poll } @@ -31,6 +37,8 @@ type PollOption struct { } type Pollable interface { + GetID() int + GetTable() string SetPoll(pollID int) error } @@ -38,6 +46,7 @@ type PollStore interface { Get(id int) (*Poll, error) Exists(id int) bool Create(parent Pollable, pollType int, pollOptions map[int]string) (int, error) + CastVote(optionIndex int, pollID int, uid int, ipaddress string) error Reload(id int) error //GlobalCount() int @@ -48,10 +57,12 @@ type PollStore interface { type DefaultPollStore struct { cache PollCache - get *sql.Stmt - exists *sql.Stmt - create *sql.Stmt - delete *sql.Stmt + get *sql.Stmt + exists *sql.Stmt + create *sql.Stmt + addVote *sql.Stmt + incrementVoteCount *sql.Stmt + delete *sql.Stmt //pollCount *sql.Stmt } @@ -62,10 +73,12 @@ func NewDefaultPollStore(cache PollCache) (*DefaultPollStore, error) { } // TODO: Add an admin version of registerStmt with more flexibility? return &DefaultPollStore{ - cache: cache, - get: acc.Select("polls").Columns("type, options, votes").Where("pollID = ?").Prepare(), - exists: acc.Select("polls").Columns("pollID").Where("pollID = ?").Prepare(), - create: acc.Insert("polls").Columns("type, options").Fields("?,?").Prepare(), + cache: cache, + get: acc.Select("polls").Columns("parentID, parentTable, type, options, votes").Where("pollID = ?").Prepare(), + exists: acc.Select("polls").Columns("pollID").Where("pollID = ?").Prepare(), + create: acc.Insert("polls").Columns("parentID, parentTable, type, options").Fields("?,?,?,?").Prepare(), + addVote: acc.Insert("polls_votes").Columns("pollID, uid, option, castAt, ipaddress").Fields("?,?,?,UTC_TIMESTAMP(),?").Prepare(), + incrementVoteCount: acc.Update("polls").Set("votes = votes + 1").Where("pollID = ?").Prepare(), //pollCount: acc.SimpleCount("polls", "", ""), }, acc.FirstError() } @@ -86,7 +99,7 @@ func (store *DefaultPollStore) Get(id int) (*Poll, error) { poll = &Poll{ID: id} var optionTxt []byte - err = store.get.QueryRow(id).Scan(&poll.Type, &optionTxt, &poll.VoteCount) + err = store.get.QueryRow(id).Scan(&poll.ParentID, &poll.ParentTable, &poll.Type, &optionTxt, &poll.VoteCount) if err != nil { return nil, err } @@ -102,7 +115,7 @@ func (store *DefaultPollStore) Get(id int) (*Poll, error) { func (store *DefaultPollStore) Reload(id int) error { poll := &Poll{ID: id} var optionTxt []byte - err := store.get.QueryRow(id).Scan(&poll.Type, &optionTxt, &poll.VoteCount) + err := store.get.QueryRow(id).Scan(&poll.ParentID, &poll.ParentTable, &poll.Type, &optionTxt, &poll.VoteCount) if err != nil { store.cache.Remove(id) return err @@ -127,13 +140,23 @@ func (store *DefaultPollStore) unpackOptionsMap(rawOptions map[int]string) []Pol return options } +// TODO: Use a transaction for this? +func (store *DefaultPollStore) CastVote(optionIndex int, pollID int, uid int, ipaddress string) error { + _, err := store.addVote.Exec(pollID, uid, optionIndex, ipaddress) + if err != nil { + return err + } + _, err = store.incrementVoteCount.Exec(pollID) + return err +} + func (store *DefaultPollStore) Create(parent Pollable, pollType int, pollOptions map[int]string) (id int, err error) { pollOptionsTxt, err := json.Marshal(pollOptions) if err != nil { return 0, err } - res, err := store.create.Exec(pollType, pollOptionsTxt) //pollOptionsTxt + res, err := store.create.Exec(parent.GetID(), parent.GetTable(), pollType, pollOptionsTxt) if err != nil { return 0, err } diff --git a/common/topic.go b/common/topic.go index cc62120c..0c2cc14f 100644 --- a/common/topic.go +++ b/common/topic.go @@ -263,6 +263,7 @@ func (topic *Topic) Update(name string, content string) error { func (topic *Topic) SetPoll(pollID int) error { _, err := topicStmts.setPoll.Exec(pollID, topic.ID) // TODO: Sniff if this changed anything to see if we hit an existing poll + topic.cacheRemove() return err } @@ -278,6 +279,13 @@ func (topic *Topic) CreateActionReply(action string, ipaddress string, user User return err } +func (topic *Topic) GetID() int { + return topic.ID +} +func (topic *Topic) GetTable() string { + return "topics" +} + // Copy gives you a non-pointer concurrency safe copy of the topic func (topic *Topic) Copy() Topic { return *topic diff --git a/gen_router.go b/gen_router.go index f717ee21..b26d7365 100644 --- a/gen_router.go +++ b/gen_router.go @@ -100,6 +100,7 @@ var RouteMap = map[string]interface{}{ "routeProfileReplyCreateSubmit": routeProfileReplyCreateSubmit, "routes.ProfileReplyEditSubmit": routes.ProfileReplyEditSubmit, "routes.ProfileReplyDeleteSubmit": routes.ProfileReplyDeleteSubmit, + "routes.PollVote": routes.PollVote, "routeLogin": routeLogin, "routeRegister": routeRegister, "routeLogout": routeLogout, @@ -196,14 +197,15 @@ var routeMapEnum = map[string]int{ "routeProfileReplyCreateSubmit": 81, "routes.ProfileReplyEditSubmit": 82, "routes.ProfileReplyDeleteSubmit": 83, - "routeLogin": 84, - "routeRegister": 85, - "routeLogout": 86, - "routeLoginSubmit": 87, - "routeRegisterSubmit": 88, - "routeDynamic": 89, - "routeUploads": 90, - "BadRoute": 91, + "routes.PollVote": 84, + "routeLogin": 85, + "routeRegister": 86, + "routeLogout": 87, + "routeLoginSubmit": 88, + "routeRegisterSubmit": 89, + "routeDynamic": 90, + "routeUploads": 91, + "BadRoute": 92, } var reverseRouteMapEnum = map[int]string{ 0: "routeAPI", @@ -290,14 +292,15 @@ var reverseRouteMapEnum = map[int]string{ 81: "routeProfileReplyCreateSubmit", 82: "routes.ProfileReplyEditSubmit", 83: "routes.ProfileReplyDeleteSubmit", - 84: "routeLogin", - 85: "routeRegister", - 86: "routeLogout", - 87: "routeLoginSubmit", - 88: "routeRegisterSubmit", - 89: "routeDynamic", - 90: "routeUploads", - 91: "BadRoute", + 84: "routes.PollVote", + 85: "routeLogin", + 86: "routeRegister", + 87: "routeLogout", + 88: "routeLoginSubmit", + 89: "routeRegisterSubmit", + 90: "routeDynamic", + 91: "routeUploads", + 92: "BadRoute", } var agentMapEnum = map[string]int{ "unknown": 0, @@ -396,27 +399,36 @@ func (router *GenRouter) RemoveFunc(pattern string) error { return nil } +func (router *GenRouter) DumpRequest(req *http.Request) { + log.Print("UA: ", req.UserAgent()) + log.Print("Method: ", req.Method) + for key, value := range req.Header { + for _, vvalue := range value { + log.Print("Header '" + key + "': " + vvalue + "!!") + } + } + log.Print("req.Host: ", req.Host) + log.Print("req.URL.Path: ", req.URL.Path) + log.Print("req.URL.RawQuery: ", req.URL.RawQuery) + log.Print("req.Referer(): ", req.Referer()) + log.Print("req.RemoteAddr: ", req.RemoteAddr) +} + +func (router *GenRouter) SuspiciousRequest(req *http.Request) { + log.Print("Supicious Request") + router.DumpRequest(req) + common.AgentViewCounter.Bump(18) +} + // TODO: Pass the default route or config struct to the router rather than accessing it via a package global // TODO: SetDefaultRoute // TODO: GetDefaultRoute - func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { if len(req.URL.Path) == 0 || req.URL.Path[0] != '/' || req.Host != common.Site.Host { - w.WriteHeader(200) // 405 + w.WriteHeader(200) // 400 w.Write([]byte("")) log.Print("Malformed Request") - log.Print("UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) + router.DumpRequest(req) common.AgentViewCounter.Bump(17) return } @@ -425,37 +437,14 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { // TODO: Cover more suspicious strings and at a lower layer than this for _, char := range req.URL.Path { if char != '&' && !(char > 44 && char < 58) && char != '=' && char != '?' && !(char > 64 && char < 91) && char != '\\' && char != '_' && !(char > 96 && char < 123) { - log.Print("Suspicious UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) - common.AgentViewCounter.Bump(18) + router.SuspiciousRequest(req) break } } + lowerPath := strings.ToLower(req.URL.Path) // TODO: Flag any requests which has a dot with anything but a number after that - if strings.Contains(req.URL.Path,"..") || strings.Contains(req.URL.Path,"--") || strings.Contains(req.URL.Path,".php") || strings.Contains(req.URL.Path,".asp") || strings.Contains(req.URL.Path,".cgi") { - log.Print("Suspicious UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) - common.AgentViewCounter.Bump(18) + if strings.Contains(req.URL.Path,"..") || strings.Contains(req.URL.Path,"--") || strings.Contains(lowerPath,".php") || strings.Contains(lowerPath,".asp") || strings.Contains(lowerPath,".cgi") || strings.Contains(lowerPath,".py") || strings.Contains(lowerPath,".sql") { + router.SuspiciousRequest(req) } } @@ -534,38 +523,13 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { common.AgentViewCounter.Bump(16) if common.Dev.DebugMode { log.Print("Blank UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("prefix: ", prefix) - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("extraData: ", extraData) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) + router.DumpRequest(req) } default: common.AgentViewCounter.Bump(0) if common.Dev.DebugMode { log.Print("Unknown UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("prefix: ", prefix) - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("extraData: ", extraData) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) + router.DumpRequest(req) } } @@ -1422,13 +1386,34 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { if err != nil { router.handleError(err,w,req,user) } + case "/poll": + switch(req.URL.Path) { + case "/poll/vote/": + err = common.NoSessionMismatch(w,req,user) + if err != nil { + router.handleError(err,w,req,user) + return + } + + err = common.MemberOnly(w,req,user) + if err != nil { + router.handleError(err,w,req,user) + return + } + + common.RouteViewCounter.Bump(84) + err = routes.PollVote(w,req,user,extraData) + } + if err != nil { + router.handleError(err,w,req,user) + } case "/accounts": switch(req.URL.Path) { case "/accounts/login/": - common.RouteViewCounter.Bump(84) + common.RouteViewCounter.Bump(85) err = routeLogin(w,req,user) case "/accounts/create/": - common.RouteViewCounter.Bump(85) + common.RouteViewCounter.Bump(86) err = routeRegister(w,req,user) case "/accounts/logout/": err = common.NoSessionMismatch(w,req,user) @@ -1443,7 +1428,7 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - common.RouteViewCounter.Bump(86) + common.RouteViewCounter.Bump(87) err = routeLogout(w,req,user) case "/accounts/login/submit/": err = common.ParseForm(w,req,user) @@ -1452,7 +1437,7 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - common.RouteViewCounter.Bump(87) + common.RouteViewCounter.Bump(88) err = routeLoginSubmit(w,req,user) case "/accounts/create/submit/": err = common.ParseForm(w,req,user) @@ -1461,7 +1446,7 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - common.RouteViewCounter.Bump(88) + common.RouteViewCounter.Bump(89) err = routeRegisterSubmit(w,req,user) } if err != nil { @@ -1478,7 +1463,7 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { common.NotFound(w,req) return } - common.RouteViewCounter.Bump(90) + common.RouteViewCounter.Bump(91) req.URL.Path += extraData // TODO: Find a way to propagate errors up from this? router.UploadHandler(w,req) // TODO: Count these views @@ -1499,7 +1484,6 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { } return*/ } - if extraData != "" { common.NotFound(w,req) return @@ -1522,7 +1506,7 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { router.RUnlock() if ok { - common.RouteViewCounter.Bump(89) // TODO: Be more specific about *which* dynamic route it is + common.RouteViewCounter.Bump(90) // TODO: Be more specific about *which* dynamic route it is req.URL.Path += extraData err = handle(w,req,user) if err != nil { @@ -1531,7 +1515,11 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - common.RouteViewCounter.Bump(91) + lowerPath := strings.ToLower(req.URL.Path) + if strings.Contains(lowerPath,"admin") || strings.Contains(lowerPath,"sql") || strings.Contains(lowerPath,"manage") { + router.SuspiciousRequest(req) + } + common.RouteViewCounter.Bump(92) common.NotFound(w,req) } } diff --git a/query_gen/tables.go b/query_gen/tables.go index 62252473..a910d0ae 100644 --- a/query_gen/tables.go +++ b/query_gen/tables.go @@ -221,6 +221,8 @@ func createTables(adapter qgen.Adapter) error { qgen.Install.CreateTable("polls", "utf8mb4", "utf8mb4_general_ci", []qgen.DBTableColumn{ qgen.DBTableColumn{"pollID", "int", 0, false, true, ""}, + qgen.DBTableColumn{"parentID", "int", 0, false, false, "0"}, + qgen.DBTableColumn{"parentTable", "varchar", 100, false, false, "topics"}, // topics, replies qgen.DBTableColumn{"type", "int", 0, false, false, "0"}, qgen.DBTableColumn{"options", "json", 0, false, false, ""}, qgen.DBTableColumn{"votes", "int", 0, false, false, "0"}, diff --git a/router_gen/main.go b/router_gen/main.go index 7b37b2ae..e1432cb0 100644 --- a/router_gen/main.go +++ b/router_gen/main.go @@ -275,27 +275,36 @@ func (router *GenRouter) RemoveFunc(pattern string) error { return nil } +func (router *GenRouter) DumpRequest(req *http.Request) { + log.Print("UA: ", req.UserAgent()) + log.Print("Method: ", req.Method) + for key, value := range req.Header { + for _, vvalue := range value { + log.Print("Header '" + key + "': " + vvalue + "!!") + } + } + log.Print("req.Host: ", req.Host) + log.Print("req.URL.Path: ", req.URL.Path) + log.Print("req.URL.RawQuery: ", req.URL.RawQuery) + log.Print("req.Referer(): ", req.Referer()) + log.Print("req.RemoteAddr: ", req.RemoteAddr) +} + +func (router *GenRouter) SuspiciousRequest(req *http.Request) { + log.Print("Supicious Request") + router.DumpRequest(req) + common.AgentViewCounter.Bump({{.AllAgentMap.suspicious}}) +} + // TODO: Pass the default route or config struct to the router rather than accessing it via a package global // TODO: SetDefaultRoute // TODO: GetDefaultRoute - func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { if len(req.URL.Path) == 0 || req.URL.Path[0] != '/' || req.Host != common.Site.Host { - w.WriteHeader(200) // 405 + w.WriteHeader(200) // 400 w.Write([]byte("")) log.Print("Malformed Request") - log.Print("UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) + router.DumpRequest(req) common.AgentViewCounter.Bump({{.AllAgentMap.malformed}}) return } @@ -304,37 +313,14 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { // TODO: Cover more suspicious strings and at a lower layer than this for _, char := range req.URL.Path { if char != '&' && !(char > 44 && char < 58) && char != '=' && char != '?' && !(char > 64 && char < 91) && char != '\\' && char != '_' && !(char > 96 && char < 123) { - log.Print("Suspicious UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) - common.AgentViewCounter.Bump({{.AllAgentMap.suspicious}}) + router.SuspiciousRequest(req) break } } + lowerPath := strings.ToLower(req.URL.Path) // TODO: Flag any requests which has a dot with anything but a number after that - if strings.Contains(req.URL.Path,"..") || strings.Contains(req.URL.Path,"--") || strings.Contains(req.URL.Path,".php") || strings.Contains(req.URL.Path,".asp") || strings.Contains(req.URL.Path,".cgi") { - log.Print("Suspicious UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) - common.AgentViewCounter.Bump({{.AllAgentMap.suspicious}}) + if strings.Contains(req.URL.Path,"..") || strings.Contains(req.URL.Path,"--") || strings.Contains(lowerPath,".php") || strings.Contains(lowerPath,".asp") || strings.Contains(lowerPath,".cgi") || strings.Contains(lowerPath,".py") || strings.Contains(lowerPath,".sql") { + router.SuspiciousRequest(req) } } @@ -413,38 +399,13 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { common.AgentViewCounter.Bump({{.AllAgentMap.blank}}) if common.Dev.DebugMode { log.Print("Blank UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("prefix: ", prefix) - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("extraData: ", extraData) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) + router.DumpRequest(req) } default: common.AgentViewCounter.Bump({{.AllAgentMap.unknown}}) if common.Dev.DebugMode { log.Print("Unknown UA: ", req.UserAgent()) - log.Print("Method: ", req.Method) - for key, value := range req.Header { - for _, vvalue := range value { - log.Print("Header '" + key + "': " + vvalue + "!!") - } - } - log.Print("prefix: ", prefix) - log.Print("req.Host: ", req.Host) - log.Print("req.URL.Path: ", req.URL.Path) - log.Print("req.URL.RawQuery: ", req.URL.RawQuery) - log.Print("extraData: ", extraData) - log.Print("req.Referer(): ", req.Referer()) - log.Print("req.RemoteAddr: ", req.RemoteAddr) + router.DumpRequest(req) } } @@ -492,7 +453,6 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { } return*/ } - if extraData != "" { common.NotFound(w,req) return @@ -524,6 +484,10 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } + lowerPath := strings.ToLower(req.URL.Path) + if strings.Contains(lowerPath,"admin") || strings.Contains(lowerPath,"sql") || strings.Contains(lowerPath,"manage") { + router.SuspiciousRequest(req) + } common.RouteViewCounter.Bump({{.AllRouteMap.BadRoute}}) common.NotFound(w,req) } diff --git a/router_gen/routes.go b/router_gen/routes.go index a0723778..e5c9aa92 100644 --- a/router_gen/routes.go +++ b/router_gen/routes.go @@ -29,6 +29,7 @@ func routes() { buildTopicRoutes() buildReplyRoutes() buildProfileReplyRoutes() + buildPollRoutes() buildAccountRoutes() addRoute(Special("routeWebsockets", "/ws/")) @@ -104,6 +105,14 @@ func buildProfileReplyRoutes() { addRouteGroup(pReplyGroup) } +func buildPollRoutes() { + pollGroup := newRouteGroup("/poll/") + pollGroup.Routes( + Action("routes.PollVote", "/poll/vote/", "extraData"), + ) + addRouteGroup(pollGroup) +} + func buildAccountRoutes() { //router.HandleFunc("/accounts/list/", routeLogin) // Redirect /accounts/ and /user/ to here.. // Get a list of all of the accounts on the forum accReplyGroup := newRouteGroup("/accounts/") diff --git a/routes/poll.go b/routes/poll.go new file mode 100644 index 00000000..7eba777f --- /dev/null +++ b/routes/poll.go @@ -0,0 +1,67 @@ +package routes + +import ( + "database/sql" + "errors" + "net/http" + "strconv" + + "../common" +) + +func PollVote(w http.ResponseWriter, r *http.Request, user common.User, sPollID string) common.RouteError { + pollID, err := strconv.Atoi(sPollID) + if err != nil { + return common.PreError("The provided PollID is not a valid number.", w, r) + } + + poll, err := common.Polls.Get(pollID) + if err == sql.ErrNoRows { + return common.PreError("The poll you tried to vote for doesn't exist.", w, r) + } else if err != nil { + return common.InternalError(err, w, r) + } + + var topic *common.Topic + if poll.ParentTable == "replies" { + reply, err := common.Rstore.Get(poll.ParentID) + if err == sql.ErrNoRows { + return common.PreError("The parent post doesn't exist.", w, r) + } else if err != nil { + return common.InternalError(err, w, r) + } + topic, err = common.Topics.Get(reply.ParentID) + } else if poll.ParentTable == "topics" { + topic, err = common.Topics.Get(poll.ParentID) + } else { + return common.InternalError(errors.New("Unknown parentTable for poll"), w, r) + } + + if err == sql.ErrNoRows { + return common.PreError("The parent topic doesn't exist.", w, r) + } else if err != nil { + return common.InternalError(err, w, r) + } + + // TODO: Add hooks to make use of headerLite + _, ferr := common.SimpleForumUserCheck(w, r, &user, topic.ParentID) + if ferr != nil { + return ferr + } + if !user.Perms.ViewTopic { + return common.NoPermissions(w, r, user) + } + + optionIndex, err := strconv.Atoi(r.PostFormValue("poll_option_input")) + if err != nil { + return common.LocalError("Malformed input", w, r, user) + } + + err = poll.CastVote(optionIndex, user.ID, user.LastIP) + if err != nil { + return common.InternalError(err, w, r) + } + + http.Redirect(w, r, "/topic/"+strconv.Itoa(topic.ID), http.StatusSeeOther) + return nil +} diff --git a/routes/topic.go b/routes/topic.go index a981caf2..368d4e5a 100644 --- a/routes/topic.go +++ b/routes/topic.go @@ -144,7 +144,6 @@ func CreateTopicSubmit(w http.ResponseWriter, r *http.Request, user common.User) if r.PostFormValue("has_poll") == "1" { var maxPollOptions = 10 var pollInputItems = make(map[int]string) - var pollInputCount = 0 for key, values := range r.Form { //if common.Dev.SuperDebug { log.Print("key: ", key) @@ -165,8 +164,7 @@ func CreateTopicSubmit(w http.ResponseWriter, r *http.Request, user common.User) // If there are duplicates, then something has gone horribly wrong, so let's ignore them, this'll likely happen during an attack _, exists := pollInputItems[index] - if !exists { - pollInputCount++ + if !exists && len(html.EscapeString(value)) != 0 { pollInputItems[index] = html.EscapeString(value) if len(pollInputItems) >= maxPollOptions { @@ -177,8 +175,14 @@ func CreateTopicSubmit(w http.ResponseWriter, r *http.Request, user common.User) } } + // Make sure the indices are sequential to avoid out of bounds issues + var seqPollInputItems = make(map[int]string) + for i := 0; i < len(pollInputItems); i++ { + seqPollInputItems[i] = pollInputItems[i] + } + pollType := 0 // Basic single choice - _, err := common.Polls.Create(topic, pollType, pollInputItems) + _, err := common.Polls.Create(topic, pollType, seqPollInputItems) if err != nil { return common.LocalError("Failed to add poll to topic", w, r, user) // TODO: Might need to be an internal error as it could leave phantom polls? } diff --git a/schema/mssql/query_polls.sql b/schema/mssql/query_polls.sql index b6125582..3c32d48f 100644 --- a/schema/mssql/query_polls.sql +++ b/schema/mssql/query_polls.sql @@ -1,5 +1,7 @@ CREATE TABLE [polls] ( [pollID] int not null IDENTITY, + [parentID] int DEFAULT 0 not null, + [parentTable] nvarchar (100) DEFAULT 'topics' not null, [type] int DEFAULT 0 not null, [options] nvarchar (MAX) not null, [votes] int DEFAULT 0 not null, diff --git a/schema/mysql/query_polls.sql b/schema/mysql/query_polls.sql index 475bc9c6..21678499 100644 --- a/schema/mysql/query_polls.sql +++ b/schema/mysql/query_polls.sql @@ -1,5 +1,7 @@ CREATE TABLE `polls` ( `pollID` int not null AUTO_INCREMENT, + `parentID` int DEFAULT 0 not null, + `parentTable` varchar(100) DEFAULT 'topics' not null, `type` int DEFAULT 0 not null, `options` text not null, `votes` int DEFAULT 0 not null, diff --git a/schema/pgsql/query_polls.sql b/schema/pgsql/query_polls.sql index 6c46e6df..811eec1b 100644 --- a/schema/pgsql/query_polls.sql +++ b/schema/pgsql/query_polls.sql @@ -1,5 +1,7 @@ CREATE TABLE `polls` ( `pollID` serial not null, + `parentID` int DEFAULT 0 not null, + `parentTable` varchar (100) DEFAULT 'topics' not null, `type` int DEFAULT 0 not null, `options` json not null, `votes` int DEFAULT 0 not null, diff --git a/template_list.go b/template_list.go index 2322f8d9..6f2b04bf 100644 --- a/template_list.go +++ b/template_list.go @@ -352,207 +352,223 @@ var topic_alt_20 = []byte(`