370 lines
9.2 KiB
Go
370 lines
9.2 KiB
Go
package common
|
|
|
|
import (
|
|
"errors"
|
|
"time"
|
|
|
|
//"log"
|
|
|
|
"database/sql"
|
|
"strconv"
|
|
|
|
qgen "github.com/Azareal/Gosora/query_gen"
|
|
)
|
|
|
|
var Convos ConversationStore
|
|
var convoStmts ConvoStmts
|
|
|
|
type ConvoStmts struct {
|
|
fetchPost *sql.Stmt
|
|
getPosts *sql.Stmt
|
|
countPosts *sql.Stmt
|
|
edit *sql.Stmt
|
|
create *sql.Stmt
|
|
delete *sql.Stmt
|
|
has *sql.Stmt
|
|
|
|
editPost *sql.Stmt
|
|
createPost *sql.Stmt
|
|
deletePost *sql.Stmt
|
|
|
|
getUsers *sql.Stmt
|
|
}
|
|
|
|
func init() {
|
|
DbInits.Add(func(acc *qgen.Accumulator) error {
|
|
cpo := "conversations_posts"
|
|
convoStmts = ConvoStmts{
|
|
fetchPost: acc.Select(cpo).Columns("cid,body,post,createdBy").Where("pid=?").Prepare(),
|
|
getPosts: acc.Select(cpo).Columns("pid,body,post,createdBy").Where("cid=?").Limit("?,?").Prepare(),
|
|
countPosts: acc.Count(cpo).Where("cid=?").Prepare(),
|
|
edit: acc.Update("conversations").Set("lastReplyBy=?,lastReplyAt=?").Where("cid=?").Prepare(),
|
|
create: acc.Insert("conversations").Columns("createdAt,lastReplyAt").Fields("UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(),
|
|
has: acc.Count("conversations_participants").Where("uid=? AND cid=?").Prepare(),
|
|
|
|
editPost: acc.Update(cpo).Set("body=?,post=?").Where("pid=?").Prepare(),
|
|
createPost: acc.Insert(cpo).Columns("cid,body,post,createdBy").Fields("?,?,?,?").Prepare(),
|
|
deletePost: acc.Delete(cpo).Where("pid=?").Prepare(),
|
|
|
|
getUsers: acc.Select("conversations_participants").Columns("uid").Where("cid=?").Prepare(),
|
|
}
|
|
return acc.FirstError()
|
|
})
|
|
}
|
|
|
|
type Conversation struct {
|
|
ID int
|
|
CreatedBy int
|
|
CreatedAt time.Time
|
|
LastReplyBy int
|
|
LastReplyAt time.Time
|
|
}
|
|
|
|
func (co *Conversation) Posts(offset, itemsPerPage int) (posts []*ConversationPost, err error) {
|
|
rows, err := convoStmts.getPosts.Query(co.ID, offset, itemsPerPage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
p := &ConversationPost{CID: co.ID}
|
|
err := rows.Scan(&p.ID, &p.Body, &p.Post, &p.CreatedBy)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p, err = ConvoPostProcess.OnLoad(p)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
posts = append(posts, p)
|
|
}
|
|
|
|
return posts, rows.Err()
|
|
}
|
|
|
|
func (co *Conversation) PostsCount() (count int) {
|
|
err := convoStmts.countPosts.QueryRow(co.ID).Scan(&count)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (co *Conversation) Uids() (ids []int, err error) {
|
|
rows, err := convoStmts.getUsers.Query(co.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var id int
|
|
err := rows.Scan(&id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ids = append(ids, id)
|
|
}
|
|
return ids, rows.Err()
|
|
}
|
|
|
|
func (co *Conversation) Has(uid int) (in bool) {
|
|
var count int
|
|
err := convoStmts.has.QueryRow(uid, co.ID).Scan(&count)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return count > 0
|
|
}
|
|
|
|
func (co *Conversation) Update() error {
|
|
_, err := convoStmts.edit.Exec(co.CreatedAt, co.LastReplyBy, co.LastReplyAt, co.ID)
|
|
return err
|
|
}
|
|
|
|
func (co *Conversation) Create() (int, error) {
|
|
res, err := convoStmts.create.Exec()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
lastID, err := res.LastInsertId()
|
|
return int(lastID), err
|
|
}
|
|
|
|
type ConversationExtra struct {
|
|
*Conversation
|
|
Users []*User
|
|
}
|
|
|
|
type ConversationStore interface {
|
|
Get(id int) (*Conversation, error)
|
|
GetUser(uid, offset int) (cos []*Conversation, err error)
|
|
GetUserExtra(uid int, offset int) (cos []*ConversationExtra, err error)
|
|
GetUserCount(uid int) (count int)
|
|
Delete(id int) error
|
|
Count() (count int)
|
|
Create(content string, createdBy int, participants []int) (int, error)
|
|
}
|
|
|
|
type DefaultConversationStore struct {
|
|
get *sql.Stmt
|
|
getUser *sql.Stmt
|
|
getUserCount *sql.Stmt
|
|
delete *sql.Stmt
|
|
deletePosts *sql.Stmt
|
|
deleteParticipants *sql.Stmt
|
|
create *sql.Stmt
|
|
addParticipant *sql.Stmt
|
|
count *sql.Stmt
|
|
}
|
|
|
|
func NewDefaultConversationStore(acc *qgen.Accumulator) (*DefaultConversationStore, error) {
|
|
co := "conversations"
|
|
return &DefaultConversationStore{
|
|
get: acc.Select(co).Columns("createdBy, createdAt, lastReplyBy, lastReplyAt").Where("cid=?").Prepare(),
|
|
getUser: acc.SimpleInnerJoin("conversations_participants AS cp", "conversations AS c", "cp.cid, c.createdBy, c.createdAt, c.lastReplyBy, c.lastReplyAt", "cp.cid=c.cid", "cp.uid=?", "c.lastReplyAt DESC, c.createdAt DESC, c.cid DESC", "?,?"),
|
|
getUserCount: acc.Count("conversations_participants").Where("uid=?").Prepare(),
|
|
delete: acc.Delete(co).Where("cid=?").Prepare(),
|
|
deletePosts: acc.Delete("conversations_posts").Where("cid=?").Prepare(),
|
|
deleteParticipants: acc.Delete("conversations_participants").Where("cid=?").Prepare(),
|
|
create: acc.Insert(co).Columns("createdBy, createdAt, lastReplyBy, lastReplyAt").Fields("?,UTC_TIMESTAMP(),?,UTC_TIMESTAMP()").Prepare(),
|
|
addParticipant: acc.Insert("conversations_participants").Columns("uid,cid").Fields("?,?").Prepare(),
|
|
count: acc.Count(co).Prepare(),
|
|
}, acc.FirstError()
|
|
}
|
|
|
|
func (s *DefaultConversationStore) Get(id int) (*Conversation, error) {
|
|
co := &Conversation{ID: id}
|
|
err := s.get.QueryRow(id).Scan(&co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt)
|
|
return co, err
|
|
}
|
|
|
|
func (s *DefaultConversationStore) GetUser(uid, offset int) (cos []*Conversation, err error) {
|
|
rows, err := s.getUser.Query(uid, offset, Config.ItemsPerPage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
co := &Conversation{}
|
|
err := rows.Scan(&co.ID, &co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cos = append(cos, co)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(cos) == 0 {
|
|
err = sql.ErrNoRows
|
|
}
|
|
return cos, err
|
|
}
|
|
|
|
func (s *DefaultConversationStore) GetUserExtra(uid, offset int) (cos []*ConversationExtra, err error) {
|
|
raw, err := s.GetUser(uid, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
//log.Printf("raw: %+v\n", raw)
|
|
|
|
if len(raw) == 1 {
|
|
//log.Print("r0b2")
|
|
uids, err := raw[0].Uids()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
//log.Println("r1b2")
|
|
umap, err := Users.BulkGetMap(uids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
//log.Println("r2b2")
|
|
users := make([]*User, len(umap))
|
|
var i int
|
|
for _, user := range umap {
|
|
users[i] = user
|
|
i++
|
|
}
|
|
return []*ConversationExtra{&ConversationExtra{raw[0], users}}, nil
|
|
}
|
|
//log.Println("1")
|
|
|
|
cmap := make(map[int]*ConversationExtra, len(raw))
|
|
for _, co := range raw {
|
|
cmap[co.ID] = &ConversationExtra{co, nil}
|
|
}
|
|
|
|
// TODO: Add a function for the q stuff
|
|
var q string
|
|
idList := make([]interface{}, len(raw))
|
|
for i, co := range raw {
|
|
idList[i] = strconv.Itoa(co.ID)
|
|
q += "?,"
|
|
}
|
|
q = q[0 : len(q)-1]
|
|
|
|
rows, err := qgen.NewAcc().Select("conversations_participants").Columns("uid,cid").Where("cid IN(" + q + ")").Query(idList...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
//log.Println("2")
|
|
|
|
idmap := make(map[int][]int) // cid: []uid
|
|
puidmap := make(map[int]struct{})
|
|
for rows.Next() {
|
|
var uid, cid int
|
|
err := rows.Scan(&uid, &cid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
idmap[cid] = append(idmap[cid], uid)
|
|
puidmap[uid] = struct{}{}
|
|
}
|
|
if err = rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
//log.Println("3")
|
|
//log.Printf("idmap: %+v\n", idmap)
|
|
//log.Printf("puidmap: %+v\n",puidmap)
|
|
|
|
puids := make([]int, len(puidmap))
|
|
var i int
|
|
for puid, _ := range puidmap {
|
|
puids[i] = puid
|
|
i++
|
|
}
|
|
umap, err := Users.BulkGetMap(puids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
//log.Println("4")
|
|
//log.Printf("umap: %+v\n", umap)
|
|
for cid, uids := range idmap {
|
|
co := cmap[cid]
|
|
for _, uid := range uids {
|
|
co.Users = append(co.Users, umap[uid])
|
|
}
|
|
//log.Printf("co.Conversation: %+v\n", co.Conversation)
|
|
//log.Printf("co.Users: %+v\n", co.Users)
|
|
cmap[cid] = co
|
|
}
|
|
//log.Printf("cmap: %+v\n", cmap)
|
|
for _, ra := range raw {
|
|
cos = append(cos, cmap[ra.ID])
|
|
}
|
|
//log.Printf("cos: %+v\n", cos)
|
|
|
|
return cos, rows.Err()
|
|
}
|
|
|
|
func (s *DefaultConversationStore) GetUserCount(uid int) (count int) {
|
|
err := s.getUserCount.QueryRow(uid).Scan(&count)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return count
|
|
}
|
|
|
|
// TODO: Use a foreign key or transaction
|
|
func (s *DefaultConversationStore) Delete(id int) error {
|
|
_, err := s.delete.Exec(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.deletePosts.Exec(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.deleteParticipants.Exec(id)
|
|
return err
|
|
}
|
|
|
|
func (s *DefaultConversationStore) Create(content string, createdBy int, participants []int) (int, error) {
|
|
if len(participants) == 0 {
|
|
return 0, errors.New("no participants set")
|
|
}
|
|
res, err := s.create.Exec(createdBy,createdBy)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
lastID, err := res.LastInsertId()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
post := &ConversationPost{CID: int(lastID), Body: content, CreatedBy: createdBy}
|
|
_, err = post.Create()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
for _, p := range participants {
|
|
_, err := s.addParticipant.Exec(p, lastID)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
_, err = s.addParticipant.Exec(createdBy, lastID)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return int(lastID), err
|
|
}
|
|
|
|
// Count returns the total number of topics on these forums
|
|
func (s *DefaultConversationStore) Count() (count int) {
|
|
err := s.count.QueryRow().Scan(&count)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return count
|
|
}
|