I've revamped the query generator to reduce the number of globals, this'll help us split up the software and plugins in the future.

Refactored the router.
Added the MemberOnly middleware and tests for it.
This commit is contained in:
Azareal 2017-11-05 09:55:34 +00:00
parent f5190e83ba
commit 3fa9bf7373
38 changed files with 1065 additions and 833 deletions

View File

@ -122,7 +122,7 @@ func buildAlert(asid int, event string, elementType string, actorID int, targetU
} }
func notifyWatchers(asid int64) { func notifyWatchers(asid int64) {
rows, err := getWatchersStmt.Query(asid) rows, err := stmts.getWatchers.Query(asid)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
log.Fatal(err.Error()) log.Fatal(err.Error())
return return
@ -147,7 +147,7 @@ func notifyWatchers(asid int64) {
var actorID, targetUserID, elementID int var actorID, targetUserID, elementID int
var event, elementType string var event, elementType string
err = getActivityEntryStmt.QueryRow(asid).Scan(&actorID, &targetUserID, &event, &elementType, &elementID) err = stmts.getActivityEntry.QueryRow(asid).Scan(&actorID, &targetUserID, &event, &elementType, &elementID)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
log.Fatal(err.Error()) log.Fatal(err.Error())
return return

View File

@ -170,7 +170,7 @@ func (auth *DefaultAuth) CreateSession(uid int) (session string, err error) {
return "", err return "", err
} }
_, err = updateSessionStmt.Exec(session, uid) _, err = stmts.updateSession.Exec(session, uid)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -4,6 +4,8 @@ import "log"
import "database/sql" import "database/sql"
var stmts *Stmts
var db *sql.DB var db *sql.DB
var dbVersion string var dbVersion string
var dbAdapter string var dbAdapter string
@ -14,11 +16,14 @@ var ErrNoRows = sql.ErrNoRows
var _initDatabase func() error var _initDatabase func() error
func initDatabase() (err error) { func initDatabase() (err error) {
stmts = &Stmts{Mocks: false}
// Engine specific code // Engine specific code
err = _initDatabase() err = _initDatabase()
if err != nil { if err != nil {
return err return err
} }
globs = &Globs{stmts}
log.Print("Loading the usergroups.") log.Print("Loading the usergroups.")
gstore, err = NewMemoryGroupStore() gstore, err = NewMemoryGroupStore()

View File

@ -143,7 +143,7 @@ func initExtend() (err error) {
// LoadPlugins polls the database to see which plugins have been activated and which have been installed // LoadPlugins polls the database to see which plugins have been activated and which have been installed
func LoadPlugins() error { func LoadPlugins() error {
rows, err := getPluginsStmt.Query() rows, err := stmts.getPlugins.Query()
if err != nil { if err != nil {
return err return err
} }

View File

@ -82,7 +82,7 @@ func (forum *Forum) Update(name string, desc string, active bool, preset string)
name = forum.Name name = forum.Name
} }
preset = strings.TrimSpace(preset) preset = strings.TrimSpace(preset)
_, err := updateForumStmt.Exec(name, desc, active, preset, forum.ID) _, err := stmts.updateForum.Exec(name, desc, active, preset, forum.ID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -23,7 +23,7 @@ func (fps *ForumPermsStore) Init() error {
log.Print("fids: ", fids) log.Print("fids: ", fids)
} }
rows, err := getForumsPermissionsStmt.Query() rows, err := stmts.getForumsPermissions.Query()
if err != nil { if err != nil {
return err return err
} }

View File

@ -102,7 +102,8 @@ func (mfs *MemoryForumStore) LoadForums() error {
} }
} }
rows, err := getForumsStmt.Query() // TODO: Move this statement into the store
rows, err := stmts.getForums.Query()
if err != nil { if err != nil {
return err return err
} }
@ -327,11 +328,11 @@ func (mfs *MemoryForumStore) Delete(id int) error {
} }
func (mfs *MemoryForumStore) AddTopic(tid int, uid int, fid int) error { func (mfs *MemoryForumStore) AddTopic(tid int, uid int, fid int) error {
_, err := updateForumCacheStmt.Exec(tid, uid, fid) _, err := stmts.updateForumCache.Exec(tid, uid, fid)
if err != nil { if err != nil {
return err return err
} }
_, err = addTopicsToForumStmt.Exec(1, fid) _, err = stmts.addTopicsToForum.Exec(1, fid)
if err != nil { if err != nil {
return err return err
} }
@ -341,7 +342,7 @@ func (mfs *MemoryForumStore) AddTopic(tid int, uid int, fid int) error {
// TODO: Update the forum cache with the latest topic // TODO: Update the forum cache with the latest topic
func (mfs *MemoryForumStore) RemoveTopic(fid int) error { func (mfs *MemoryForumStore) RemoveTopic(fid int) error {
_, err := removeTopicsFromForumStmt.Exec(1, fid) _, err := stmts.removeTopicsFromForum.Exec(1, fid)
if err != nil { if err != nil {
return err return err
} }
@ -353,7 +354,7 @@ func (mfs *MemoryForumStore) RemoveTopic(fid int) error {
// DEPRECATED. forum.Update() will be the way to do this in the future, once it's completed // DEPRECATED. forum.Update() will be the way to do this in the future, once it's completed
// TODO: Have a pointer to the last topic rather than storing it on the forum itself // TODO: Have a pointer to the last topic rather than storing it on the forum itself
func (mfs *MemoryForumStore) UpdateLastTopic(tid int, uid int, fid int) error { func (mfs *MemoryForumStore) UpdateLastTopic(tid int, uid int, fid int) error {
_, err := updateForumCacheStmt.Exec(tid, uid, fid) _, err := stmts.updateForumCache.Exec(tid, uid, fid)
if err != nil { if err != nil {
return err return err
} }
@ -363,7 +364,8 @@ func (mfs *MemoryForumStore) UpdateLastTopic(tid int, uid int, fid int) error {
func (mfs *MemoryForumStore) Create(forumName string, forumDesc string, active bool, preset string) (int, error) { func (mfs *MemoryForumStore) Create(forumName string, forumDesc string, active bool, preset string) (int, error) {
forumCreateMutex.Lock() forumCreateMutex.Lock()
res, err := createForumStmt.Exec(forumName, forumDesc, active, preset) // TODO: Move this query into the store
res, err := stmts.createForum.Exec(forumName, forumDesc, active, preset)
if err != nil { if err != nil {
return 0, err return 0, err
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7,48 +7,62 @@ import "log"
import "database/sql" import "database/sql"
// nolint // nolint
var addRepliesToTopicStmt *sql.Stmt type Stmts struct {
var removeRepliesFromTopicStmt *sql.Stmt addRepliesToTopic *sql.Stmt
var addTopicsToForumStmt *sql.Stmt removeRepliesFromTopic *sql.Stmt
var removeTopicsFromForumStmt *sql.Stmt addTopicsToForum *sql.Stmt
var updateForumCacheStmt *sql.Stmt removeTopicsFromForum *sql.Stmt
var addLikesToTopicStmt *sql.Stmt updateForumCache *sql.Stmt
var addLikesToReplyStmt *sql.Stmt addLikesToTopic *sql.Stmt
var editTopicStmt *sql.Stmt addLikesToReply *sql.Stmt
var editReplyStmt *sql.Stmt editTopic *sql.Stmt
var stickTopicStmt *sql.Stmt editReply *sql.Stmt
var unstickTopicStmt *sql.Stmt stickTopic *sql.Stmt
var lockTopicStmt *sql.Stmt unstickTopic *sql.Stmt
var unlockTopicStmt *sql.Stmt lockTopic *sql.Stmt
var updateLastIPStmt *sql.Stmt unlockTopic *sql.Stmt
var updateSessionStmt *sql.Stmt updateLastIP *sql.Stmt
var setPasswordStmt *sql.Stmt updateSession *sql.Stmt
var setAvatarStmt *sql.Stmt setPassword *sql.Stmt
var setUsernameStmt *sql.Stmt setAvatar *sql.Stmt
var changeGroupStmt *sql.Stmt setUsername *sql.Stmt
var activateUserStmt *sql.Stmt changeGroup *sql.Stmt
var updateUserLevelStmt *sql.Stmt activateUser *sql.Stmt
var incrementUserScoreStmt *sql.Stmt updateUserLevel *sql.Stmt
var incrementUserPostsStmt *sql.Stmt incrementUserScore *sql.Stmt
var incrementUserBigpostsStmt *sql.Stmt incrementUserPosts *sql.Stmt
var incrementUserMegapostsStmt *sql.Stmt incrementUserBigposts *sql.Stmt
var incrementUserTopicsStmt *sql.Stmt incrementUserMegaposts *sql.Stmt
var editProfileReplyStmt *sql.Stmt incrementUserTopics *sql.Stmt
var updateForumStmt *sql.Stmt editProfileReply *sql.Stmt
var updateSettingStmt *sql.Stmt updateForum *sql.Stmt
var updatePluginStmt *sql.Stmt updateSetting *sql.Stmt
var updatePluginInstallStmt *sql.Stmt updatePlugin *sql.Stmt
var updateThemeStmt *sql.Stmt updatePluginInstall *sql.Stmt
var updateUserStmt *sql.Stmt updateTheme *sql.Stmt
var updateUserGroupStmt *sql.Stmt updateUser *sql.Stmt
var updateGroupPermsStmt *sql.Stmt updateUserGroup *sql.Stmt
var updateGroupRankStmt *sql.Stmt updateGroupPerms *sql.Stmt
var updateGroupStmt *sql.Stmt updateGroupRank *sql.Stmt
var updateEmailStmt *sql.Stmt updateGroup *sql.Stmt
var verifyEmailStmt *sql.Stmt updateEmail *sql.Stmt
var setTempGroupStmt *sql.Stmt verifyEmail *sql.Stmt
var updateWordFilterStmt *sql.Stmt setTempGroup *sql.Stmt
var bumpSyncStmt *sql.Stmt updateWordFilter *sql.Stmt
bumpSync *sql.Stmt
getActivityFeedByWatcher *sql.Stmt
getActivityCountByWatcher *sql.Stmt
todaysPostCount *sql.Stmt
todaysTopicCount *sql.Stmt
todaysReportCount *sql.Stmt
todaysNewUserCount *sql.Stmt
findUsersByIPUsers *sql.Stmt
findUsersByIPTopics *sql.Stmt
findUsersByIPReplies *sql.Stmt
Mocks bool
}
// nolint // nolint
func _gen_pgsql() (err error) { func _gen_pgsql() (err error) {
@ -57,253 +71,253 @@ func _gen_pgsql() (err error) {
} }
log.Print("Preparing addRepliesToTopic statement.") log.Print("Preparing addRepliesToTopic statement.")
addRepliesToTopicStmt, err = db.Prepare("UPDATE `topics` SET `postCount` = `postCount` + ?,`lastReplyBy` = ?,`lastReplyAt` = LOCALTIMESTAMP() WHERE `tid` = ?") stmts.addRepliesToTopic, err = db.Prepare("UPDATE `topics` SET `postCount` = `postCount` + ?,`lastReplyBy` = ?,`lastReplyAt` = LOCALTIMESTAMP() WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing removeRepliesFromTopic statement.") log.Print("Preparing removeRepliesFromTopic statement.")
removeRepliesFromTopicStmt, err = db.Prepare("UPDATE `topics` SET `postCount` = `postCount` - ? WHERE `tid` = ?") stmts.removeRepliesFromTopic, err = db.Prepare("UPDATE `topics` SET `postCount` = `postCount` - ? WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing addTopicsToForum statement.") log.Print("Preparing addTopicsToForum statement.")
addTopicsToForumStmt, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` + ? WHERE `fid` = ?") stmts.addTopicsToForum, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` + ? WHERE `fid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing removeTopicsFromForum statement.") log.Print("Preparing removeTopicsFromForum statement.")
removeTopicsFromForumStmt, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` - ? WHERE `fid` = ?") stmts.removeTopicsFromForum, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` - ? WHERE `fid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateForumCache statement.") log.Print("Preparing updateForumCache statement.")
updateForumCacheStmt, err = db.Prepare("UPDATE `forums` SET `lastTopicID` = ?,`lastReplyerID` = ? WHERE `fid` = ?") stmts.updateForumCache, err = db.Prepare("UPDATE `forums` SET `lastTopicID` = ?,`lastReplyerID` = ? WHERE `fid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing addLikesToTopic statement.") log.Print("Preparing addLikesToTopic statement.")
addLikesToTopicStmt, err = db.Prepare("UPDATE `topics` SET `likeCount` = `likeCount` + ? WHERE `tid` = ?") stmts.addLikesToTopic, err = db.Prepare("UPDATE `topics` SET `likeCount` = `likeCount` + ? WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing addLikesToReply statement.") log.Print("Preparing addLikesToReply statement.")
addLikesToReplyStmt, err = db.Prepare("UPDATE `replies` SET `likeCount` = `likeCount` + ? WHERE `rid` = ?") stmts.addLikesToReply, err = db.Prepare("UPDATE `replies` SET `likeCount` = `likeCount` + ? WHERE `rid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing editTopic statement.") log.Print("Preparing editTopic statement.")
editTopicStmt, err = db.Prepare("UPDATE `topics` SET `title` = ?,`content` = ?,`parsed_content` = ? WHERE `tid` = ?") stmts.editTopic, err = db.Prepare("UPDATE `topics` SET `title` = ?,`content` = ?,`parsed_content` = ? WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing editReply statement.") log.Print("Preparing editReply statement.")
editReplyStmt, err = db.Prepare("UPDATE `replies` SET `content` = ?,`parsed_content` = ? WHERE `rid` = ?") stmts.editReply, err = db.Prepare("UPDATE `replies` SET `content` = ?,`parsed_content` = ? WHERE `rid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing stickTopic statement.") log.Print("Preparing stickTopic statement.")
stickTopicStmt, err = db.Prepare("UPDATE `topics` SET `sticky` = 1 WHERE `tid` = ?") stmts.stickTopic, err = db.Prepare("UPDATE `topics` SET `sticky` = 1 WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing unstickTopic statement.") log.Print("Preparing unstickTopic statement.")
unstickTopicStmt, err = db.Prepare("UPDATE `topics` SET `sticky` = 0 WHERE `tid` = ?") stmts.unstickTopic, err = db.Prepare("UPDATE `topics` SET `sticky` = 0 WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing lockTopic statement.") log.Print("Preparing lockTopic statement.")
lockTopicStmt, err = db.Prepare("UPDATE `topics` SET `is_closed` = 1 WHERE `tid` = ?") stmts.lockTopic, err = db.Prepare("UPDATE `topics` SET `is_closed` = 1 WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing unlockTopic statement.") log.Print("Preparing unlockTopic statement.")
unlockTopicStmt, err = db.Prepare("UPDATE `topics` SET `is_closed` = 0 WHERE `tid` = ?") stmts.unlockTopic, err = db.Prepare("UPDATE `topics` SET `is_closed` = 0 WHERE `tid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateLastIP statement.") log.Print("Preparing updateLastIP statement.")
updateLastIPStmt, err = db.Prepare("UPDATE `users` SET `last_ip` = ? WHERE `uid` = ?") stmts.updateLastIP, err = db.Prepare("UPDATE `users` SET `last_ip` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateSession statement.") log.Print("Preparing updateSession statement.")
updateSessionStmt, err = db.Prepare("UPDATE `users` SET `session` = ? WHERE `uid` = ?") stmts.updateSession, err = db.Prepare("UPDATE `users` SET `session` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing setPassword statement.") log.Print("Preparing setPassword statement.")
setPasswordStmt, err = db.Prepare("UPDATE `users` SET `password` = ?,`salt` = ? WHERE `uid` = ?") stmts.setPassword, err = db.Prepare("UPDATE `users` SET `password` = ?,`salt` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing setAvatar statement.") log.Print("Preparing setAvatar statement.")
setAvatarStmt, err = db.Prepare("UPDATE `users` SET `avatar` = ? WHERE `uid` = ?") stmts.setAvatar, err = db.Prepare("UPDATE `users` SET `avatar` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing setUsername statement.") log.Print("Preparing setUsername statement.")
setUsernameStmt, err = db.Prepare("UPDATE `users` SET `name` = ? WHERE `uid` = ?") stmts.setUsername, err = db.Prepare("UPDATE `users` SET `name` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing changeGroup statement.") log.Print("Preparing changeGroup statement.")
changeGroupStmt, err = db.Prepare("UPDATE `users` SET `group` = ? WHERE `uid` = ?") stmts.changeGroup, err = db.Prepare("UPDATE `users` SET `group` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing activateUser statement.") log.Print("Preparing activateUser statement.")
activateUserStmt, err = db.Prepare("UPDATE `users` SET `active` = 1 WHERE `uid` = ?") stmts.activateUser, err = db.Prepare("UPDATE `users` SET `active` = 1 WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateUserLevel statement.") log.Print("Preparing updateUserLevel statement.")
updateUserLevelStmt, err = db.Prepare("UPDATE `users` SET `level` = ? WHERE `uid` = ?") stmts.updateUserLevel, err = db.Prepare("UPDATE `users` SET `level` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing incrementUserScore statement.") log.Print("Preparing incrementUserScore statement.")
incrementUserScoreStmt, err = db.Prepare("UPDATE `users` SET `score` = `score` + ? WHERE `uid` = ?") stmts.incrementUserScore, err = db.Prepare("UPDATE `users` SET `score` = `score` + ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing incrementUserPosts statement.") log.Print("Preparing incrementUserPosts statement.")
incrementUserPostsStmt, err = db.Prepare("UPDATE `users` SET `posts` = `posts` + ? WHERE `uid` = ?") stmts.incrementUserPosts, err = db.Prepare("UPDATE `users` SET `posts` = `posts` + ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing incrementUserBigposts statement.") log.Print("Preparing incrementUserBigposts statement.")
incrementUserBigpostsStmt, err = db.Prepare("UPDATE `users` SET `posts` = `posts` + ?,`bigposts` = `bigposts` + ? WHERE `uid` = ?") stmts.incrementUserBigposts, err = db.Prepare("UPDATE `users` SET `posts` = `posts` + ?,`bigposts` = `bigposts` + ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing incrementUserMegaposts statement.") log.Print("Preparing incrementUserMegaposts statement.")
incrementUserMegapostsStmt, err = db.Prepare("UPDATE `users` SET `posts` = `posts` + ?,`bigposts` = `bigposts` + ?,`megaposts` = `megaposts` + ? WHERE `uid` = ?") stmts.incrementUserMegaposts, err = db.Prepare("UPDATE `users` SET `posts` = `posts` + ?,`bigposts` = `bigposts` + ?,`megaposts` = `megaposts` + ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing incrementUserTopics statement.") log.Print("Preparing incrementUserTopics statement.")
incrementUserTopicsStmt, err = db.Prepare("UPDATE `users` SET `topics` = `topics` + ? WHERE `uid` = ?") stmts.incrementUserTopics, err = db.Prepare("UPDATE `users` SET `topics` = `topics` + ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing editProfileReply statement.") log.Print("Preparing editProfileReply statement.")
editProfileReplyStmt, err = db.Prepare("UPDATE `users_replies` SET `content` = ?,`parsed_content` = ? WHERE `rid` = ?") stmts.editProfileReply, err = db.Prepare("UPDATE `users_replies` SET `content` = ?,`parsed_content` = ? WHERE `rid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateForum statement.") log.Print("Preparing updateForum statement.")
updateForumStmt, err = db.Prepare("UPDATE `forums` SET `name` = ?,`desc` = ?,`active` = ?,`preset` = ? WHERE `fid` = ?") stmts.updateForum, err = db.Prepare("UPDATE `forums` SET `name` = ?,`desc` = ?,`active` = ?,`preset` = ? WHERE `fid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateSetting statement.") log.Print("Preparing updateSetting statement.")
updateSettingStmt, err = db.Prepare("UPDATE `settings` SET `content` = ? WHERE `name` = ?") stmts.updateSetting, err = db.Prepare("UPDATE `settings` SET `content` = ? WHERE `name` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updatePlugin statement.") log.Print("Preparing updatePlugin statement.")
updatePluginStmt, err = db.Prepare("UPDATE `plugins` SET `active` = ? WHERE `uname` = ?") stmts.updatePlugin, err = db.Prepare("UPDATE `plugins` SET `active` = ? WHERE `uname` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updatePluginInstall statement.") log.Print("Preparing updatePluginInstall statement.")
updatePluginInstallStmt, err = db.Prepare("UPDATE `plugins` SET `installed` = ? WHERE `uname` = ?") stmts.updatePluginInstall, err = db.Prepare("UPDATE `plugins` SET `installed` = ? WHERE `uname` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateTheme statement.") log.Print("Preparing updateTheme statement.")
updateThemeStmt, err = db.Prepare("UPDATE `themes` SET `default` = ? WHERE `uname` = ?") stmts.updateTheme, err = db.Prepare("UPDATE `themes` SET `default` = ? WHERE `uname` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateUser statement.") log.Print("Preparing updateUser statement.")
updateUserStmt, err = db.Prepare("UPDATE `users` SET `name` = ?,`email` = ?,`group` = ? WHERE `uid` = ?") stmts.updateUser, err = db.Prepare("UPDATE `users` SET `name` = ?,`email` = ?,`group` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateUserGroup statement.") log.Print("Preparing updateUserGroup statement.")
updateUserGroupStmt, err = db.Prepare("UPDATE `users` SET `group` = ? WHERE `uid` = ?") stmts.updateUserGroup, err = db.Prepare("UPDATE `users` SET `group` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateGroupPerms statement.") log.Print("Preparing updateGroupPerms statement.")
updateGroupPermsStmt, err = db.Prepare("UPDATE `users_groups` SET `permissions` = ? WHERE `gid` = ?") stmts.updateGroupPerms, err = db.Prepare("UPDATE `users_groups` SET `permissions` = ? WHERE `gid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateGroupRank statement.") log.Print("Preparing updateGroupRank statement.")
updateGroupRankStmt, err = db.Prepare("UPDATE `users_groups` SET `is_admin` = ?,`is_mod` = ?,`is_banned` = ? WHERE `gid` = ?") stmts.updateGroupRank, err = db.Prepare("UPDATE `users_groups` SET `is_admin` = ?,`is_mod` = ?,`is_banned` = ? WHERE `gid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateGroup statement.") log.Print("Preparing updateGroup statement.")
updateGroupStmt, err = db.Prepare("UPDATE `users_groups` SET `name` = ?,`tag` = ? WHERE `gid` = ?") stmts.updateGroup, err = db.Prepare("UPDATE `users_groups` SET `name` = ?,`tag` = ? WHERE `gid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateEmail statement.") log.Print("Preparing updateEmail statement.")
updateEmailStmt, err = db.Prepare("UPDATE `emails` SET `email` = ?,`uid` = ?,`validated` = ?,`token` = ? WHERE `email` = ?") stmts.updateEmail, err = db.Prepare("UPDATE `emails` SET `email` = ?,`uid` = ?,`validated` = ?,`token` = ? WHERE `email` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing verifyEmail statement.") log.Print("Preparing verifyEmail statement.")
verifyEmailStmt, err = db.Prepare("UPDATE `emails` SET `validated` = 1,`token` = '' WHERE `email` = ?") stmts.verifyEmail, err = db.Prepare("UPDATE `emails` SET `validated` = 1,`token` = '' WHERE `email` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing setTempGroup statement.") log.Print("Preparing setTempGroup statement.")
setTempGroupStmt, err = db.Prepare("UPDATE `users` SET `temp_group` = ? WHERE `uid` = ?") stmts.setTempGroup, err = db.Prepare("UPDATE `users` SET `temp_group` = ? WHERE `uid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing updateWordFilter statement.") log.Print("Preparing updateWordFilter statement.")
updateWordFilterStmt, err = db.Prepare("UPDATE `word_filters` SET `find` = ?,`replacement` = ? WHERE `wfid` = ?") stmts.updateWordFilter, err = db.Prepare("UPDATE `word_filters` SET `find` = ?,`replacement` = ? WHERE `wfid` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing bumpSync statement.") log.Print("Preparing bumpSync statement.")
bumpSyncStmt, err = db.Prepare("UPDATE `sync` SET `last_update` = LOCALTIMESTAMP()") stmts.bumpSync, err = db.Prepare("UPDATE `sync` SET `last_update` = LOCALTIMESTAMP()")
if err != nil { if err != nil {
return err return err
} }

View File

@ -136,6 +136,12 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
router.handleError(err,w,req,user) router.handleError(err,w,req,user)
} }
case "/report": case "/report":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
switch(req.URL.Path) { switch(req.URL.Path) {
case "/report/submit/": case "/report/submit/":
err = routeReportSubmit(w,req,user,extra_data) err = routeReportSubmit(w,req,user,extra_data)
@ -146,6 +152,12 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
case "/topics": case "/topics":
switch(req.URL.Path) { switch(req.URL.Path) {
case "/topics/create/": case "/topics/create/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeTopicCreate(w,req,user,extra_data) err = routeTopicCreate(w,req,user,extra_data)
default: default:
err = routeTopics(w,req,user) err = routeTopics(w,req,user)
@ -233,6 +245,79 @@ func (router *GenRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if err != nil { if err != nil {
router.handleError(err,w,req,user) router.handleError(err,w,req,user)
} }
case "/user":
switch(req.URL.Path) {
case "/user/edit/critical/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditCritical(w,req,user)
case "/user/edit/critical/submit/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditCriticalSubmit(w,req,user)
case "/user/edit/avatar/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditAvatar(w,req,user)
case "/user/edit/avatar/submit/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditAvatarSubmit(w,req,user)
case "/user/edit/username/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditUsername(w,req,user)
case "/user/edit/username/submit/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditUsernameSubmit(w,req,user)
case "/user/edit/email/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditEmail(w,req,user)
case "/user/edit/token/":
err = MemberOnly(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
err = routeAccountOwnEditEmailTokenSubmit(w,req,user,extra_data)
default:
req.URL.Path += extra_data
err = routeProfile(w,req,user)
}
if err != nil {
router.handleError(err,w,req,user)
}
case "/uploads": case "/uploads":
if extra_data == "" { if extra_data == "" {
NotFound(w,req) NotFound(w,req)

View File

@ -28,7 +28,7 @@ type Group struct {
} }
func (group *Group) ChangeRank(isAdmin bool, isMod bool, isBanned bool) (err error) { func (group *Group) ChangeRank(isAdmin bool, isMod bool, isBanned bool) (err error) {
_, err = updateGroupRankStmt.Exec(isAdmin, isMod, isBanned, group.ID) _, err = stmts.updateGroupRank.Exec(isAdmin, isMod, isBanned, group.ID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -54,12 +54,13 @@ func NewMemoryGroupStore() (*MemoryGroupStore, error) {
}, nil }, nil
} }
// TODO: Move this query from the global stmt store into this store
func (mgs *MemoryGroupStore) LoadGroups() error { func (mgs *MemoryGroupStore) LoadGroups() error {
mgs.Lock() mgs.Lock()
defer mgs.Unlock() defer mgs.Unlock()
mgs.groups[0] = &Group{ID: 0, Name: "Unknown"} mgs.groups[0] = &Group{ID: 0, Name: "Unknown"}
rows, err := getGroupsStmt.Query() rows, err := stmts.getGroups.Query()
if err != nil { if err != nil {
return err return err
} }

28
main.go
View File

@ -75,6 +75,13 @@ func (slice StringList) Contains(needle string) bool {
var staticFiles = make(map[string]SFile) var staticFiles = make(map[string]SFile)
var logWriter = io.MultiWriter(os.Stderr) var logWriter = io.MultiWriter(os.Stderr)
// TODO: Wrap the globals in here so we can pass pointers to them to subpackages
var globs *Globs
type Globs struct {
stmts *Stmts
}
func main() { func main() {
// TODO: Recover from panics // TODO: Recover from panics
/*defer func() { /*defer func() {
@ -238,18 +245,21 @@ func main() {
//router.HandleFunc("/accounts/list/", routeLogin) // Redirect /accounts/ and /user/ to here.. // Get a list of all of the accounts on the forum //router.HandleFunc("/accounts/list/", routeLogin) // Redirect /accounts/ and /user/ to here.. // Get a list of all of the accounts on the forum
//router.HandleFunc("/accounts/create/full/", routeLogout) // Advanced account creator for admins? //router.HandleFunc("/accounts/create/full/", routeLogout) // Advanced account creator for admins?
//router.HandleFunc("/user/edit/", routeLogout) //router.HandleFunc("/user/edit/", routeLogout)
router.HandleFunc("/user/edit/critical/", routeAccountOwnEditCritical) // Password & Email ////router.HandleFunc("/user/edit/critical/", routeAccountOwnEditCritical) // Password & Email
router.HandleFunc("/user/edit/critical/submit/", routeAccountOwnEditCriticalSubmit) ////router.HandleFunc("/user/edit/critical/submit/", routeAccountOwnEditCriticalSubmit)
router.HandleFunc("/user/edit/avatar/", routeAccountOwnEditAvatar) ////router.HandleFunc("/user/edit/avatar/", routeAccountOwnEditAvatar)
router.HandleFunc("/user/edit/avatar/submit/", routeAccountOwnEditAvatarSubmit) ////router.HandleFunc("/user/edit/avatar/submit/", routeAccountOwnEditAvatarSubmit)
router.HandleFunc("/user/edit/username/", routeAccountOwnEditUsername) ////router.HandleFunc("/user/edit/username/", routeAccountOwnEditUsername)
router.HandleFunc("/user/edit/username/submit/", routeAccountOwnEditUsernameSubmit) ////router.HandleFunc("/user/edit/username/submit/", routeAccountOwnEditUsernameSubmit)
router.HandleFunc("/user/edit/email/", routeAccountOwnEditEmail) ////router.HandleFunc("/user/edit/email/", routeAccountOwnEditEmail)
router.HandleFunc("/user/edit/token/", routeAccountOwnEditEmailTokenSubmit) ////router.HandleFunc("/user/edit/token/", routeAccountOwnEditEmailTokenSubmit)
router.HandleFunc("/user/", routeProfile) ////router.HandleFunc("/user/", routeProfile)
// TODO: Move these into /user/?
router.HandleFunc("/profile/reply/create/", routeProfileReplyCreate) router.HandleFunc("/profile/reply/create/", routeProfileReplyCreate)
router.HandleFunc("/profile/reply/edit/submit/", routeProfileReplyEditSubmit) router.HandleFunc("/profile/reply/edit/submit/", routeProfileReplyEditSubmit)
router.HandleFunc("/profile/reply/delete/submit/", routeProfileReplyDeleteSubmit) router.HandleFunc("/profile/reply/delete/submit/", routeProfileReplyDeleteSubmit)
//router.HandleFunc("/user/edit/submit/", routeLogout) // routeLogout? what on earth? o.o //router.HandleFunc("/user/edit/submit/", routeLogout) // routeLogout? what on earth? o.o
//router.HandleFunc("/users/ban/", routeBan) //router.HandleFunc("/users/ban/", routeBan)
router.HandleFunc("/users/ban/submit/", routeBanSubmit) router.HandleFunc("/users/ban/submit/", routeBanSubmit)

View File

@ -153,7 +153,7 @@ func routeTopicCreateSubmit(w http.ResponseWriter, r *http.Request, user User) R
} }
} }
_, err = addSubscriptionStmt.Exec(user.ID, tid, "topic") _, err = stmts.addSubscription.Exec(user.ID, tid, "topic")
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -224,7 +224,7 @@ func routeTopicCreateSubmit(w http.ResponseWriter, r *http.Request, user User) R
return LocalError("Upload failed [Copy Failed]", w, r, user) return LocalError("Upload failed [Copy Failed]", w, r, user)
} }
_, err = addAttachmentStmt.Exec(fid, "forums", tid, "topics", user.ID, filename) _, err = stmts.addAttachment.Exec(fid, "forums", tid, "topics", user.ID, filename)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -329,7 +329,7 @@ func routeCreateReply(w http.ResponseWriter, r *http.Request, user User) RouteEr
return LocalError("Upload failed [Copy Failed]", w, r, user) return LocalError("Upload failed [Copy Failed]", w, r, user)
} }
_, err = addAttachmentStmt.Exec(topic.ParentID, "forums", tid, "replies", user.ID, filename) _, err = stmts.addAttachment.Exec(topic.ParentID, "forums", tid, "replies", user.ID, filename)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -353,7 +353,7 @@ func routeCreateReply(w http.ResponseWriter, r *http.Request, user User) RouteEr
return InternalError(err, w, r) return InternalError(err, w, r)
} }
res, err := addActivityStmt.Exec(user.ID, topic.CreatedBy, "reply", "topic", tid) res, err := stmts.addActivity.Exec(user.ID, topic.CreatedBy, "reply", "topic", tid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -362,7 +362,7 @@ func routeCreateReply(w http.ResponseWriter, r *http.Request, user User) RouteEr
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = notifyWatchersStmt.Exec(lastID) _, err = stmts.notifyWatchers.Exec(lastID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -414,7 +414,7 @@ func routeLikeTopic(w http.ResponseWriter, r *http.Request, user User) RouteErro
return LocalError("You can't like your own topics", w, r, user) return LocalError("You can't like your own topics", w, r, user)
} }
err = hasLikedTopicStmt.QueryRow(user.ID, tid).Scan(&tid) err = stmts.hasLikedTopic.QueryRow(user.ID, tid).Scan(&tid)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} else if err != ErrNoRows { } else if err != ErrNoRows {
@ -429,17 +429,17 @@ func routeLikeTopic(w http.ResponseWriter, r *http.Request, user User) RouteErro
} }
score := 1 score := 1
_, err = createLikeStmt.Exec(score, tid, "topics", user.ID) _, err = stmts.createLike.Exec(score, tid, "topics", user.ID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = addLikesToTopicStmt.Exec(1, tid) _, err = stmts.addLikesToTopic.Exec(1, tid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
res, err := addActivityStmt.Exec(user.ID, topic.CreatedBy, "like", "topic", tid) res, err := stmts.addActivity.Exec(user.ID, topic.CreatedBy, "like", "topic", tid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -448,7 +448,7 @@ func routeLikeTopic(w http.ResponseWriter, r *http.Request, user User) RouteErro
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = notifyOneStmt.Exec(topic.CreatedBy, lastID) _, err = stmts.notifyOne.Exec(topic.CreatedBy, lastID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -484,7 +484,7 @@ func routeReplyLikeSubmit(w http.ResponseWriter, r *http.Request, user User) Rou
} }
var fid int var fid int
err = getTopicFIDStmt.QueryRow(reply.ParentID).Scan(&fid) err = stmts.getTopicFID.QueryRow(reply.ParentID).Scan(&fid)
if err == ErrNoRows { if err == ErrNoRows {
return PreError("The parent topic doesn't exist.", w, r) return PreError("The parent topic doesn't exist.", w, r)
} else if err != nil { } else if err != nil {
@ -518,7 +518,7 @@ func routeReplyLikeSubmit(w http.ResponseWriter, r *http.Request, user User) Rou
return InternalError(err, w, r) return InternalError(err, w, r)
} }
res, err := addActivityStmt.Exec(user.ID, reply.CreatedBy, "like", "post", rid) res, err := stmts.addActivity.Exec(user.ID, reply.CreatedBy, "like", "post", rid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -527,7 +527,7 @@ func routeReplyLikeSubmit(w http.ResponseWriter, r *http.Request, user User) Rou
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = notifyOneStmt.Exec(reply.CreatedBy, lastID) _, err = stmts.notifyOne.Exec(reply.CreatedBy, lastID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -559,13 +559,13 @@ func routeProfileReplyCreate(w http.ResponseWriter, r *http.Request, user User)
} }
content := html.EscapeString(preparseMessage(r.PostFormValue("reply-content"))) content := html.EscapeString(preparseMessage(r.PostFormValue("reply-content")))
_, err = createProfileReplyStmt.Exec(uid, content, parseMessage(content, 0, ""), user.ID, ipaddress) _, err = stmts.createProfileReply.Exec(uid, content, parseMessage(content, 0, ""), user.ID, ipaddress)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
var userName string var userName string
err = getUserNameStmt.QueryRow(uid).Scan(&userName) err = stmts.getUserName.QueryRow(uid).Scan(&userName)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("The profile you're trying to post on doesn't exist.", w, r, user) return LocalError("The profile you're trying to post on doesn't exist.", w, r, user)
} else if err != nil { } else if err != nil {
@ -626,7 +626,7 @@ func routeReportSubmit(w http.ResponseWriter, r *http.Request, user User, sitemI
return InternalError(err, w, r) return InternalError(err, w, r)
} }
err = getUserNameStmt.QueryRow(userReply.ParentID).Scan(&title) err = stmts.getUserName.QueryRow(userReply.ParentID).Scan(&title)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("We weren't able to find the profile the reported post is supposed to be on", w, r, user) return LocalError("We weren't able to find the profile the reported post is supposed to be on", w, r, user)
} else if err != nil { } else if err != nil {
@ -635,7 +635,7 @@ func routeReportSubmit(w http.ResponseWriter, r *http.Request, user User, sitemI
title = "Profile: " + title title = "Profile: " + title
content = userReply.Content + "\n\nOriginal Post: @" + strconv.Itoa(userReply.ParentID) content = userReply.Content + "\n\nOriginal Post: @" + strconv.Itoa(userReply.ParentID)
} else if itemType == "topic" { } else if itemType == "topic" {
err = getTopicBasicStmt.QueryRow(itemID).Scan(&title, &content) err = stmts.getTopicBasic.QueryRow(itemID).Scan(&title, &content)
if err == ErrNoRows { if err == ErrNoRows {
return NotFound(w, r) return NotFound(w, r)
} else if err != nil { } else if err != nil {
@ -653,7 +653,7 @@ func routeReportSubmit(w http.ResponseWriter, r *http.Request, user User, sitemI
} }
var count int var count int
rows, err := reportExistsStmt.Query(itemType + "_" + strconv.Itoa(itemID)) rows, err := stmts.reportExists.Query(itemType + "_" + strconv.Itoa(itemID))
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -670,7 +670,7 @@ func routeReportSubmit(w http.ResponseWriter, r *http.Request, user User, sitemI
// TODO: Repost attachments in the reports forum, so that the mods can see them // TODO: Repost attachments in the reports forum, so that the mods can see them
// ? - Can we do this via the TopicStore? // ? - Can we do this via the TopicStore?
res, err := createReportStmt.Exec(title, content, parseMessage(content, 0, ""), user.ID, user.ID, itemType+"_"+strconv.Itoa(itemID)) res, err := stmts.createReport.Exec(title, content, parseMessage(content, 0, ""), user.ID, user.ID, itemType+"_"+strconv.Itoa(itemID))
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -680,7 +680,7 @@ func routeReportSubmit(w http.ResponseWriter, r *http.Request, user User, sitemI
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = addTopicsToForumStmt.Exec(1, fid) _, err = stmts.addTopicsToForum.Exec(1, fid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -698,9 +698,6 @@ func routeAccountOwnEditCritical(w http.ResponseWriter, r *http.Request, user Us
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
pi := Page{"Edit Password", user, headerVars, tList, nil} pi := Page{"Edit Password", user, headerVars, tList, nil}
if preRenderHooks["pre_render_account_own_edit_critical"] != nil { if preRenderHooks["pre_render_account_own_edit_critical"] != nil {
@ -720,9 +717,6 @@ func routeAccountOwnEditCriticalSubmit(w http.ResponseWriter, r *http.Request, u
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
@ -734,7 +728,7 @@ func routeAccountOwnEditCriticalSubmit(w http.ResponseWriter, r *http.Request, u
newPassword := r.PostFormValue("account-new-password") newPassword := r.PostFormValue("account-new-password")
confirmPassword := r.PostFormValue("account-confirm-password") confirmPassword := r.PostFormValue("account-confirm-password")
err = getPasswordStmt.QueryRow(user.ID).Scan(&realPassword, &salt) err = stmts.getPassword.QueryRow(user.ID).Scan(&realPassword, &salt)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("Your account no longer exists.", w, r, user) return LocalError("Your account no longer exists.", w, r, user)
} else if err != nil { } else if err != nil {
@ -774,9 +768,7 @@ func routeAccountOwnEditAvatar(w http.ResponseWriter, r *http.Request, user User
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
pi := Page{"Edit Avatar", user, headerVars, tList, nil} pi := Page{"Edit Avatar", user, headerVars, tList, nil}
if preRenderHooks["pre_render_account_own_edit_avatar"] != nil { if preRenderHooks["pre_render_account_own_edit_avatar"] != nil {
if runPreRenderHook("pre_render_account_own_edit_avatar", w, r, &user, &pi) { if runPreRenderHook("pre_render_account_own_edit_avatar", w, r, &user, &pi) {
@ -801,9 +793,6 @@ func routeAccountOwnEditAvatarSubmit(w http.ResponseWriter, r *http.Request, use
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
err := r.ParseMultipartForm(int64(megabyte)) err := r.ParseMultipartForm(int64(megabyte))
if err != nil { if err != nil {
@ -884,9 +873,7 @@ func routeAccountOwnEditUsername(w http.ResponseWriter, r *http.Request, user Us
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
pi := Page{"Edit Username", user, headerVars, tList, user.Name} pi := Page{"Edit Username", user, headerVars, tList, user.Name}
if preRenderHooks["pre_render_account_own_edit_username"] != nil { if preRenderHooks["pre_render_account_own_edit_username"] != nil {
if runPreRenderHook("pre_render_account_own_edit_username", w, r, &user, &pi) { if runPreRenderHook("pre_render_account_own_edit_username", w, r, &user, &pi) {
@ -905,9 +892,6 @@ func routeAccountOwnEditUsernameSubmit(w http.ResponseWriter, r *http.Request, u
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return LocalError("Bad Form", w, r, user) return LocalError("Bad Form", w, r, user)
@ -939,22 +923,19 @@ func routeAccountOwnEditEmail(w http.ResponseWriter, r *http.Request, user User)
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
email := Email{UserID: user.ID} email := Email{UserID: user.ID}
var emailList []interface{} var emailList []interface{}
rows, err := getEmailsByUserStmt.Query(user.ID) rows, err := stmts.getEmailsByUser.Query(user.ID)
if err != nil { if err != nil {
log.Fatal(err) return InternalError(err, w, r)
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err := rows.Scan(&email.Email, &email.Validated, &email.Token) err := rows.Scan(&email.Email, &email.Validated, &email.Token)
if err != nil { if err != nil {
log.Fatal(err) return InternalError(err, w, r)
} }
if email.Email == user.Email { if email.Email == user.Email {
@ -964,7 +945,7 @@ func routeAccountOwnEditEmail(w http.ResponseWriter, r *http.Request, user User)
} }
err = rows.Err() err = rows.Err()
if err != nil { if err != nil {
log.Fatal(err) return InternalError(err, w, r)
} }
// Was this site migrated from another forum software? Most of them don't have multiple emails for a single user. // Was this site migrated from another forum software? Most of them don't have multiple emails for a single user.
@ -992,20 +973,16 @@ func routeAccountOwnEditEmail(w http.ResponseWriter, r *http.Request, user User)
return nil return nil
} }
func routeAccountOwnEditEmailTokenSubmit(w http.ResponseWriter, r *http.Request, user User) RouteError { func routeAccountOwnEditEmailTokenSubmit(w http.ResponseWriter, r *http.Request, user User, token string) RouteError {
headerVars, ferr := UserCheck(w, r, &user) headerVars, ferr := UserCheck(w, r, &user)
if ferr != nil { if ferr != nil {
return ferr return ferr
} }
if !user.Loggedin {
return LocalError("You need to login to edit your account.", w, r, user)
}
token := r.URL.Path[len("/user/edit/token/"):]
email := Email{UserID: user.ID} email := Email{UserID: user.ID}
targetEmail := Email{UserID: user.ID} targetEmail := Email{UserID: user.ID}
var emailList []interface{} var emailList []interface{}
rows, err := getEmailsByUserStmt.Query(user.ID) rows, err := stmts.getEmailsByUser.Query(user.ID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1037,14 +1014,14 @@ func routeAccountOwnEditEmailTokenSubmit(w http.ResponseWriter, r *http.Request,
return LocalError("That's not a valid token!", w, r, user) return LocalError("That's not a valid token!", w, r, user)
} }
_, err = verifyEmailStmt.Exec(user.Email) _, err = stmts.verifyEmail.Exec(user.Email)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
// If Email Activation is on, then activate the account while we're here // If Email Activation is on, then activate the account while we're here
if headerVars.Settings["activation_type"] == 2 { if headerVars.Settings["activation_type"] == 2 {
_, err = activateUserStmt.Exec(user.ID) _, err = stmts.activateUser.Exec(user.ID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1098,7 +1075,7 @@ func routeShowAttachment(w http.ResponseWriter, r *http.Request, user User, file
var originTable string var originTable string
var originID, uploadedBy int var originID, uploadedBy int
err = getAttachmentStmt.QueryRow(filename, sectionID, sectionTable).Scan(&sectionID, &sectionTable, &originID, &originTable, &uploadedBy, &filename) err = stmts.getAttachment.QueryRow(filename, sectionID, sectionTable).Scan(&sectionID, &sectionTable, &originID, &originTable, &uploadedBy, &filename)
if err == ErrNoRows { if err == ErrNoRows {
return NotFound(w, r) return NotFound(w, r)
} else if err != nil { } else if err != nil {

View File

@ -475,6 +475,21 @@ func TestPermsMiddleware(t *testing.T) {
expect(t, ferr == nil, "Supermods should be allowed through supermod gates") expect(t, ferr == nil, "Supermods should be allowed through supermod gates")
// TODO: Loop over the Control Panel routes and make sure only supermods can get in // TODO: Loop over the Control Panel routes and make sure only supermods can get in
user = getDummyUser()
ferr = MemberOnly(dummyResponseRecorder, dummyRequest, *user)
expect(t, ferr != nil, "Blank users shouldn't be considered loggedin")
user.Loggedin = false
ferr = MemberOnly(dummyResponseRecorder, dummyRequest, *user)
expect(t, ferr != nil, "Guests shouldn't be able to access member areas")
user.Loggedin = true
ferr = MemberOnly(dummyResponseRecorder, dummyRequest, *user)
expect(t, ferr == nil, "Logged in users should be able to access member areas")
// TODO: Loop over the /user/ routes and make sure only members can access the ones other than /user/username
} }
func TestTopicStore(t *testing.T) { func TestTopicStore(t *testing.T) {

View File

@ -123,7 +123,7 @@ func routeDeleteTopic(w http.ResponseWriter, r *http.Request, user User) RouteEr
} }
// ? - We might need to add soft-delete before we can do an action reply for this // ? - We might need to add soft-delete before we can do an action reply for this
/*_, err = createActionReplyStmt.Exec(tid,"delete",ipaddress,user.ID) /*_, err = stmts.createActionReply.Exec(tid,"delete",ipaddress,user.ID)
if err != nil { if err != nil {
return InternalErrorJSQ(err,w,r,isJs) return InternalErrorJSQ(err,w,r,isJs)
}*/ }*/
@ -350,13 +350,13 @@ func routeReplyEditSubmit(w http.ResponseWriter, r *http.Request, user User) Rou
// Get the Reply ID.. // Get the Reply ID..
var tid int var tid int
err = getReplyTIDStmt.QueryRow(rid).Scan(&tid) err = stmts.getReplyTID.QueryRow(rid).Scan(&tid)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
var fid int var fid int
err = getTopicFIDStmt.QueryRow(tid).Scan(&fid) err = stmts.getTopicFID.QueryRow(tid).Scan(&fid)
if err == ErrNoRows { if err == ErrNoRows {
return PreErrorJSQ("The parent topic doesn't exist.", w, r, isJs) return PreErrorJSQ("The parent topic doesn't exist.", w, r, isJs)
} else if err != nil { } else if err != nil {
@ -373,7 +373,7 @@ func routeReplyEditSubmit(w http.ResponseWriter, r *http.Request, user User) Rou
} }
content := html.EscapeString(preparseMessage(r.PostFormValue("edit_item"))) content := html.EscapeString(preparseMessage(r.PostFormValue("edit_item")))
_, err = editReplyStmt.Exec(content, parseMessage(content, fid, "forums"), rid) _, err = stmts.editReply.Exec(content, parseMessage(content, fid, "forums"), rid)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -408,7 +408,7 @@ func routeReplyDeleteSubmit(w http.ResponseWriter, r *http.Request, user User) R
} }
var fid int var fid int
err = getTopicFIDStmt.QueryRow(reply.ParentID).Scan(&fid) err = stmts.getTopicFID.QueryRow(reply.ParentID).Scan(&fid)
if err == ErrNoRows { if err == ErrNoRows {
return PreErrorJSQ("The parent topic doesn't exist.", w, r, isJs) return PreErrorJSQ("The parent topic doesn't exist.", w, r, isJs)
} else if err != nil { } else if err != nil {
@ -472,7 +472,7 @@ func routeProfileReplyEditSubmit(w http.ResponseWriter, r *http.Request, user Us
// Get the Reply ID.. // Get the Reply ID..
var uid int var uid int
err = getUserReplyUIDStmt.QueryRow(rid).Scan(&uid) err = stmts.getUserReplyUID.QueryRow(rid).Scan(&uid)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -482,7 +482,7 @@ func routeProfileReplyEditSubmit(w http.ResponseWriter, r *http.Request, user Us
} }
content := html.EscapeString(preparseMessage(r.PostFormValue("edit_item"))) content := html.EscapeString(preparseMessage(r.PostFormValue("edit_item")))
_, err = editProfileReplyStmt.Exec(content, parseMessage(content, 0, ""), rid) _, err = stmts.editProfileReply.Exec(content, parseMessage(content, 0, ""), rid)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -508,7 +508,7 @@ func routeProfileReplyDeleteSubmit(w http.ResponseWriter, r *http.Request, user
} }
var uid int var uid int
err = getUserReplyUIDStmt.QueryRow(rid).Scan(&uid) err = stmts.getUserReplyUID.QueryRow(rid).Scan(&uid)
if err == ErrNoRows { if err == ErrNoRows {
return LocalErrorJSQ("The reply you tried to delete doesn't exist.", w, r, user, isJs) return LocalErrorJSQ("The reply you tried to delete doesn't exist.", w, r, user, isJs)
} else if err != nil { } else if err != nil {
@ -519,7 +519,7 @@ func routeProfileReplyDeleteSubmit(w http.ResponseWriter, r *http.Request, user
return NoPermissionsJSQ(w, r, user, isJs) return NoPermissionsJSQ(w, r, user, isJs)
} }
_, err = deleteProfileReplyStmt.Exec(rid) _, err = stmts.deleteProfileReply.Exec(rid)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -546,7 +546,7 @@ func routeIps(w http.ResponseWriter, r *http.Request, user User) RouteError {
var uid int var uid int
var reqUserList = make(map[int]bool) var reqUserList = make(map[int]bool)
rows, err := findUsersByIPUsersStmt.Query(ip) rows, err := stmts.findUsersByIPUsers.Query(ip)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -564,7 +564,7 @@ func routeIps(w http.ResponseWriter, r *http.Request, user User) RouteError {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
rows2, err := findUsersByIPTopicsStmt.Query(ip) rows2, err := stmts.findUsersByIPTopics.Query(ip)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -582,7 +582,7 @@ func routeIps(w http.ResponseWriter, r *http.Request, user User) RouteError {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
rows3, err := findUsersByIPRepliesStmt.Query(ip) rows3, err := stmts.findUsersByIPReplies.Query(ip)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }

View File

@ -20,16 +20,6 @@ import (
var dbInstance string = "" var dbInstance string = ""
var getActivityFeedByWatcherStmt *sql.Stmt
var getActivityCountByWatcherStmt *sql.Stmt
var todaysPostCountStmt *sql.Stmt
var todaysTopicCountStmt *sql.Stmt
var todaysReportCountStmt *sql.Stmt
var todaysNewUserCountStmt *sql.Stmt
var findUsersByIPUsersStmt *sql.Stmt
var findUsersByIPTopicsStmt *sql.Stmt
var findUsersByIPRepliesStmt *sql.Stmt
func init() { func init() {
dbAdapter = "mssql" dbAdapter = "mssql"
_initDatabase = initMSSQL _initDatabase = initMSSQL

View File

@ -16,15 +16,6 @@ import _ "github.com/go-sql-driver/mysql"
import "./query_gen/lib" import "./query_gen/lib"
var dbCollation = "utf8mb4_general_ci" var dbCollation = "utf8mb4_general_ci"
var getActivityFeedByWatcherStmt *sql.Stmt
var getActivityCountByWatcherStmt *sql.Stmt
var todaysPostCountStmt *sql.Stmt
var todaysTopicCountStmt *sql.Stmt
var todaysReportCountStmt *sql.Stmt
var todaysNewUserCountStmt *sql.Stmt
var findUsersByIPUsersStmt *sql.Stmt
var findUsersByIPTopicsStmt *sql.Stmt
var findUsersByIPRepliesStmt *sql.Stmt
func init() { func init() {
dbAdapter = "mysql" dbAdapter = "mysql"
@ -75,54 +66,54 @@ func initMySQL() (err error) {
// TODO: Is there a less noisy way of doing this for tests? // TODO: Is there a less noisy way of doing this for tests?
log.Print("Preparing get_activity_feed_by_watcher statement.") log.Print("Preparing get_activity_feed_by_watcher statement.")
getActivityFeedByWatcherStmt, err = db.Prepare("SELECT activity_stream_matches.asid, activity_stream.actor, activity_stream.targetUser, activity_stream.event, activity_stream.elementType, activity_stream.elementID FROM `activity_stream_matches` INNER JOIN `activity_stream` ON activity_stream_matches.asid = activity_stream.asid AND activity_stream_matches.watcher != activity_stream.actor WHERE `watcher` = ? ORDER BY activity_stream.asid ASC LIMIT 8") stmts.getActivityFeedByWatcher, err = db.Prepare("SELECT activity_stream_matches.asid, activity_stream.actor, activity_stream.targetUser, activity_stream.event, activity_stream.elementType, activity_stream.elementID FROM `activity_stream_matches` INNER JOIN `activity_stream` ON activity_stream_matches.asid = activity_stream.asid AND activity_stream_matches.watcher != activity_stream.actor WHERE `watcher` = ? ORDER BY activity_stream.asid ASC LIMIT 8")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing get_activity_count_by_watcher statement.") log.Print("Preparing get_activity_count_by_watcher statement.")
getActivityCountByWatcherStmt, err = db.Prepare("SELECT count(*) FROM `activity_stream_matches` INNER JOIN `activity_stream` ON activity_stream_matches.asid = activity_stream.asid AND activity_stream_matches.watcher != activity_stream.actor WHERE `watcher` = ?") stmts.getActivityCountByWatcher, err = db.Prepare("SELECT count(*) FROM `activity_stream_matches` INNER JOIN `activity_stream` ON activity_stream_matches.asid = activity_stream.asid AND activity_stream_matches.watcher != activity_stream.actor WHERE `watcher` = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing todays_post_count statement.") log.Print("Preparing todays_post_count statement.")
todaysPostCountStmt, err = db.Prepare("select count(*) from replies where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp()") stmts.todaysPostCount, err = db.Prepare("select count(*) from replies where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp()")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing todays_topic_count statement.") log.Print("Preparing todays_topic_count statement.")
todaysTopicCountStmt, err = db.Prepare("select count(*) from topics where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp()") stmts.todaysTopicCount, err = db.Prepare("select count(*) from topics where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp()")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing todays_report_count statement.") log.Print("Preparing todays_report_count statement.")
todaysReportCountStmt, err = db.Prepare("select count(*) from topics where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp() and parentID = 1") stmts.todaysReportCount, err = db.Prepare("select count(*) from topics where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp() and parentID = 1")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing todays_newuser_count statement.") log.Print("Preparing todays_newuser_count statement.")
todaysNewUserCountStmt, err = db.Prepare("select count(*) from users where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp()") stmts.todaysNewUserCount, err = db.Prepare("select count(*) from users where createdAt BETWEEN (utc_timestamp() - interval 1 day) and utc_timestamp()")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing find_users_by_ip_users statement.") log.Print("Preparing find_users_by_ip_users statement.")
findUsersByIPUsersStmt, err = db.Prepare("select uid from users where last_ip = ?") stmts.findUsersByIPUsers, err = db.Prepare("select uid from users where last_ip = ?")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing find_users_by_ip_topics statement.") log.Print("Preparing find_users_by_ip_topics statement.")
findUsersByIPTopicsStmt, err = db.Prepare("select uid from users where uid in(select createdBy from topics where ipaddress = ?)") stmts.findUsersByIPTopics, err = db.Prepare("select uid from users where uid in(select createdBy from topics where ipaddress = ?)")
if err != nil { if err != nil {
return err return err
} }
log.Print("Preparing find_users_by_ip_replies statement.") log.Print("Preparing find_users_by_ip_replies statement.")
findUsersByIPRepliesStmt, err = db.Prepare("select uid from users where uid in(select createdBy from replies where ipaddress = ?)") stmts.findUsersByIPReplies, err = db.Prepare("select uid from users where uid in(select createdBy from replies where ipaddress = ?)")
return err return err
} }

View File

@ -63,7 +63,7 @@ func routePanel(w http.ResponseWriter, r *http.Request, user User) RouteError {
} }
var postCount int var postCount int
err = todaysPostCountStmt.QueryRow().Scan(&postCount) err = stmts.todaysPostCount.QueryRow().Scan(&postCount)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -79,7 +79,7 @@ func routePanel(w http.ResponseWriter, r *http.Request, user User) RouteError {
} }
var topicCount int var topicCount int
err = todaysTopicCountStmt.QueryRow().Scan(&topicCount) err = stmts.todaysTopicCount.QueryRow().Scan(&topicCount)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -95,14 +95,14 @@ func routePanel(w http.ResponseWriter, r *http.Request, user User) RouteError {
} }
var reportCount int var reportCount int
err = todaysReportCountStmt.QueryRow().Scan(&reportCount) err = stmts.todaysReportCount.QueryRow().Scan(&reportCount)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
var reportInterval = "week" var reportInterval = "week"
var newUserCount int var newUserCount int
err = todaysNewUserCountStmt.QueryRow().Scan(&newUserCount) err = stmts.todaysNewUserCount.QueryRow().Scan(&newUserCount)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -483,7 +483,7 @@ func routePanelForumsEditPermsSubmit(w http.ResponseWriter, r *http.Request, use
} }
// TODO: Add this and replaceForumPermsForGroup into a transaction? // TODO: Add this and replaceForumPermsForGroup into a transaction?
_, err = updateForumStmt.Exec(forum.Name, forum.Desc, forum.Active, "", fid) _, err = stmts.updateForum.Exec(forum.Name, forum.Desc, forum.Active, "", fid)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -512,7 +512,7 @@ func routePanelSettings(w http.ResponseWriter, r *http.Request, user User) Route
//log.Print("headerVars.Settings",headerVars.Settings) //log.Print("headerVars.Settings",headerVars.Settings)
var settingList = make(map[string]interface{}) var settingList = make(map[string]interface{})
rows, err := getSettingsStmt.Query() rows, err := stmts.getSettings.Query()
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -572,7 +572,7 @@ func routePanelSetting(w http.ResponseWriter, r *http.Request, user User, sname
} }
setting := Setting{sname, "", "", ""} setting := Setting{sname, "", "", ""}
err := getSettingStmt.QueryRow(setting.Name).Scan(&setting.Content, &setting.Type) err := stmts.getSetting.QueryRow(setting.Name).Scan(&setting.Content, &setting.Type)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("The setting you want to edit doesn't exist.", w, r, user) return LocalError("The setting you want to edit doesn't exist.", w, r, user)
} else if err != nil { } else if err != nil {
@ -630,7 +630,7 @@ func routePanelSettingEdit(w http.ResponseWriter, r *http.Request, user User, sn
var stype, sconstraints string var stype, sconstraints string
scontent := r.PostFormValue("setting-value") scontent := r.PostFormValue("setting-value")
err = getFullSettingStmt.QueryRow(sname).Scan(&sname, &stype, &sconstraints) err = stmts.getFullSetting.QueryRow(sname).Scan(&sname, &stype, &sconstraints)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("The setting you want to edit doesn't exist.", w, r, user) return LocalError("The setting you want to edit doesn't exist.", w, r, user)
} else if err != nil { } else if err != nil {
@ -646,7 +646,7 @@ func routePanelSettingEdit(w http.ResponseWriter, r *http.Request, user User, sn
} }
// TODO: Make this a method or function? // TODO: Make this a method or function?
_, err = updateSettingStmt.Exec(scontent, sname) _, err = stmts.updateSetting.Exec(scontent, sname)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -707,7 +707,7 @@ func routePanelWordFiltersCreate(w http.ResponseWriter, r *http.Request, user Us
// Unlike with find, it's okay if we leave this blank, as this means that the admin wants to remove the word entirely with no replacement // Unlike with find, it's okay if we leave this blank, as this means that the admin wants to remove the word entirely with no replacement
replacement := strings.TrimSpace(r.PostFormValue("replacement")) replacement := strings.TrimSpace(r.PostFormValue("replacement"))
res, err := createWordFilterStmt.Exec(find, replacement) res, err := stmts.createWordFilter.Exec(find, replacement)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -778,7 +778,7 @@ func routePanelWordFiltersEditSubmit(w http.ResponseWriter, r *http.Request, use
// Unlike with find, it's okay if we leave this blank, as this means that the admin wants to remove the word entirely with no replacement // Unlike with find, it's okay if we leave this blank, as this means that the admin wants to remove the word entirely with no replacement
replacement := strings.TrimSpace(r.PostFormValue("replacement")) replacement := strings.TrimSpace(r.PostFormValue("replacement"))
_, err = updateWordFilterStmt.Exec(find, replacement, id) _, err = stmts.updateWordFilter.Exec(find, replacement, id)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -811,7 +811,7 @@ func routePanelWordFiltersDeleteSubmit(w http.ResponseWriter, r *http.Request, u
return LocalErrorJSQ("The word filter ID must be an integer.", w, r, user, isJs) return LocalErrorJSQ("The word filter ID must be an integer.", w, r, user, isJs)
} }
_, err = deleteWordFilterStmt.Exec(id) _, err = stmts.deleteWordFilter.Exec(id)
if err != nil { if err != nil {
return InternalErrorJSQ(err, w, r, isJs) return InternalErrorJSQ(err, w, r, isJs)
} }
@ -874,7 +874,7 @@ func routePanelPluginsActivate(w http.ResponseWriter, r *http.Request, user User
} }
var active bool var active bool
err := isPluginActiveStmt.QueryRow(uname).Scan(&active) err := stmts.isPluginActive.QueryRow(uname).Scan(&active)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -894,13 +894,13 @@ func routePanelPluginsActivate(w http.ResponseWriter, r *http.Request, user User
return LocalError("The plugin is already active", w, r, user) return LocalError("The plugin is already active", w, r, user)
} }
//log.Print("updatePlugin") //log.Print("updatePlugin")
_, err = updatePluginStmt.Exec(1, uname) _, err = stmts.updatePlugin.Exec(1, uname)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
} else { } else {
//log.Print("addPlugin") //log.Print("addPlugin")
_, err := addPluginStmt.Exec(uname, 1, 0) _, err := stmts.addPlugin.Exec(uname, 1, 0)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -936,7 +936,7 @@ func routePanelPluginsDeactivate(w http.ResponseWriter, r *http.Request, user Us
} }
var active bool var active bool
err := isPluginActiveStmt.QueryRow(uname).Scan(&active) err := stmts.isPluginActive.QueryRow(uname).Scan(&active)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("The plugin you're trying to deactivate isn't active", w, r, user) return LocalError("The plugin you're trying to deactivate isn't active", w, r, user)
} else if err != nil { } else if err != nil {
@ -946,7 +946,7 @@ func routePanelPluginsDeactivate(w http.ResponseWriter, r *http.Request, user Us
if !active { if !active {
return LocalError("The plugin you're trying to deactivate isn't active", w, r, user) return LocalError("The plugin you're trying to deactivate isn't active", w, r, user)
} }
_, err = updatePluginStmt.Exec(0, uname) _, err = stmts.updatePlugin.Exec(0, uname)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -985,7 +985,7 @@ func routePanelPluginsInstall(w http.ResponseWriter, r *http.Request, user User,
} }
var active bool var active bool
err := isPluginActiveStmt.QueryRow(uname).Scan(&active) err := stmts.isPluginActive.QueryRow(uname).Scan(&active)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1006,16 +1006,16 @@ func routePanelPluginsInstall(w http.ResponseWriter, r *http.Request, user User,
} }
if hasPlugin { if hasPlugin {
_, err = updatePluginInstallStmt.Exec(1, uname) _, err = stmts.updatePluginInstall.Exec(1, uname)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = updatePluginStmt.Exec(1, uname) _, err = stmts.updatePlugin.Exec(1, uname)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
} else { } else {
_, err := addPluginStmt.Exec(uname, 1, 1) _, err := stmts.addPlugin.Exec(uname, 1, 1)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1046,7 +1046,7 @@ func routePanelUsers(w http.ResponseWriter, r *http.Request, user User) RouteErr
var userList []User var userList []User
// TODO: Move this into the UserStore // TODO: Move this into the UserStore
rows, err := getUsersOffsetStmt.Query(offset, perPage) rows, err := stmts.getUsersOffset.Query(offset, perPage)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1216,7 +1216,7 @@ func routePanelUsersEditSubmit(w http.ResponseWriter, r *http.Request, user User
return LocalError("You need the EditUserGroupSuperMod permission to assign someone to a super mod group.", w, r, user) return LocalError("You need the EditUserGroupSuperMod permission to assign someone to a super mod group.", w, r, user)
} }
_, err = updateUserStmt.Exec(newname, newemail, newgroup, targetUser.ID) _, err = stmts.updateUser.Exec(newname, newemail, newgroup, targetUser.ID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1516,7 +1516,7 @@ func routePanelGroupsEditSubmit(w http.ResponseWriter, r *http.Request, user Use
} }
// TODO: Move this to *Group // TODO: Move this to *Group
_, err = updateGroupStmt.Exec(gname, gtag, gid) _, err = stmts.updateGroup.Exec(gname, gtag, gid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1581,7 +1581,7 @@ func routePanelGroupsEditPermsSubmit(w http.ResponseWriter, r *http.Request, use
if err != nil { if err != nil {
return LocalError("Unable to marshal the data", w, r, user) return LocalError("Unable to marshal the data", w, r, user)
} }
_, err = updateGroupPermsStmt.Exec(pjson, gid) _, err = stmts.updateGroupPerms.Exec(pjson, gid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1696,7 +1696,7 @@ func routePanelThemesSetDefault(w http.ResponseWriter, r *http.Request, user Use
var isDefault bool var isDefault bool
log.Print("uname", uname) // TODO: Do we need to log this? log.Print("uname", uname) // TODO: Do we need to log this?
err := isThemeDefaultStmt.QueryRow(uname).Scan(&isDefault) err := stmts.isThemeDefault.QueryRow(uname).Scan(&isDefault)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1707,12 +1707,12 @@ func routePanelThemesSetDefault(w http.ResponseWriter, r *http.Request, user Use
if isDefault { if isDefault {
return LocalError("The theme is already active", w, r, user) return LocalError("The theme is already active", w, r, user)
} }
_, err = updateThemeStmt.Exec(1, uname) _, err = stmts.updateTheme.Exec(1, uname)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
} else { } else {
_, err := addThemeStmt.Exec(uname, 1) _, err := stmts.addTheme.Exec(uname, 1)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1721,7 +1721,7 @@ func routePanelThemesSetDefault(w http.ResponseWriter, r *http.Request, user Use
// TODO: Make this less racey // TODO: Make this less racey
changeDefaultThemeMutex.Lock() changeDefaultThemeMutex.Lock()
defaultTheme := defaultThemeBox.Load().(string) defaultTheme := defaultThemeBox.Load().(string)
_, err = updateThemeStmt.Exec(0, defaultTheme) _, err = stmts.updateTheme.Exec(0, defaultTheme)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1803,7 +1803,7 @@ func routePanelLogsMod(w http.ResponseWriter, r *http.Request, user User) RouteE
} }
var logCount int var logCount int
err := modlogCountStmt.QueryRow().Scan(&logCount) err := stmts.modlogCount.QueryRow().Scan(&logCount)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1812,7 +1812,7 @@ func routePanelLogsMod(w http.ResponseWriter, r *http.Request, user User) RouteE
perPage := 10 perPage := 10
offset, page, lastPage := pageOffset(logCount, page, perPage) offset, page, lastPage := pageOffset(logCount, page, perPage)
rows, err := getModlogsOffsetStmt.Query(offset, perPage) rows, err := stmts.getModlogsOffset.Query(offset, perPage)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }

View File

@ -13,12 +13,6 @@ import "./query_gen/lib"
// TODO: Add support for SSL for all database drivers, not just pgsql // TODO: Add support for SSL for all database drivers, not just pgsql
var db_sslmode = "disable" // verify-full var db_sslmode = "disable" // verify-full
var get_activity_feed_by_watcher_stmt *sql.Stmt
var get_activity_count_by_watcher_stmt *sql.Stmt
var todays_post_count_stmt *sql.Stmt
var todays_topic_count_stmt *sql.Stmt
var todays_report_count_stmt *sql.Stmt
var todays_newuser_count_stmt *sql.Stmt
func init() { func init() {
db_adapter = "pgsql" db_adapter = "pgsql"

View File

@ -491,14 +491,6 @@ func (adapter *MssqlAdapter) SimpleSelect(name string, table string, columns str
} }
} }
/*if limiter.MaxCount != "" {
if limiter.MaxCount == "?" {
substituteCount++
limiter.MaxCount = "?" + strconv.Itoa(substituteCount)
}
querystr = "TOP " + limiter.MaxCount + " " + querystr
}*/
// ! Does this work without an offset? // ! Does this work without an offset?
if limiter.MaxCount != "" { if limiter.MaxCount != "" {
if limiter.MaxCount == "?" { if limiter.MaxCount == "?" {
@ -1086,10 +1078,10 @@ func (adapter *MssqlAdapter) Write() error {
stmt := adapter.Buffer[name] stmt := adapter.Buffer[name]
// TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :( // TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :(
if stmt.Type != "create-table" { if stmt.Type != "create-table" {
stmts += "var " + name + "Stmt *sql.Stmt\n" stmts += "\t" + name + " *sql.Stmt\n"
body += ` body += `
log.Print("Preparing ` + name + ` statement.") log.Print("Preparing ` + name + ` statement.")
` + name + `Stmt, err = db.Prepare("` + stmt.Contents + `") stmts.` + name + `, err = db.Prepare("` + stmt.Contents + `")
if err != nil { if err != nil {
log.Print("Bad Query: ","` + stmt.Contents + `") log.Print("Bad Query: ","` + stmt.Contents + `")
return err return err
@ -1098,6 +1090,7 @@ func (adapter *MssqlAdapter) Write() error {
} }
} }
// TODO: Move these custom queries out of this file
out := `// +build mssql out := `// +build mssql
// This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time. // This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time.
@ -1107,7 +1100,21 @@ import "log"
import "database/sql" import "database/sql"
// nolint // nolint
type Stmts struct {
` + stmts + ` ` + stmts + `
getActivityFeedByWatcher *sql.Stmt
getActivityCountByWatcher *sql.Stmt
todaysPostCount *sql.Stmt
todaysTopicCount *sql.Stmt
todaysReportCount *sql.Stmt
todaysNewUserCount *sql.Stmt
findUsersByIPUsers *sql.Stmt
findUsersByIPTopics *sql.Stmt
findUsersByIPReplies *sql.Stmt
Mocks bool
}
// nolint // nolint
func _gen_mssql() (err error) { func _gen_mssql() (err error) {
if dev.DebugMode { if dev.DebugMode {

View File

@ -129,10 +129,10 @@ func (adapter *MysqlAdapter) SimpleInsert(name string, table string, columns str
} }
querystr += field.Name + "," querystr += field.Name + ","
} }
querystr = querystr[0 : len(querystr)-1] querystr = querystr[0:len(querystr)-1] + ")"
adapter.pushStatement(name, "insert", querystr+")") adapter.pushStatement(name, "insert", querystr)
return querystr + ")", nil return querystr, nil
} }
func (adapter *MysqlAdapter) buildColumns(columns string) (querystr string) { func (adapter *MysqlAdapter) buildColumns(columns string) (querystr string) {
@ -307,7 +307,7 @@ func (adapter *MysqlAdapter) Purge(name string, table string) (string, error) {
// TODO: Add support for BETWEEN x.x // TODO: Add support for BETWEEN x.x
func (adapter *MysqlAdapter) buildWhere(where string) (querystr string, err error) { func (adapter *MysqlAdapter) buildWhere(where string) (querystr string, err error) {
if len(where) != 0 { if len(where) != 0 {
querystr += " WHERE" querystr = " WHERE"
for _, loc := range processWhere(where) { for _, loc := range processWhere(where) {
for _, token := range loc.Expr { for _, token := range loc.Expr {
switch token.Type { switch token.Type {
@ -330,7 +330,7 @@ func (adapter *MysqlAdapter) buildWhere(where string) (querystr string, err erro
func (adapter *MysqlAdapter) buildOrderby(orderby string) (querystr string) { func (adapter *MysqlAdapter) buildOrderby(orderby string) (querystr string) {
if len(orderby) != 0 { if len(orderby) != 0 {
querystr += " ORDER BY " querystr = " ORDER BY "
for _, column := range processOrderby(orderby) { for _, column := range processOrderby(orderby) {
// TODO: We might want to escape this column // TODO: We might want to escape this column
querystr += column.Column + " " + strings.ToUpper(column.Order) + "," querystr += column.Column + " " + strings.ToUpper(column.Order) + ","
@ -468,7 +468,7 @@ func (adapter *MysqlAdapter) buildJoiners(joiners string) (querystr string) {
// Add support for BETWEEN x.x // Add support for BETWEEN x.x
func (adapter *MysqlAdapter) buildJoinWhere(where string) (querystr string, err error) { func (adapter *MysqlAdapter) buildJoinWhere(where string) (querystr string, err error) {
if len(where) != 0 { if len(where) != 0 {
querystr += " WHERE" querystr = " WHERE"
for _, loc := range processWhere(where) { for _, loc := range processWhere(where) {
for _, token := range loc.Expr { for _, token := range loc.Expr {
switch token.Type { switch token.Type {
@ -496,24 +496,22 @@ func (adapter *MysqlAdapter) buildJoinWhere(where string) (querystr string, err
func (adapter *MysqlAdapter) buildLimit(limit string) (querystr string) { func (adapter *MysqlAdapter) buildLimit(limit string) (querystr string) {
if limit != "" { if limit != "" {
querystr += " LIMIT " + limit querystr = " LIMIT " + limit
} }
return querystr return querystr
} }
func (adapter *MysqlAdapter) buildJoinColumns(columns string) (querystr string) { func (adapter *MysqlAdapter) buildJoinColumns(columns string) (querystr string) {
for _, column := range processColumns(columns) { for _, column := range processColumns(columns) {
var source, alias string
// Escape the column names, just in case we've used a reserved keyword // Escape the column names, just in case we've used a reserved keyword
var source = column.Left
if column.Table != "" { if column.Table != "" {
source = "`" + column.Table + "`.`" + column.Left + "`" source = "`" + column.Table + "`.`" + source + "`"
} else if column.Type == "function" { } else if column.Type != "function" {
source = column.Left source = "`" + source + "`"
} else {
source = "`" + column.Left + "`"
} }
var alias string
if column.Alias != "" { if column.Alias != "" {
alias = " AS `" + column.Alias + "`" alias = " AS `" + column.Alias + "`"
} }
@ -563,19 +561,19 @@ func (adapter *MysqlAdapter) Write() error {
stmt := adapter.Buffer[name] stmt := adapter.Buffer[name]
// ? - Table creation might be a little complex for Go to do outside a SQL file :( // ? - Table creation might be a little complex for Go to do outside a SQL file :(
if stmt.Type == "upsert" { if stmt.Type == "upsert" {
stmts += "var " + name + "Stmt *qgen.MySQLUpsertCallback\n" stmts += "\t" + name + " *qgen.MySQLUpsertCallback\n"
body += ` body += `
log.Print("Preparing ` + name + ` statement.") log.Print("Preparing ` + name + ` statement.")
` + name + `Stmt, err = qgen.PrepareMySQLUpsertCallback(db, "` + stmt.Contents + `") stmts.` + name + `, err = qgen.PrepareMySQLUpsertCallback(db, "` + stmt.Contents + `")
if err != nil { if err != nil {
return err return err
} }
` `
} else if stmt.Type != "create-table" { } else if stmt.Type != "create-table" {
stmts += "var " + name + "Stmt *sql.Stmt\n" stmts += "\t" + name + " *sql.Stmt\n"
body += ` body += `
log.Print("Preparing ` + name + ` statement.") log.Print("Preparing ` + name + ` statement.")
` + name + `Stmt, err = db.Prepare("` + stmt.Contents + `") stmts.` + name + `, err = db.Prepare("` + stmt.Contents + `")
if err != nil { if err != nil {
return err return err
} }
@ -583,6 +581,7 @@ func (adapter *MysqlAdapter) Write() error {
} }
} }
// TODO: Move these custom queries out of this file
out := `// +build !pgsql, !sqlite, !mssql out := `// +build !pgsql, !sqlite, !mssql
/* This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time. */ /* This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time. */
@ -594,7 +593,21 @@ import "database/sql"
//import "./query_gen/lib" //import "./query_gen/lib"
// nolint // nolint
type Stmts struct {
` + stmts + ` ` + stmts + `
getActivityFeedByWatcher *sql.Stmt
getActivityCountByWatcher *sql.Stmt
todaysPostCount *sql.Stmt
todaysTopicCount *sql.Stmt
todaysReportCount *sql.Stmt
todaysNewUserCount *sql.Stmt
findUsersByIPUsers *sql.Stmt
findUsersByIPTopics *sql.Stmt
findUsersByIPReplies *sql.Stmt
Mocks bool
}
// nolint // nolint
func _gen_mysql() (err error) { func _gen_mysql() (err error) {
if dev.DebugMode { if dev.DebugMode {

View File

@ -327,10 +327,10 @@ func (adapter *PgsqlAdapter) Write() error {
stmt := adapter.Buffer[name] stmt := adapter.Buffer[name]
// TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :( // TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :(
if stmt.Type != "create-table" { if stmt.Type != "create-table" {
stmts += "var " + name + "Stmt *sql.Stmt\n" stmts += "\t" + name + " *sql.Stmt\n"
body += ` body += `
log.Print("Preparing ` + name + ` statement.") log.Print("Preparing ` + name + ` statement.")
` + name + `Stmt, err = db.Prepare("` + stmt.Contents + `") stmts.` + name + `, err = db.Prepare("` + stmt.Contents + `")
if err != nil { if err != nil {
return err return err
} }
@ -338,6 +338,7 @@ func (adapter *PgsqlAdapter) Write() error {
} }
} }
// TODO: Move these custom queries out of this file
out := `// +build pgsql out := `// +build pgsql
// This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time. // This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time.
@ -347,7 +348,21 @@ import "log"
import "database/sql" import "database/sql"
// nolint // nolint
type Stmts struct {
` + stmts + ` ` + stmts + `
getActivityFeedByWatcher *sql.Stmt
getActivityCountByWatcher *sql.Stmt
todaysPostCount *sql.Stmt
todaysTopicCount *sql.Stmt
todaysReportCount *sql.Stmt
todaysNewUserCount *sql.Stmt
findUsersByIPUsers *sql.Stmt
findUsersByIPTopics *sql.Stmt
findUsersByIPReplies *sql.Stmt
Mocks bool
}
// nolint // nolint
func _gen_pgsql() (err error) { func _gen_pgsql() (err error) {
if dev.DebugMode { if dev.DebugMode {

View File

@ -65,7 +65,7 @@ var ErrAlreadyLiked = errors.New("You already liked this!")
// TODO: Wrap these queries in a transaction to make sure the state is consistent // TODO: Wrap these queries in a transaction to make sure the state is consistent
func (reply *Reply) Like(uid int) (err error) { func (reply *Reply) Like(uid int) (err error) {
var rid int // unused, just here to avoid mutating reply.ID var rid int // unused, just here to avoid mutating reply.ID
err = hasLikedReplyStmt.QueryRow(uid, reply.ID).Scan(&rid) err = stmts.hasLikedReply.QueryRow(uid, reply.ID).Scan(&rid)
if err != nil && err != ErrNoRows { if err != nil && err != ErrNoRows {
return err return err
} else if err != ErrNoRows { } else if err != ErrNoRows {
@ -73,21 +73,21 @@ func (reply *Reply) Like(uid int) (err error) {
} }
score := 1 score := 1
_, err = createLikeStmt.Exec(score, reply.ID, "replies", uid) _, err = stmts.createLike.Exec(score, reply.ID, "replies", uid)
if err != nil { if err != nil {
return err return err
} }
_, err = addLikesToReplyStmt.Exec(1, reply.ID) _, err = stmts.addLikesToReply.Exec(1, reply.ID)
return err return err
} }
// TODO: Write tests for this // TODO: Write tests for this
func (reply *Reply) Delete() error { func (reply *Reply) Delete() error {
_, err := deleteReplyStmt.Exec(reply.ID) _, err := stmts.deleteReply.Exec(reply.ID)
if err != nil { if err != nil {
return err return err
} }
_, err = removeRepliesFromTopicStmt.Exec(1, reply.ParentID) _, err = stmts.removeRepliesFromTopic.Exec(1, reply.ParentID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(reply.ParentID) tcache.CacheRemove(reply.ParentID)
@ -100,6 +100,7 @@ func (reply *Reply) Copy() Reply {
return *reply return *reply
} }
// TODO: Refactor this to stop hitting the global stmt store
type ReplyStore interface { type ReplyStore interface {
Get(id int) (*Reply, error) Get(id int) (*Reply, error)
Create(tid int, content string, ipaddress string, fid int, uid int) (id int, err error) Create(tid int, content string, ipaddress string, fid int, uid int) (id int, err error)
@ -114,14 +115,14 @@ func NewSQLReplyStore() *SQLReplyStore {
func (store *SQLReplyStore) Get(id int) (*Reply, error) { func (store *SQLReplyStore) Get(id int) (*Reply, error) {
reply := Reply{ID: id} reply := Reply{ID: id}
err := getReplyStmt.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress, &reply.LikeCount) err := stmts.getReply.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress, &reply.LikeCount)
return &reply, err return &reply, err
} }
// TODO: Write a test for this // TODO: Write a test for this
func (store *SQLReplyStore) Create(tid int, content string, ipaddress string, fid int, uid int) (id int, err error) { func (store *SQLReplyStore) Create(tid int, content string, ipaddress string, fid int, uid int) (id int, err error) {
wcount := wordCount(content) wcount := wordCount(content)
res, err := createReplyStmt.Exec(tid, content, parseMessage(content, fid, "forums"), ipaddress, wcount, uid) res, err := stmts.createReply.Exec(tid, content, parseMessage(content, fid, "forums"), ipaddress, wcount, uid)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -130,7 +131,7 @@ func (store *SQLReplyStore) Create(tid int, content string, ipaddress string, fi
return 0, err return 0, err
} }
_, err = addRepliesToTopicStmt.Exec(1, uid, tid) _, err = stmts.addRepliesToTopic.Exec(1, uid, tid)
if err != nil { if err != nil {
return int(lastID), err return int(lastID), err
} }
@ -145,6 +146,7 @@ type ProfileReplyStore interface {
Get(id int) (*Reply, error) Get(id int) (*Reply, error)
} }
// TODO: Refactor this to stop using the global stmt store
type SQLProfileReplyStore struct { type SQLProfileReplyStore struct {
} }
@ -154,6 +156,6 @@ func NewSQLProfileReplyStore() *SQLProfileReplyStore {
func (store *SQLProfileReplyStore) Get(id int) (*Reply, error) { func (store *SQLProfileReplyStore) Get(id int) (*Reply, error) {
reply := Reply{ID: id} reply := Reply{ID: id}
err := getUserReplyStmt.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress) err := stmts.getUserReply.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress)
return &reply, err return &reply, err
} }

View File

@ -26,8 +26,20 @@ func main() {
end = len(route.Path) - 1 end = len(route.Path) - 1
} }
out += "\n\t\tcase \"" + route.Path[0:end] + "\":" out += "\n\t\tcase \"" + route.Path[0:end] + "\":"
if route.Before != "" { if len(route.RunBefore) > 0 {
out += "\n\t\t\t" + route.Before for _, runnable := range route.RunBefore {
if runnable.Literal {
out += "\n\t\t\t\t\t" + runnable.Contents
} else {
out += `
err = ` + runnable.Contents + `(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
`
}
}
} }
out += "\n\t\t\terr = " + route.Name + "(w,req,user" out += "\n\t\t\terr = " + route.Name + "(w,req,user"
for _, item := range route.Vars { for _, item := range route.Vars {
@ -48,14 +60,18 @@ func main() {
} }
out += ` out += `
case "` + group.Path[0:end] + `":` case "` + group.Path[0:end] + `":`
for _, callback := range group.Before { for _, runnable := range group.RunBefore {
out += ` if runnable.Literal {
err = ` + callback + `(w,req,user) out += "\t\t\t" + runnable.Contents
} else {
out += `
err = ` + runnable.Contents + `(w,req,user)
if err != nil { if err != nil {
router.handleError(err,w,req,user) router.handleError(err,w,req,user)
return return
} }
` `
}
} }
out += "\n\t\t\tswitch(req.URL.Path) {" out += "\n\t\t\tswitch(req.URL.Path) {"
@ -67,8 +83,20 @@ func main() {
} }
out += "\n\t\t\t\tcase \"" + route.Path + "\":" out += "\n\t\t\t\tcase \"" + route.Path + "\":"
if route.Before != "" { if len(route.RunBefore) > 0 {
out += "\n\t\t\t\t\t" + route.Before for _, runnable := range route.RunBefore {
if runnable.Literal {
out += "\n\t\t\t\t\t" + runnable.Contents
} else {
out += `
err = ` + runnable.Contents + `(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
`
}
}
} }
out += "\n\t\t\t\t\terr = " + route.Name + "(w,req,user" out += "\n\t\t\t\t\terr = " + route.Name + "(w,req,user"
for _, item := range route.Vars { for _, item := range route.Vars {
@ -79,8 +107,20 @@ func main() {
if defaultRoute.Name != "" { if defaultRoute.Name != "" {
out += "\n\t\t\t\tdefault:" out += "\n\t\t\t\tdefault:"
if defaultRoute.Before != "" { if len(defaultRoute.RunBefore) > 0 {
out += "\n\t\t\t\t\t" + defaultRoute.Before for _, runnable := range defaultRoute.RunBefore {
if runnable.Literal {
out += "\n\t\t\t\t\t" + runnable.Contents
} else {
out += `
err = ` + runnable.Contents + `(w,req,user)
if err != nil {
router.handleError(err,w,req,user)
return
}
`
}
}
} }
out += "\n\t\t\t\t\terr = " + defaultRoute.Name + "(w,req,user" out += "\n\t\t\t\t\terr = " + defaultRoute.Name + "(w,req,user"
for _, item := range defaultRoute.Vars { for _, item := range defaultRoute.Vars {

View File

@ -1,32 +1,51 @@
package main package main
type RouteImpl struct { type RouteImpl struct {
Name string Name string
Path string Path string
Before string Vars []string
Vars []string RunBefore []Runnable
} }
type RouteGroup struct { type RouteGroup struct {
Path string Path string
RouteList []*RouteImpl RouteList []*RouteImpl
Before []string RunBefore []Runnable
} }
func addRoute(fname string, path string, before string, vars ...string) { type Runnable struct {
routeList = append(routeList, &RouteImpl{fname, path, before, vars}) Contents string
Literal bool
}
func addRoute(route *RouteImpl) {
routeList = append(routeList, route)
}
func (route *RouteImpl) Before(item string, literal ...bool) *RouteImpl {
var litItem bool
if len(literal) > 0 {
litItem = literal[0]
}
route.RunBefore = append(route.RunBefore, Runnable{item, litItem})
return route
} }
func newRouteGroup(path string, routes ...*RouteImpl) *RouteGroup { func newRouteGroup(path string, routes ...*RouteImpl) *RouteGroup {
return &RouteGroup{path, routes, []string{}} return &RouteGroup{path, routes, []Runnable{}}
} }
func addRouteGroup(routeGroup *RouteGroup) { func addRouteGroup(routeGroup *RouteGroup) {
routeGroups = append(routeGroups, routeGroup) routeGroups = append(routeGroups, routeGroup)
} }
func (group *RouteGroup) RunBefore(line string) { func (group *RouteGroup) Before(line string, literal ...bool) *RouteGroup {
group.Before = append(group.Before, line) var litItem bool
if len(literal) > 0 {
litItem = literal[0]
}
group.RunBefore = append(group.RunBefore, Runnable{line, litItem})
return group
} }
func (group *RouteGroup) Routes(routes ...*RouteImpl) { func (group *RouteGroup) Routes(routes ...*RouteImpl) {
@ -34,88 +53,101 @@ func (group *RouteGroup) Routes(routes ...*RouteImpl) {
} }
func blankRoute() *RouteImpl { func blankRoute() *RouteImpl {
return &RouteImpl{"", "", "", []string{}} return &RouteImpl{"", "", []string{}, []Runnable{}}
} }
func Route(fname string, path string, args ...string) *RouteImpl { func Route(fname string, path string, args ...string) *RouteImpl {
var before = "" return &RouteImpl{fname, path, args, []Runnable{}}
if len(args) > 0 {
before = args[0]
args = args[1:]
}
return &RouteImpl{fname, path, before, args}
} }
func routes() { func routes() {
//addRoute("default_route","","") //addRoute("default_route","","")
addRoute("routeAPI", "/api/", "") addRoute(Route("routeAPI", "/api/"))
///addRoute("routeStatic","/static/","req.URL.Path += extra_data") ///addRoute("routeStatic","/static/","req.URL.Path += extra_data")
addRoute("routeOverview", "/overview/", "") addRoute(Route("routeOverview", "/overview/"))
//addRoute("routeCustomPage","/pages/",""/*,"&extra_data"*/) //addRoute("routeCustomPage","/pages/",""/*,"&extra_data"*/)
addRoute("routeForums", "/forums/", "" /*,"&forums"*/) addRoute(Route("routeForums", "/forums/" /*,"&forums"*/))
addRoute("routeForum", "/forum/", "", "extra_data") addRoute(Route("routeForum", "/forum/", "extra_data"))
//addRoute("routeTopicCreate","/topics/create/","","extra_data") //addRoute("routeTopicCreate","/topics/create/","","extra_data")
//addRoute("routeTopics","/topics/",""/*,"&groups","&forums"*/) //addRoute("routeTopics","/topics/",""/*,"&groups","&forums"*/)
addRoute("routeChangeTheme", "/theme/", "") addRoute(Route("routeChangeTheme", "/theme/"))
addRoute("routeShowAttachment", "/attachs/", "", "extra_data") addRoute(Route("routeShowAttachment", "/attachs/", "extra_data"))
reportGroup := newRouteGroup("/report/", reportGroup := newRouteGroup("/report/",
Route("routeReportSubmit", "/report/submit/", "", "extra_data"), Route("routeReportSubmit", "/report/submit/", "extra_data"),
) ).Before("MemberOnly")
addRouteGroup(reportGroup) addRouteGroup(reportGroup)
topicGroup := newRouteGroup("/topics/", topicGroup := newRouteGroup("/topics/",
Route("routeTopics", "/topics/"), Route("routeTopics", "/topics/"),
Route("routeTopicCreate", "/topics/create/", "", "extra_data"), Route("routeTopicCreate", "/topics/create/", "extra_data").Before("MemberOnly"),
) )
addRouteGroup(topicGroup) addRouteGroup(topicGroup)
buildPanelRoutes() buildPanelRoutes()
buildUserRoutes()
}
// TODO: Test the email token route
// TODO: Add a BeforeExcept method?
func buildUserRoutes() {
userGroup := newRouteGroup("/user/") //.Before("MemberOnly")
userGroup.Routes(
Route("routeProfile", "/user/").Before("req.URL.Path += extra_data", true),
Route("routeAccountOwnEditCritical", "/user/edit/critical/").Before("MemberOnly"),
Route("routeAccountOwnEditCriticalSubmit", "/user/edit/critical/submit/").Before("MemberOnly"),
Route("routeAccountOwnEditAvatar", "/user/edit/avatar/").Before("MemberOnly"),
Route("routeAccountOwnEditAvatarSubmit", "/user/edit/avatar/submit/").Before("MemberOnly"),
Route("routeAccountOwnEditUsername", "/user/edit/username/").Before("MemberOnly"),
Route("routeAccountOwnEditUsernameSubmit", "/user/edit/username/submit/").Before("MemberOnly"),
Route("routeAccountOwnEditEmail", "/user/edit/email/").Before("MemberOnly"),
Route("routeAccountOwnEditEmailTokenSubmit", "/user/edit/token/", "extra_data").Before("MemberOnly"),
)
addRouteGroup(userGroup)
} }
func buildPanelRoutes() { func buildPanelRoutes() {
panelGroup := newRouteGroup("/panel/") panelGroup := newRouteGroup("/panel/").Before("SuperModOnly")
panelGroup.RunBefore("SuperModOnly")
panelGroup.Routes( panelGroup.Routes(
Route("routePanel", "/panel/"), Route("routePanel", "/panel/"),
Route("routePanelForums", "/panel/forums/"), Route("routePanelForums", "/panel/forums/"),
Route("routePanelForumsCreateSubmit", "/panel/forums/create/"), Route("routePanelForumsCreateSubmit", "/panel/forums/create/"),
Route("routePanelForumsDelete", "/panel/forums/delete/", "", "extra_data"), Route("routePanelForumsDelete", "/panel/forums/delete/", "extra_data"),
Route("routePanelForumsDeleteSubmit", "/panel/forums/delete/submit/", "", "extra_data"), Route("routePanelForumsDeleteSubmit", "/panel/forums/delete/submit/", "extra_data"),
Route("routePanelForumsEdit", "/panel/forums/edit/", "", "extra_data"), Route("routePanelForumsEdit", "/panel/forums/edit/", "extra_data"),
Route("routePanelForumsEditSubmit", "/panel/forums/edit/submit/", "", "extra_data"), Route("routePanelForumsEditSubmit", "/panel/forums/edit/submit/", "extra_data"),
Route("routePanelForumsEditPermsSubmit", "/panel/forums/edit/perms/submit/", "", "extra_data"), Route("routePanelForumsEditPermsSubmit", "/panel/forums/edit/perms/submit/", "extra_data"),
Route("routePanelSettings", "/panel/settings/"), Route("routePanelSettings", "/panel/settings/"),
Route("routePanelSetting", "/panel/settings/edit/", "", "extra_data"), Route("routePanelSetting", "/panel/settings/edit/", "extra_data"),
Route("routePanelSettingEdit", "/panel/settings/edit/submit/", "", "extra_data"), Route("routePanelSettingEdit", "/panel/settings/edit/submit/", "extra_data"),
Route("routePanelWordFilters", "/panel/settings/word-filters/"), Route("routePanelWordFilters", "/panel/settings/word-filters/"),
Route("routePanelWordFiltersCreate", "/panel/settings/word-filters/create/"), Route("routePanelWordFiltersCreate", "/panel/settings/word-filters/create/"),
Route("routePanelWordFiltersEdit", "/panel/settings/word-filters/edit/", "", "extra_data"), Route("routePanelWordFiltersEdit", "/panel/settings/word-filters/edit/", "extra_data"),
Route("routePanelWordFiltersEditSubmit", "/panel/settings/word-filters/edit/submit/", "", "extra_data"), Route("routePanelWordFiltersEditSubmit", "/panel/settings/word-filters/edit/submit/", "extra_data"),
Route("routePanelWordFiltersDeleteSubmit", "/panel/settings/word-filters/delete/submit/", "", "extra_data"), Route("routePanelWordFiltersDeleteSubmit", "/panel/settings/word-filters/delete/submit/", "extra_data"),
Route("routePanelThemes", "/panel/themes/"), Route("routePanelThemes", "/panel/themes/"),
Route("routePanelThemesSetDefault", "/panel/themes/default/", "", "extra_data"), Route("routePanelThemesSetDefault", "/panel/themes/default/", "extra_data"),
Route("routePanelPlugins", "/panel/plugins/"), Route("routePanelPlugins", "/panel/plugins/"),
Route("routePanelPluginsActivate", "/panel/plugins/activate/", "", "extra_data"), Route("routePanelPluginsActivate", "/panel/plugins/activate/", "extra_data"),
Route("routePanelPluginsDeactivate", "/panel/plugins/deactivate/", "", "extra_data"), Route("routePanelPluginsDeactivate", "/panel/plugins/deactivate/", "extra_data"),
Route("routePanelPluginsInstall", "/panel/plugins/install/", "", "extra_data"), Route("routePanelPluginsInstall", "/panel/plugins/install/", "extra_data"),
Route("routePanelUsers", "/panel/users/"), Route("routePanelUsers", "/panel/users/"),
Route("routePanelUsersEdit", "/panel/users/edit/", "", "extra_data"), Route("routePanelUsersEdit", "/panel/users/edit/", "extra_data"),
Route("routePanelUsersEditSubmit", "/panel/users/edit/submit/", "", "extra_data"), Route("routePanelUsersEditSubmit", "/panel/users/edit/submit/", "extra_data"),
Route("routePanelGroups", "/panel/groups/"), Route("routePanelGroups", "/panel/groups/"),
Route("routePanelGroupsEdit", "/panel/groups/edit/", "", "extra_data"), Route("routePanelGroupsEdit", "/panel/groups/edit/", "extra_data"),
Route("routePanelGroupsEditPerms", "/panel/groups/edit/perms/", "", "extra_data"), Route("routePanelGroupsEditPerms", "/panel/groups/edit/perms/", "extra_data"),
Route("routePanelGroupsEditSubmit", "/panel/groups/edit/submit/", "", "extra_data"), Route("routePanelGroupsEditSubmit", "/panel/groups/edit/submit/", "extra_data"),
Route("routePanelGroupsEditPermsSubmit", "/panel/groups/edit/perms/submit/", "", "extra_data"), Route("routePanelGroupsEditPermsSubmit", "/panel/groups/edit/perms/submit/", "extra_data"),
Route("routePanelGroupsCreateSubmit", "/panel/groups/create/"), Route("routePanelGroupsCreateSubmit", "/panel/groups/create/"),
Route("routePanelBackups", "/panel/backups/", "", "extra_data"), Route("routePanelBackups", "/panel/backups/", "extra_data"),
Route("routePanelLogsMod", "/panel/logs/mod/"), Route("routePanelLogsMod", "/panel/logs/mod/"),
Route("routePanelDebug", "/panel/debug/"), Route("routePanelDebug", "/panel/debug/"),
) )

View File

@ -351,7 +351,7 @@ func routeForum(w http.ResponseWriter, r *http.Request, user User, sfid string)
} }
// TODO: Move this to *Forum // TODO: Move this to *Forum
rows, err := getForumTopicsOffsetStmt.Query(fid, offset, config.ItemsPerPage) rows, err := stmts.getForumTopicsOffset.Query(fid, offset, config.ItemsPerPage)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -558,7 +558,7 @@ func routeTopicID(w http.ResponseWriter, r *http.Request, user User) RouteError
tpage := TopicPage{topic.Title, user, headerVars, replyList, topic, page, lastPage} tpage := TopicPage{topic.Title, user, headerVars, replyList, topic, page, lastPage}
// Get the replies.. // Get the replies..
rows, err := getTopicRepliesOffsetStmt.Query(topic.ID, offset, config.ItemsPerPage) rows, err := stmts.getTopicRepliesOffset.Query(topic.ID, offset, config.ItemsPerPage)
if err == ErrNoRows { if err == ErrNoRows {
return LocalError("Bad Page. Some of the posts may have been deleted or you got here by directly typing in the page number.", w, r, user) return LocalError("Bad Page. Some of the posts may have been deleted or you got here by directly typing in the page number.", w, r, user)
} else if err != nil { } else if err != nil {
@ -684,7 +684,7 @@ func routeProfile(w http.ResponseWriter, r *http.Request, user User) RouteError
} }
// Get the replies.. // Get the replies..
rows, err := getProfileRepliesStmt.Query(puser.ID) rows, err := stmts.getProfileReplies.Query(puser.ID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -904,7 +904,7 @@ func routeRegisterSubmit(w http.ResponseWriter, r *http.Request, user User) Rout
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
_, err = addEmailStmt.Exec(email, uid, 0, token) _, err = stmts.addEmail.Exec(email, uid, 0, token)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -946,7 +946,7 @@ func routeChangeTheme(w http.ResponseWriter, r *http.Request, user User) RouteEr
// TODO: Store the current theme in the user's account? // TODO: Store the current theme in the user's account?
/*if user.Loggedin { /*if user.Loggedin {
_, err = change_theme_stmt.Exec(newTheme, user.ID) _, err = stmts.changeTheme.Exec(newTheme, user.ID)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -986,7 +986,7 @@ func routeAPI(w http.ResponseWriter, r *http.Request, user User) RouteError {
return PreErrorJS("Invalid asid", w, r) return PreErrorJS("Invalid asid", w, r)
} }
_, err = deleteActivityStreamMatchStmt.Exec(user.ID, asid) _, err = stmts.deleteActivityStreamMatch.Exec(user.ID, asid)
if err != nil { if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
@ -1000,14 +1000,14 @@ func routeAPI(w http.ResponseWriter, r *http.Request, user User) RouteError {
var asid, actorID, targetUserID, elementID int var asid, actorID, targetUserID, elementID int
var msgCount int var msgCount int
err = getActivityCountByWatcherStmt.QueryRow(user.ID).Scan(&msgCount) err = stmts.getActivityCountByWatcher.QueryRow(user.ID).Scan(&msgCount)
if err == ErrNoRows { if err == ErrNoRows {
return PreErrorJS("Couldn't find the parent topic", w, r) return PreErrorJS("Couldn't find the parent topic", w, r)
} else if err != nil { } else if err != nil {
return InternalErrorJS(err, w, r) return InternalErrorJS(err, w, r)
} }
rows, err := getActivityFeedByWatcherStmt.Query(user.ID) rows, err := stmts.getActivityFeedByWatcher.Query(user.ID)
if err != nil { if err != nil {
return InternalErrorJS(err, w, r) return InternalErrorJS(err, w, r)
} }

View File

@ -166,7 +166,7 @@ func panelUserCheck(w http.ResponseWriter, r *http.Request, user *User) (headerV
} }
} }
err = groupCountStmt.QueryRow().Scan(&stats.Groups) err = stmts.groupCount.QueryRow().Scan(&stats.Groups)
if err != nil { if err != nil {
return headerVars, stats, InternalError(err, w, r) return headerVars, stats, InternalError(err, w, r)
} }
@ -284,7 +284,7 @@ func preRoute(w http.ResponseWriter, r *http.Request) (User, bool) {
return *user, false return *user, false
} }
if host != user.LastIP { if host != user.LastIP {
_, err = updateLastIPStmt.Exec(host, user.ID) _, err = stmts.updateLastIP.Exec(host, user.ID)
if err != nil { if err != nil {
InternalError(err, w, r) InternalError(err, w, r)
return *user, false return *user, false
@ -306,3 +306,11 @@ func SuperModOnly(w http.ResponseWriter, r *http.Request, user User) RouteError
} }
return nil return nil
} }
// MemberOnly makes sure that only logged in users can access this route
func MemberOnly(w http.ResponseWriter, r *http.Request, user User) RouteError {
if !user.Loggedin {
return NoPermissions(w, r, user) // TODO: Do an error telling them to login instead?
}
return nil
}

View File

@ -27,7 +27,7 @@ func init() {
} }
func LoadSettings() error { func LoadSettings() error {
rows, err := getFullSettingsStmt.Query() rows, err := stmts.getFullSettings.Query()
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,7 +18,7 @@ func init() {
} }
func handleExpiredScheduledGroups() error { func handleExpiredScheduledGroups() error {
rows, err := getExpiredScheduledGroupsStmt.Query() rows, err := stmts.getExpiredScheduledGroups.Query()
if err != nil { if err != nil {
return err return err
} }
@ -44,7 +44,7 @@ func handleExpiredScheduledGroups() error {
func handleServerSync() error { func handleServerSync() error {
var lastUpdate time.Time var lastUpdate time.Time
err := getSyncStmt.QueryRow().Scan(&lastUpdate) err := stmts.getSync.QueryRow().Scan(&lastUpdate)
if err != nil { if err != nil {
return err return err
} }

View File

@ -77,7 +77,7 @@ func init() {
// ? - Delete themes which no longer exist in the themes folder from the database? // ? - Delete themes which no longer exist in the themes folder from the database?
func LoadThemes() error { func LoadThemes() error {
changeDefaultThemeMutex.Lock() changeDefaultThemeMutex.Lock()
rows, err := getThemesStmt.Query() rows, err := stmts.getThemes.Query()
if err != nil { if err != nil {
return err return err
} }

View File

@ -103,7 +103,7 @@ type TopicsRow struct {
} }
func (topic *Topic) Lock() (err error) { func (topic *Topic) Lock() (err error) {
_, err = lockTopicStmt.Exec(topic.ID) _, err = stmts.lockTopic.Exec(topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(topic.ID) tcache.CacheRemove(topic.ID)
@ -112,7 +112,7 @@ func (topic *Topic) Lock() (err error) {
} }
func (topic *Topic) Unlock() (err error) { func (topic *Topic) Unlock() (err error) {
_, err = unlockTopicStmt.Exec(topic.ID) _, err = stmts.unlockTopic.Exec(topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(topic.ID) tcache.CacheRemove(topic.ID)
@ -123,7 +123,7 @@ func (topic *Topic) Unlock() (err error) {
// TODO: We might want more consistent terminology rather than using stick in some places and pin in others. If you don't understand the difference, there is none, they are one and the same. // TODO: We might want more consistent terminology rather than using stick in some places and pin in others. If you don't understand the difference, there is none, they are one and the same.
// ? - We do a CacheDelete() here instead of mutating the pointer to avoid creating a race condition // ? - We do a CacheDelete() here instead of mutating the pointer to avoid creating a race condition
func (topic *Topic) Stick() (err error) { func (topic *Topic) Stick() (err error) {
_, err = stickTopicStmt.Exec(topic.ID) _, err = stmts.stickTopic.Exec(topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(topic.ID) tcache.CacheRemove(topic.ID)
@ -132,7 +132,7 @@ func (topic *Topic) Stick() (err error) {
} }
func (topic *Topic) Unstick() (err error) { func (topic *Topic) Unstick() (err error) {
_, err = unstickTopicStmt.Exec(topic.ID) _, err = stmts.unstickTopic.Exec(topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(topic.ID) tcache.CacheRemove(topic.ID)
@ -168,7 +168,7 @@ func (topic *Topic) Delete() error {
return err return err
} }
_, err = deleteTopicStmt.Exec(topic.ID) _, err = stmts.deleteTopic.Exec(topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(topic.ID) tcache.CacheRemove(topic.ID)
@ -179,7 +179,7 @@ func (topic *Topic) Delete() error {
func (topic *Topic) Update(name string, content string) error { func (topic *Topic) Update(name string, content string) error {
content = preparseMessage(content) content = preparseMessage(content)
parsed_content := parseMessage(html.EscapeString(content), topic.ParentID, "forums") parsed_content := parseMessage(html.EscapeString(content), topic.ParentID, "forums")
_, err := editTopicStmt.Exec(name, content, parsed_content, topic.ID) _, err := stmts.editTopic.Exec(name, content, parsed_content, topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
@ -189,11 +189,11 @@ func (topic *Topic) Update(name string, content string) error {
} }
func (topic *Topic) CreateActionReply(action string, ipaddress string, user User) (err error) { func (topic *Topic) CreateActionReply(action string, ipaddress string, user User) (err error) {
_, err = createActionReplyStmt.Exec(topic.ID, action, ipaddress, user.ID) _, err = stmts.createActionReply.Exec(topic.ID, action, ipaddress, user.ID)
if err != nil { if err != nil {
return err return err
} }
_, err = addRepliesToTopicStmt.Exec(1, user.ID, topic.ID) _, err = stmts.addRepliesToTopic.Exec(1, user.ID, topic.ID)
tcache, ok := topics.(TopicCache) tcache, ok := topics.(TopicCache)
if ok { if ok {
tcache.CacheRemove(topic.ID) tcache.CacheRemove(topic.ID)
@ -235,7 +235,7 @@ func getTopicUser(tid int) (TopicUser, error) {
} }
tu := TopicUser{ID: tid} tu := TopicUser{ID: tid}
err := getTopicUserStmt.QueryRow(tid).Scan(&tu.Title, &tu.Content, &tu.CreatedBy, &tu.CreatedAt, &tu.IsClosed, &tu.Sticky, &tu.ParentID, &tu.IPAddress, &tu.PostCount, &tu.LikeCount, &tu.CreatedByName, &tu.Avatar, &tu.Group, &tu.URLPrefix, &tu.URLName, &tu.Level) err := stmts.getTopicUser.QueryRow(tid).Scan(&tu.Title, &tu.Content, &tu.CreatedBy, &tu.CreatedAt, &tu.IsClosed, &tu.Sticky, &tu.ParentID, &tu.IPAddress, &tu.PostCount, &tu.LikeCount, &tu.CreatedByName, &tu.Avatar, &tu.Group, &tu.URLPrefix, &tu.URLName, &tu.Level)
tu.Link = buildTopicURL(nameToSlug(tu.Title), tu.ID) tu.Link = buildTopicURL(nameToSlug(tu.Title), tu.ID)
tu.UserLink = buildProfileURL(nameToSlug(tu.CreatedByName), tu.CreatedBy) tu.UserLink = buildProfileURL(nameToSlug(tu.CreatedByName), tu.CreatedBy)
tu.Tag = gstore.DirtyGet(tu.Group).Tag tu.Tag = gstore.DirtyGet(tu.Group).Tag
@ -282,7 +282,7 @@ func getDummyTopic() *Topic {
func getTopicByReply(rid int) (*Topic, error) { func getTopicByReply(rid int) (*Topic, error) {
topic := Topic{ID: 0} topic := Topic{ID: 0}
err := getTopicByReplyStmt.QueryRow(rid).Scan(&topic.ID, &topic.Title, &topic.Content, &topic.CreatedBy, &topic.CreatedAt, &topic.IsClosed, &topic.Sticky, &topic.ParentID, &topic.IPAddress, &topic.PostCount, &topic.LikeCount, &topic.Data) err := stmts.getTopicByReply.QueryRow(rid).Scan(&topic.ID, &topic.Title, &topic.Content, &topic.CreatedBy, &topic.CreatedAt, &topic.IsClosed, &topic.Sticky, &topic.ParentID, &topic.IPAddress, &topic.PostCount, &topic.LikeCount, &topic.Data)
topic.Link = buildTopicURL(nameToSlug(topic.Title), topic.ID) topic.Link = buildTopicURL(nameToSlug(topic.Title), topic.ID)
return &topic, err return &topic, err
} }

View File

@ -157,7 +157,7 @@ func (mts *MemoryTopicStore) Create(fid int, topicName string, content string, u
wcount := wordCount(content) wcount := wordCount(content)
// TODO: Move this statement into the topic store // TODO: Move this statement into the topic store
res, err := createTopicStmt.Exec(fid, topicName, content, parsedContent, uid, ipaddress, wcount, uid) res, err := stmts.createTopic.Exec(fid, topicName, content, parsedContent, uid, ipaddress, wcount, uid)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -319,7 +319,7 @@ func (sts *SQLTopicStore) Create(fid int, topicName string, content string, uid
wcount := wordCount(content) wcount := wordCount(content)
// TODO: Move this statement into the topic store // TODO: Move this statement into the topic store
res, err := createTopicStmt.Exec(fid, topicName, content, parsedContent, uid, ipaddress, wcount, uid) res, err := stmts.createTopic.Exec(fid, topicName, content, parsedContent, uid, ipaddress, wcount, uid)
if err != nil { if err != nil {
return 0, err return 0, err
} }

38
user.go
View File

@ -174,11 +174,11 @@ func (user *User) RevertGroupUpdate() error {
// TODO: Use a transaction here // TODO: Use a transaction here
// ? - Add a Deactivate method? Not really needed, if someone's been bad you could do a ban, I guess it might be useful, if someone says that email x isn't actually owned by the user in question? // ? - Add a Deactivate method? Not really needed, if someone's been bad you could do a ban, I guess it might be useful, if someone says that email x isn't actually owned by the user in question?
func (user *User) Activate() (err error) { func (user *User) Activate() (err error) {
_, err = activateUserStmt.Exec(user.ID) _, err = stmts.activateUser.Exec(user.ID)
if err != nil { if err != nil {
return err return err
} }
_, err = changeGroupStmt.Exec(config.DefaultGroup, user.ID) _, err = stmts.changeGroup.Exec(config.DefaultGroup, user.ID)
ucache, ok := users.(UserCache) ucache, ok := users.(UserCache)
if ok { if ok {
ucache.CacheRemove(user.ID) ucache.CacheRemove(user.ID)
@ -190,7 +190,7 @@ func (user *User) Activate() (err error) {
// TODO: Delete this user's content too? // TODO: Delete this user's content too?
// TODO: Expose this to the admin? // TODO: Expose this to the admin?
func (user *User) Delete() error { func (user *User) Delete() error {
_, err := deleteUserStmt.Exec(user.ID) _, err := stmts.deleteUser.Exec(user.ID)
if err != nil { if err != nil {
return err return err
} }
@ -202,7 +202,7 @@ func (user *User) Delete() error {
} }
func (user *User) ChangeName(username string) (err error) { func (user *User) ChangeName(username string) (err error) {
_, err = setUsernameStmt.Exec(username, user.ID) _, err = stmts.setUsername.Exec(username, user.ID)
ucache, ok := users.(UserCache) ucache, ok := users.(UserCache)
if ok { if ok {
ucache.CacheRemove(user.ID) ucache.CacheRemove(user.ID)
@ -211,7 +211,7 @@ func (user *User) ChangeName(username string) (err error) {
} }
func (user *User) ChangeAvatar(avatar string) (err error) { func (user *User) ChangeAvatar(avatar string) (err error) {
_, err = setAvatarStmt.Exec(avatar, user.ID) _, err = stmts.setAvatar.Exec(avatar, user.ID)
ucache, ok := users.(UserCache) ucache, ok := users.(UserCache)
if ok { if ok {
ucache.CacheRemove(user.ID) ucache.CacheRemove(user.ID)
@ -220,7 +220,7 @@ func (user *User) ChangeAvatar(avatar string) (err error) {
} }
func (user *User) ChangeGroup(group int) (err error) { func (user *User) ChangeGroup(group int) (err error) {
_, err = updateUserGroupStmt.Exec(group, user.ID) _, err = stmts.updateUserGroup.Exec(group, user.ID)
ucache, ok := users.(UserCache) ucache, ok := users.(UserCache)
if ok { if ok {
ucache.CacheRemove(user.ID) ucache.CacheRemove(user.ID)
@ -232,7 +232,7 @@ func (user *User) increasePostStats(wcount int, topic bool) (err error) {
var mod int var mod int
baseScore := 1 baseScore := 1
if topic { if topic {
_, err = incrementUserTopicsStmt.Exec(1, user.ID) _, err = stmts.incrementUserTopics.Exec(1, user.ID)
if err != nil { if err != nil {
return err return err
} }
@ -241,26 +241,26 @@ func (user *User) increasePostStats(wcount int, topic bool) (err error) {
settings := settingBox.Load().(SettingBox) settings := settingBox.Load().(SettingBox)
if wcount >= settings["megapost_min_words"].(int) { if wcount >= settings["megapost_min_words"].(int) {
_, err = incrementUserMegapostsStmt.Exec(1, 1, 1, user.ID) _, err = stmts.incrementUserMegaposts.Exec(1, 1, 1, user.ID)
mod = 4 mod = 4
} else if wcount >= settings["bigpost_min_words"].(int) { } else if wcount >= settings["bigpost_min_words"].(int) {
_, err = incrementUserBigpostsStmt.Exec(1, 1, user.ID) _, err = stmts.incrementUserBigposts.Exec(1, 1, user.ID)
mod = 1 mod = 1
} else { } else {
_, err = incrementUserPostsStmt.Exec(1, user.ID) _, err = stmts.incrementUserPosts.Exec(1, user.ID)
} }
if err != nil { if err != nil {
return err return err
} }
_, err = incrementUserScoreStmt.Exec(baseScore+mod, user.ID) _, err = stmts.incrementUserScore.Exec(baseScore+mod, user.ID)
if err != nil { if err != nil {
return err return err
} }
//log.Print(user.Score + base_score + mod) //log.Print(user.Score + base_score + mod)
//log.Print(getLevel(user.Score + base_score + mod)) //log.Print(getLevel(user.Score + base_score + mod))
// TODO: Use a transaction to prevent level desyncs? // TODO: Use a transaction to prevent level desyncs?
_, err = updateUserLevelStmt.Exec(getLevel(user.Score+baseScore+mod), user.ID) _, err = stmts.updateUserLevel.Exec(getLevel(user.Score+baseScore+mod), user.ID)
return err return err
} }
@ -268,7 +268,7 @@ func (user *User) decreasePostStats(wcount int, topic bool) (err error) {
var mod int var mod int
baseScore := -1 baseScore := -1
if topic { if topic {
_, err = incrementUserTopicsStmt.Exec(-1, user.ID) _, err = stmts.incrementUserTopics.Exec(-1, user.ID)
if err != nil { if err != nil {
return err return err
} }
@ -277,24 +277,24 @@ func (user *User) decreasePostStats(wcount int, topic bool) (err error) {
settings := settingBox.Load().(SettingBox) settings := settingBox.Load().(SettingBox)
if wcount >= settings["megapost_min_words"].(int) { if wcount >= settings["megapost_min_words"].(int) {
_, err = incrementUserMegapostsStmt.Exec(-1, -1, -1, user.ID) _, err = stmts.incrementUserMegaposts.Exec(-1, -1, -1, user.ID)
mod = 4 mod = 4
} else if wcount >= settings["bigpost_min_words"].(int) { } else if wcount >= settings["bigpost_min_words"].(int) {
_, err = incrementUserBigpostsStmt.Exec(-1, -1, user.ID) _, err = stmts.incrementUserBigposts.Exec(-1, -1, user.ID)
mod = 1 mod = 1
} else { } else {
_, err = incrementUserPostsStmt.Exec(-1, user.ID) _, err = stmts.incrementUserPosts.Exec(-1, user.ID)
} }
if err != nil { if err != nil {
return err return err
} }
_, err = incrementUserScoreStmt.Exec(baseScore-mod, user.ID) _, err = stmts.incrementUserScore.Exec(baseScore-mod, user.ID)
if err != nil { if err != nil {
return err return err
} }
// TODO: Use a transaction to prevent level desyncs? // TODO: Use a transaction to prevent level desyncs?
_, err = updateUserLevelStmt.Exec(getLevel(user.Score-baseScore-mod), user.ID) _, err = stmts.updateUserLevel.Exec(getLevel(user.Score-baseScore-mod), user.ID)
return err return err
} }
@ -359,7 +359,7 @@ func SetPassword(uid int, password string) error {
if err != nil { if err != nil {
return err return err
} }
_, err = setPasswordStmt.Exec(hashedPassword, salt, uid) _, err = stmts.setPassword.Exec(hashedPassword, salt, uid)
return err return err
} }

View File

@ -423,12 +423,14 @@ func buildSlug(slug string, id int) string {
return slug + "." + strconv.Itoa(id) return slug + "." + strconv.Itoa(id)
} }
// TODO: Make a store for this?
func addModLog(action string, elementID int, elementType string, ipaddress string, actorID int) (err error) { func addModLog(action string, elementID int, elementType string, ipaddress string, actorID int) (err error) {
_, err = addModlogEntryStmt.Exec(action, elementID, elementType, ipaddress, actorID) _, err = stmts.addModlogEntry.Exec(action, elementID, elementType, ipaddress, actorID)
return err return err
} }
// TODO: Make a store for this?
func addAdminLog(action string, elementID string, elementType int, ipaddress string, actorID int) (err error) { func addAdminLog(action string, elementID string, elementType int, ipaddress string, actorID int) (err error) {
_, err = addAdminlogEntryStmt.Exec(action, elementID, elementType, ipaddress, actorID) _, err = stmts.addAdminlogEntry.Exec(action, elementID, elementType, ipaddress, actorID)
return err return err
} }

View File

@ -40,8 +40,9 @@ type NameTextPair struct {
Text string Text string
} }
// TODO: Make a store for this?
func initWidgets() error { func initWidgets() error {
rows, err := getWidgetsStmt.Query() rows, err := stmts.getWidgets.Query()
if err != nil { if err != nil {
return err return err
} }

View File

@ -16,7 +16,7 @@ func init() {
} }
func LoadWordFilters() error { func LoadWordFilters() error {
rows, err := getWordFiltersStmt.Query() rows, err := stmts.getWordFilters.Query()
if err != nil { if err != nil {
return err return err
} }