521 lines
14 KiB
Go
521 lines
14 KiB
Go
package common
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"log"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// TODO: Add the watchdog goroutine
|
|
// TODO: Add some sort of update method
|
|
var users UserStore
|
|
var errAccountExists = errors.New("this username is already in use")
|
|
|
|
type UserStore interface {
|
|
Get(id int) (*User, error)
|
|
Exists(id int) bool
|
|
//BulkGet(ids []int) ([]*User, error)
|
|
BulkGetMap(ids []int) (map[int]*User, error)
|
|
BypassGet(id int) (*User, error)
|
|
Create(username string, password string, email string, group int, active bool) (int, error)
|
|
GlobalCount() int
|
|
}
|
|
|
|
type UserCache interface {
|
|
CacheGet(id int) (*User, error)
|
|
CacheGetUnsafe(id int) (*User, error)
|
|
CacheSet(item *User) error
|
|
CacheAdd(item *User) error
|
|
CacheAddUnsafe(item *User) error
|
|
CacheRemove(id int) error
|
|
CacheRemoveUnsafe(id int) error
|
|
Flush()
|
|
Reload(id int) error
|
|
Length() int
|
|
SetCapacity(capacity int)
|
|
GetCapacity() int
|
|
}
|
|
|
|
type MemoryUserStore struct {
|
|
items map[int]*User
|
|
length int64
|
|
capacity int
|
|
get *sql.Stmt
|
|
exists *sql.Stmt
|
|
register *sql.Stmt
|
|
usernameExists *sql.Stmt
|
|
userCount *sql.Stmt
|
|
sync.RWMutex
|
|
}
|
|
|
|
// NewMemoryUserStore gives you a new instance of MemoryUserStore
|
|
func NewMemoryUserStore(capacity int) (*MemoryUserStore, error) {
|
|
acc := qgen.Builder.Accumulator()
|
|
// TODO: Add an admin version of registerStmt with more flexibility?
|
|
return &MemoryUserStore{
|
|
items: make(map[int]*User),
|
|
capacity: capacity,
|
|
get: acc.SimpleSelect("users", "name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid = ?", "", ""),
|
|
exists: acc.SimpleSelect("users", "uid", "uid = ?", "", ""),
|
|
register: acc.SimpleInsert("users", "name, email, password, salt, group, is_super_admin, session, active, message, createdAt, lastActiveAt", "?,?,?,?,?,0,'',?,'',UTC_TIMESTAMP(),UTC_TIMESTAMP()"),
|
|
usernameExists: acc.SimpleSelect("users", "name", "name = ?", "", ""),
|
|
userCount: acc.SimpleCount("users", "", ""),
|
|
}, acc.FirstError()
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheGet(id int) (*User, error) {
|
|
mus.RLock()
|
|
item, ok := mus.items[id]
|
|
mus.RUnlock()
|
|
if ok {
|
|
return item, nil
|
|
}
|
|
return item, ErrNoRows
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheGetUnsafe(id int) (*User, error) {
|
|
item, ok := mus.items[id]
|
|
if ok {
|
|
return item, nil
|
|
}
|
|
return item, ErrNoRows
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Get(id int) (*User, error) {
|
|
mus.RLock()
|
|
user, ok := mus.items[id]
|
|
mus.RUnlock()
|
|
if ok {
|
|
return user, nil
|
|
}
|
|
|
|
user = &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
if err == nil {
|
|
mus.CacheSet(user)
|
|
}
|
|
return user, err
|
|
}
|
|
|
|
// WARNING: We did a little hack to make this as thin and quick as possible to reduce lock contention, use the * Cascade* methods instead for normal use
|
|
func (mus *MemoryUserStore) bulkGet(ids []int) (list []*User) {
|
|
list = make([]*User, len(ids))
|
|
mus.RLock()
|
|
for i, id := range ids {
|
|
list[i] = mus.items[id]
|
|
}
|
|
mus.RUnlock()
|
|
return list
|
|
}
|
|
|
|
// TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
|
|
// TODO: ID of 0 should always error?
|
|
func (mus *MemoryUserStore) BulkGetMap(ids []int) (list map[int]*User, err error) {
|
|
var idCount = len(ids)
|
|
list = make(map[int]*User)
|
|
if idCount == 0 {
|
|
return list, nil
|
|
}
|
|
|
|
var stillHere []int
|
|
sliceList := mus.bulkGet(ids)
|
|
for i, sliceItem := range sliceList {
|
|
if sliceItem != nil {
|
|
list[sliceItem.ID] = sliceItem
|
|
} else {
|
|
stillHere = append(stillHere, ids[i])
|
|
}
|
|
}
|
|
ids = stillHere
|
|
|
|
// If every user is in the cache, then return immediately
|
|
if len(ids) == 0 {
|
|
return list, nil
|
|
}
|
|
|
|
// TODO: Add a function for the qlist stuff
|
|
var qlist string
|
|
var uidList []interface{}
|
|
for _, id := range ids {
|
|
uidList = append(uidList, strconv.Itoa(id))
|
|
qlist += "?,"
|
|
}
|
|
qlist = qlist[0 : len(qlist)-1]
|
|
|
|
stmt, err := qgen.Builder.SimpleSelect("users", "uid, name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid IN("+qlist+")", "", "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := stmt.Query(uidList...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
user := &User{Loggedin: true}
|
|
err := rows.Scan(&user.ID, &user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Initialise the user
|
|
user.Init()
|
|
|
|
// Add it to the cache...
|
|
_ = mus.CacheSet(user)
|
|
|
|
// Add it to the list to be returned
|
|
list[user.ID] = user
|
|
}
|
|
|
|
// Did we miss any users?
|
|
if idCount > len(list) {
|
|
var sidList string
|
|
for _, id := range ids {
|
|
_, ok := list[id]
|
|
if !ok {
|
|
sidList += strconv.Itoa(id) + ","
|
|
}
|
|
}
|
|
|
|
// We probably don't need this, but it might be useful in case of bugs in BulkCascadeGetMap
|
|
if sidList == "" {
|
|
if dev.DebugMode {
|
|
log.Print("This data is sampled later in the BulkCascadeGetMap function, so it might miss the cached IDs")
|
|
log.Print("idCount", idCount)
|
|
log.Print("ids", ids)
|
|
log.Print("list", list)
|
|
}
|
|
return list, errors.New("We weren't able to find a user, but we don't know which one")
|
|
}
|
|
sidList = sidList[0 : len(sidList)-1]
|
|
|
|
return list, errors.New("Unable to find the users with the following IDs: " + sidList)
|
|
}
|
|
|
|
return list, nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) BypassGet(id int) (*User, error) {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
return user, err
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Reload(id int) error {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
if err != nil {
|
|
mus.CacheRemove(id)
|
|
return err
|
|
}
|
|
|
|
user.Init()
|
|
_ = mus.CacheSet(user)
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Exists(id int) bool {
|
|
err := mus.exists.QueryRow(id).Scan(&id)
|
|
if err != nil && err != ErrNoRows {
|
|
LogError(err)
|
|
}
|
|
return err != ErrNoRows
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheSet(item *User) error {
|
|
mus.Lock()
|
|
user, ok := mus.items[item.ID]
|
|
if ok {
|
|
mus.Unlock()
|
|
*user = *item
|
|
} else if int(mus.length) >= mus.capacity {
|
|
mus.Unlock()
|
|
return ErrStoreCapacityOverflow
|
|
} else {
|
|
mus.items[item.ID] = item
|
|
mus.Unlock()
|
|
atomic.AddInt64(&mus.length, 1)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheAdd(item *User) error {
|
|
if int(mus.length) >= mus.capacity {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
mus.Lock()
|
|
mus.items[item.ID] = item
|
|
mus.length = int64(len(mus.items))
|
|
mus.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheAddUnsafe(item *User) error {
|
|
if int(mus.length) >= mus.capacity {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
mus.items[item.ID] = item
|
|
mus.length = int64(len(mus.items))
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheRemove(id int) error {
|
|
mus.Lock()
|
|
_, ok := mus.items[id]
|
|
if !ok {
|
|
mus.Unlock()
|
|
return ErrNoRows
|
|
}
|
|
delete(mus.items, id)
|
|
mus.Unlock()
|
|
atomic.AddInt64(&mus.length, -1)
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheRemoveUnsafe(id int) error {
|
|
_, ok := mus.items[id]
|
|
if !ok {
|
|
return ErrNoRows
|
|
}
|
|
delete(mus.items, id)
|
|
atomic.AddInt64(&mus.length, -1)
|
|
return nil
|
|
}
|
|
|
|
// TODO: Change active to a bool?
|
|
func (mus *MemoryUserStore) Create(username string, password string, email string, group int, active bool) (int, error) {
|
|
// Is this username already taken..?
|
|
err := mus.usernameExists.QueryRow(username).Scan(&username)
|
|
if err != ErrNoRows {
|
|
return 0, errAccountExists
|
|
}
|
|
|
|
salt, err := GenerateSafeString(saltLength)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
res, err := mus.register.Exec(username, email, string(hashedPassword), salt, group, active)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
lastID, err := res.LastInsertId()
|
|
return int(lastID), err
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Flush() {
|
|
mus.Lock()
|
|
mus.items = make(map[int]*User)
|
|
mus.length = 0
|
|
mus.Unlock()
|
|
}
|
|
|
|
// ! Is this concurrent?
|
|
// Length returns the number of users in the memory cache
|
|
func (mus *MemoryUserStore) Length() int {
|
|
return int(mus.length)
|
|
}
|
|
|
|
func (mus *MemoryUserStore) SetCapacity(capacity int) {
|
|
mus.capacity = capacity
|
|
}
|
|
|
|
func (mus *MemoryUserStore) GetCapacity() int {
|
|
return mus.capacity
|
|
}
|
|
|
|
// GlobalCount returns the total number of users registered on the forums
|
|
func (mus *MemoryUserStore) GlobalCount() (ucount int) {
|
|
err := mus.userCount.QueryRow().Scan(&ucount)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return ucount
|
|
}
|
|
|
|
type SQLUserStore struct {
|
|
get *sql.Stmt
|
|
exists *sql.Stmt
|
|
register *sql.Stmt
|
|
usernameExists *sql.Stmt
|
|
userCount *sql.Stmt
|
|
}
|
|
|
|
func NewSQLUserStore() (*SQLUserStore, error) {
|
|
acc := qgen.Builder.Accumulator()
|
|
// TODO: Add an admin version of registerStmt with more flexibility?
|
|
return &SQLUserStore{
|
|
get: acc.SimpleSelect("users", "name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid = ?", "", ""),
|
|
exists: acc.SimpleSelect("users", "uid", "uid = ?", "", ""),
|
|
register: acc.SimpleInsert("users", "name, email, password, salt, group, is_super_admin, session, active, message, createdAt, lastActiveAt", "?,?,?,?,?,0,'',?,'',UTC_TIMESTAMP(),UTC_TIMESTAMP()"),
|
|
usernameExists: acc.SimpleSelect("users", "name", "name = ?", "", ""),
|
|
userCount: acc.SimpleCount("users", "", ""),
|
|
}, acc.FirstError()
|
|
}
|
|
|
|
func (mus *SQLUserStore) Get(id int) (*User, error) {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
return user, err
|
|
}
|
|
|
|
// TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
|
|
func (mus *SQLUserStore) BulkGetMap(ids []int) (list map[int]*User, err error) {
|
|
var qlist string
|
|
var uidList []interface{}
|
|
for _, id := range ids {
|
|
uidList = append(uidList, strconv.Itoa(id))
|
|
qlist += "?,"
|
|
}
|
|
qlist = qlist[0 : len(qlist)-1]
|
|
|
|
stmt, err := qgen.Builder.SimpleSelect("users", "uid, name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid IN("+qlist+")", "", "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := stmt.Query(uidList...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
list = make(map[int]*User)
|
|
for rows.Next() {
|
|
user := &User{Loggedin: true}
|
|
err := rows.Scan(&user.ID, &user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Initialise the user
|
|
user.Init()
|
|
|
|
// Add it to the list to be returned
|
|
list[user.ID] = user
|
|
}
|
|
|
|
return list, nil
|
|
}
|
|
|
|
func (mus *SQLUserStore) BypassGet(id int) (*User, error) {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
return user, err
|
|
}
|
|
|
|
func (mus *SQLUserStore) Exists(id int) bool {
|
|
err := mus.exists.QueryRow(id).Scan(&id)
|
|
if err != nil && err != ErrNoRows {
|
|
LogError(err)
|
|
}
|
|
return err != ErrNoRows
|
|
}
|
|
|
|
func (mus *SQLUserStore) Create(username string, password string, email string, group int, active bool) (int, error) {
|
|
// Is this username already taken..?
|
|
err := mus.usernameExists.QueryRow(username).Scan(&username)
|
|
if err != ErrNoRows {
|
|
return 0, errAccountExists
|
|
}
|
|
|
|
salt, err := GenerateSafeString(saltLength)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
res, err := mus.register.Exec(username, email, string(hashedPassword), salt, group, active)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
lastID, err := res.LastInsertId()
|
|
return int(lastID), err
|
|
}
|
|
|
|
// GlobalCount returns the total number of users registered on the forums
|
|
func (mus *SQLUserStore) GlobalCount() (ucount int) {
|
|
err := mus.userCount.QueryRow().Scan(&ucount)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return ucount
|
|
}
|
|
|
|
// TODO: MockUserStore
|
|
|
|
// NullUserStore is here for tests because Go doesn't have short-circuiting
|
|
type NullUserStore struct {
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheGet(_ int) (*User, error) {
|
|
return nil, ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheGetUnsafe(_ int) (*User, error) {
|
|
return nil, ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheSet(_ *User) error {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheAdd(_ *User) error {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheAddUnsafe(_ *User) error {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheRemove(_ int) error {
|
|
return ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheRemoveUnsafe(_ int) error {
|
|
return ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) Flush() {
|
|
}
|
|
|
|
func (nus *NullUserStore) Reload(_ int) error {
|
|
return ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) Length() int {
|
|
return 0
|
|
}
|
|
|
|
func (nus *NullUserStore) SetCapacity(_ int) {
|
|
}
|
|
|
|
func (nus *NullUserStore) GetCapacity() int {
|
|
return 0
|
|
}
|