package main import ( "database/sql" "errors" "log" "strconv" "strings" "sync" "./query_gen/lib" "golang.org/x/crypto/bcrypt" ) // TODO: Add the watchdog goroutine var users UserStore var errAccountExists = errors.New("this username is already in use") type UserStore interface { Load(id int) error Get(id int) (*User, error) GetUnsafe(id int) (*User, error) CascadeGet(id int) (*User, error) //BulkCascadeGet(ids []int) ([]*User, error) BulkCascadeGetMap(ids []int) (map[int]*User, error) BypassGet(id int) (*User, error) Set(item *User) error Add(item *User) error AddUnsafe(item *User) error Remove(id int) error RemoveUnsafe(id int) error CreateUser(username string, password string, email string, group int, active int) (int, error) GetLength() int GetCapacity() int GetGlobalCount() int } type MemoryUserStore struct { items map[int]*User length int capacity int get *sql.Stmt register *sql.Stmt usernameExists *sql.Stmt userCount *sql.Stmt sync.RWMutex } // NewMemoryUserStore gives you a new instance of MemoryUserStore func NewMemoryUserStore(capacity int) *MemoryUserStore { getStmt, err := qgen.Builder.SimpleSelect("users", "name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid = ?", "", "") if err != nil { log.Fatal(err) } // Add an admin version of register_stmt with more flexibility? // create_account_stmt, err = db.Prepare("INSERT INTO registerStmt, err := qgen.Builder.SimpleInsert("users", "name, email, password, salt, group, is_super_admin, session, active, message", "?,?,?,?,?,0,'',?,''") if err != nil { log.Fatal(err) } usernameExistsStmt, err := qgen.Builder.SimpleSelect("users", "name", "name = ?", "", "") if err != nil { log.Fatal(err) } userCountStmt, err := qgen.Builder.SimpleCount("users", "", "") if err != nil { log.Fatal(err) } return &MemoryUserStore{ items: make(map[int]*User), capacity: capacity, get: getStmt, register: registerStmt, usernameExists: usernameExistsStmt, userCount: userCountStmt, } } func (sus *MemoryUserStore) Get(id int) (*User, error) { sus.RLock() item, ok := sus.items[id] sus.RUnlock() if ok { return item, nil } return item, ErrNoRows } func (sus *MemoryUserStore) GetUnsafe(id int) (*User, error) { item, ok := sus.items[id] if ok { return item, nil } return item, ErrNoRows } func (sus *MemoryUserStore) CascadeGet(id int) (*User, error) { sus.RLock() user, ok := sus.items[id] sus.RUnlock() if ok { return user, nil } user = &User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(user) if err == nil { sus.Set(user) } return user, err } // WARNING: We did a little hack to make this as thin and quick as possible to reduce lock contention, use the * Cascade* methods instead for normal use func (sus *MemoryUserStore) bulkGet(ids []int) (list []*User) { list = make([]*User, len(ids)) sus.RLock() for i, id := range ids { list[i] = sus.items[id] } sus.RUnlock() return list } // TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts? // TODO: ID of 0 should always error? func (sus *MemoryUserStore) BulkCascadeGetMap(ids []int) (list map[int]*User, err error) { var idCount = len(ids) list = make(map[int]*User) if idCount == 0 { return list, nil } var stillHere []int sliceList := sus.bulkGet(ids) for i, sliceItem := range sliceList { if sliceItem != nil { list[sliceItem.ID] = sliceItem } else { stillHere = append(stillHere, ids[i]) } } ids = stillHere // If every user is in the cache, then return immediately if len(ids) == 0 { return list, nil } var qlist string var uidList []interface{} for _, id := range ids { uidList = append(uidList, strconv.Itoa(id)) qlist += "?," } qlist = qlist[0 : len(qlist)-1] stmt, err := qgen.Builder.SimpleSelect("users", "uid, name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid IN("+qlist+")", "", "") if err != nil { return nil, err } rows, err := stmt.Query(uidList...) if err != nil { return nil, err } for rows.Next() { user := &User{Loggedin: true} err := rows.Scan(&user.ID, &user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if err != nil { return nil, err } // Initialise the user if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), user.ID) user.Tag = groups[user.Group].Tag initUserPerms(user) // Add it to the cache... _ = sus.Set(user) // Add it to the list to be returned list[user.ID] = user } // Did we miss any users? if idCount > len(list) { var sidList string for _, id := range ids { _, ok := list[id] if !ok { sidList += strconv.Itoa(id) + "," } } // We probably don't need this, but it might be useful in case of bugs in BulkCascadeGetMap if sidList == "" { if dev.DebugMode { log.Print("This data is sampled later in the BulkCascadeGetMap function, so it might miss the cached IDs") log.Print("idCount", idCount) log.Print("ids", ids) log.Print("list", list) } return list, errors.New("We weren't able to find a user, but we don't know which one") } sidList = sidList[0 : len(sidList)-1] return list, errors.New("Unable to find the users with the following IDs: " + sidList) } return list, nil } func (sus *MemoryUserStore) BypassGet(id int) (*User, error) { user := &User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(user) return user, err } func (sus *MemoryUserStore) Load(id int) error { user := &User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if err != nil { sus.Remove(id) return err } if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(user) _ = sus.Set(user) return nil } func (sus *MemoryUserStore) Set(item *User) error { sus.Lock() user, ok := sus.items[item.ID] if ok { sus.Unlock() *user = *item } else if sus.length >= sus.capacity { sus.Unlock() return ErrStoreCapacityOverflow } else { sus.items[item.ID] = item sus.Unlock() sus.length++ } return nil } func (sus *MemoryUserStore) Add(item *User) error { if sus.length >= sus.capacity { return ErrStoreCapacityOverflow } sus.Lock() sus.items[item.ID] = item sus.Unlock() sus.length++ return nil } func (sus *MemoryUserStore) AddUnsafe(item *User) error { if sus.length >= sus.capacity { return ErrStoreCapacityOverflow } sus.items[item.ID] = item sus.length++ return nil } func (sus *MemoryUserStore) Remove(id int) error { sus.Lock() delete(sus.items, id) sus.Unlock() sus.length-- return nil } func (sus *MemoryUserStore) RemoveUnsafe(id int) error { delete(sus.items, id) sus.length-- return nil } func (sus *MemoryUserStore) CreateUser(username string, password string, email string, group int, active int) (int, error) { // Is this username already taken..? err := sus.usernameExists.QueryRow(username).Scan(&username) if err != ErrNoRows { return 0, errAccountExists } salt, err := GenerateSafeString(saltLength) if err != nil { return 0, err } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost) if err != nil { return 0, err } res, err := sus.register.Exec(username, email, string(hashedPassword), salt, group, active) if err != nil { return 0, err } lastID, err := res.LastInsertId() return int(lastID), err } func (sus *MemoryUserStore) GetLength() int { return sus.length } func (sus *MemoryUserStore) SetCapacity(capacity int) { sus.capacity = capacity } func (sus *MemoryUserStore) GetCapacity() int { return sus.capacity } // Return the total number of users registered on the forums func (sus *MemoryUserStore) GetGlobalCount() int { var ucount int err := sus.userCount.QueryRow().Scan(&ucount) if err != nil { LogError(err) } return ucount } type SQLUserStore struct { get *sql.Stmt register *sql.Stmt usernameExists *sql.Stmt userCount *sql.Stmt } func NewSQLUserStore() *SQLUserStore { getStmt, err := qgen.Builder.SimpleSelect("users", "name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid = ?", "", "") if err != nil { log.Fatal(err) } // Add an admin version of register_stmt with more flexibility? // create_account_stmt, err = db.Prepare("INSERT INTO registerStmt, err := qgen.Builder.SimpleInsert("users", "name, email, password, salt, group, is_super_admin, session, active, message", "?,?,?,?,?,0,'',?,''") if err != nil { log.Fatal(err) } usernameExistsStmt, err := qgen.Builder.SimpleSelect("users", "name", "name = ?", "", "") if err != nil { log.Fatal(err) } userCountStmt, err := qgen.Builder.SimpleCount("users", "", "") if err != nil { log.Fatal(err) } return &SQLUserStore{ get: getStmt, register: registerStmt, usernameExists: usernameExistsStmt, userCount: userCountStmt, } } func (sus *SQLUserStore) Get(id int) (*User, error) { user := User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(&user) return &user, err } func (sus *SQLUserStore) GetUnsafe(id int) (*User, error) { user := User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(&user) return &user, err } func (sus *SQLUserStore) CascadeGet(id int) (*User, error) { user := User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(&user) return &user, err } // TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts? func (sus *SQLUserStore) BulkCascadeGetMap(ids []int) (list map[int]*User, err error) { var qlist string var uidList []interface{} for _, id := range ids { uidList = append(uidList, strconv.Itoa(id)) qlist += "?," } qlist = qlist[0 : len(qlist)-1] stmt, err := qgen.Builder.SimpleSelect("users", "uid, name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid IN("+qlist+")", "", "") if err != nil { return nil, err } rows, err := stmt.Query(uidList...) if err != nil { return nil, err } list = make(map[int]*User) for rows.Next() { user := &User{Loggedin: true} err := rows.Scan(&user.ID, &user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if err != nil { return nil, err } // Initialise the user if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), user.ID) user.Tag = groups[user.Group].Tag initUserPerms(user) // Add it to the list to be returned list[user.ID] = user } return list, nil } func (sus *SQLUserStore) BypassGet(id int) (*User, error) { user := User{ID: id, Loggedin: true} err := sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) if user.Avatar != "" { if user.Avatar[0] == '.' { user.Avatar = "/uploads/avatar_" + strconv.Itoa(user.ID) + user.Avatar } } else { user.Avatar = strings.Replace(config.Noavatar, "{id}", strconv.Itoa(user.ID), 1) } user.Link = buildProfileURL(nameToSlug(user.Name), id) user.Tag = groups[user.Group].Tag initUserPerms(&user) return &user, err } func (sus *SQLUserStore) Load(id int) error { user := &User{ID: id} // Simplify this into a quick check to see whether the user exists. Add an Exists method to facilitate this? return sus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup) } func (sus *SQLUserStore) CreateUser(username string, password string, email string, group int, active int) (int, error) { // Is this username already taken..? err := sus.usernameExists.QueryRow(username).Scan(&username) if err != ErrNoRows { return 0, errAccountExists } salt, err := GenerateSafeString(saltLength) if err != nil { return 0, err } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost) if err != nil { return 0, err } res, err := sus.register.Exec(username, email, string(hashedPassword), salt, group, active) if err != nil { return 0, err } lastID, err := res.LastInsertId() return int(lastID), err } // Placeholder methods, as we're not don't need to do any cache management with this implementation ofr the UserStore func (sus *SQLUserStore) Set(item *User) error { return nil } func (sus *SQLUserStore) Add(item *User) error { return nil } func (sus *SQLUserStore) AddUnsafe(item *User) error { return nil } func (sus *SQLUserStore) Remove(id int) error { return nil } func (sus *SQLUserStore) RemoveUnsafe(id int) error { return nil } func (sus *SQLUserStore) GetCapacity() int { return 0 } // Return the total number of users registered on the forums func (sus *SQLUserStore) GetLength() int { var ucount int err := sus.userCount.QueryRow().Scan(&ucount) if err != nil { LogError(err) } return ucount } func (sus *SQLUserStore) GetGlobalCount() int { var ucount int err := sus.userCount.QueryRow().Scan(&ucount) if err != nil { LogError(err) } return ucount }