diff --git a/common/conversations.go b/common/conversations.go index 54a173ca..a80c368b 100644 --- a/common/conversations.go +++ b/common/conversations.go @@ -1,13 +1,10 @@ package common import ( - "io" "time" + "errors" + //"strconv" "database/sql" - "encoding/hex" - "crypto/aes" - "crypto/cipher" - "crypto/rand" qgen "github.com/Azareal/Gosora/query_gen" ) @@ -16,103 +13,15 @@ import ( conversations conversations_posts */ - -var ConvoPostProcess ConvoPostProcessor = NewDefaultConvoPostProcessor() - -type ConvoPostProcessor interface { - OnLoad(co *ConversationPost) (*ConversationPost, error) - OnSave(co *ConversationPost) (*ConversationPost, error) -} - -type DefaultConvoPostProcessor struct { -} - -func NewDefaultConvoPostProcessor() *DefaultConvoPostProcessor { - return &DefaultConvoPostProcessor{} -} - -func (pr *DefaultConvoPostProcessor) OnLoad(co *ConversationPost) (*ConversationPost, error) { - return co, nil -} - -func (pr *DefaultConvoPostProcessor) OnSave(co *ConversationPost) (*ConversationPost, error) { - return co, nil -} - -type AesConvoPostProcessor struct { -} - -func NewAesConvoPostProcessor() *AesConvoPostProcessor { - return &AesConvoPostProcessor{} -} - -func (pr *AesConvoPostProcessor) OnLoad(co *ConversationPost) (*ConversationPost, error) { - if co.Post != "aes" { - return co, nil - } - key, _ := hex.DecodeString(Config.ConvoKey) - - ciphertext, err := hex.DecodeString(co.Body) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - nonceSize := aesgcm.NonceSize() - if len(ciphertext) < nonceSize { - return nil, err - } - - nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] - plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return nil, err - } - - lco := *co - lco.Body = string(plaintext) - return &lco, nil -} - -func (pr *AesConvoPostProcessor) OnSave(co *ConversationPost) (*ConversationPost, error) { - key, _ := hex.DecodeString(Config.ConvoKey) - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - nonce := make([]byte, 12) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return nil, err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - ciphertext := aesgcm.Seal(nil, nonce, []byte(co.Body), nil) - - lco := *co - lco.Body = hex.EncodeToString(ciphertext) - lco.Post = "aes" - return &lco, nil -} - +var Convos ConversationStore var convoStmts ConvoStmts type ConvoStmts struct { getPosts *sql.Stmt + countPosts *sql.Stmt edit *sql.Stmt create *sql.Stmt + delete *sql.Stmt editPost *sql.Stmt createPost *sql.Stmt @@ -121,12 +30,15 @@ type ConvoStmts struct { /*func init() { DbInits.Add(func(acc *qgen.Accumulator) error { convoStmts = ConvoStmts{ - getPosts: acc.Select("conversations_posts").Columns("pid, body, post").Where("cid = ?").Prepare(), - edit: acc.Update("conversations").Set("participants = ?, lastReplyAt = ?").Where("cid = ?").Prepare(), - create: acc.Insert("conversations").Columns("participants, createdAt, lastReplyAt").Fields("?,UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), + getPosts: acc.Select("conversations_posts").Columns("pid, body, post, createdBy").Where("cid = ?").Limit("?,?").Prepare(), + countPosts: acc.Count("conversations_posts").Where("cid = ?").Prepare(), + //edit: acc.Update("conversations").Set("participants = ?, lastReplyBy = ?, lastReplyAt = ?").Where("cid = ?").Prepare(), + edit: acc.Update("conversations").Set("lastReplyBy = ?, lastReplyAt = ?").Where("cid = ?").Prepare(), + //create: acc.Insert("conversations").Columns("participants, createdAt, lastReplyAt").Fields("?,UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), + create: acc.Insert("conversations").Columns("createdAt, lastReplyAt").Fields("UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), - editPost: acc.Update("conversations_posts").Set("body = ?").Where("cid = ?").Prepare(), - createPost: acc.Insert("conversations_posts").Columns("body").Fields("?").Prepare(), + editPost: acc.Update("conversations_posts").Set("body = ?, post = ?").Where("cid = ?").Prepare(), + createPost: acc.Insert("conversations_posts").Columns("cid, body, post, createdBy").Fields("?,?,?,?").Prepare(), } return acc.FirstError() }) @@ -134,8 +46,10 @@ type ConvoStmts struct { type Conversation struct { ID int - Participants string + //Participants string + CreatedBy int CreatedAt time.Time + LastReplyBy int LastReplyAt time.Time } @@ -148,7 +62,7 @@ func (co *Conversation) Posts(offset int) (posts []*ConversationPost, err error) for rows.Next() { convo := &ConversationPost{CID: co.ID} - err := rows.Scan(&convo.ID, &convo.Body, &convo.Post) + err := rows.Scan(&convo.ID, &convo.Body, &convo.Post, &convo.CreatedBy) if err != nil { return nil, err } @@ -158,53 +72,25 @@ func (co *Conversation) Posts(offset int) (posts []*ConversationPost, err error) } posts = append(posts, convo) } - err = rows.Err() - if err != nil { - return nil, err - } - return posts, err + return posts, rows.Err() +} + +func (co *Conversation) PostsCount() (count int) { + err := convoStmts.countPosts.QueryRow(co.ID).Scan(&count) + if err != nil { + LogError(err) + } + return count } func (co *Conversation) Update() error { - _, err := convoStmts.edit.Exec(co.Participants, co.CreatedAt, co.LastReplyAt, co.ID) + _, err := convoStmts.edit.Exec(/*co.Participants, */co.CreatedAt, co.LastReplyBy, co.LastReplyAt, co.ID) return err } func (co *Conversation) Create() (int, error) { - res, err := convoStmts.create.Exec(co.Participants) - if err != nil { - return 0, err - } - - lastID, err := res.LastInsertId() - return int(lastID), err -} - -type ConversationPost struct { - ID int - CID int - Body string - Post string // aes, '' -} - -func (co *ConversationPost) Update() error { - lco, err := ConvoPostProcess.OnSave(co) - if err != nil { - return err - } - //GetHookTable().VhookNoRet("convo_post_update", lco) - _, err = convoStmts.editPost.Exec(lco.Body, lco.ID) - return err -} - -func (co *ConversationPost) Create() (int, error) { - lco, err := ConvoPostProcess.OnSave(co) - if err != nil { - return 0, err - } - //GetHookTable().VhookNoRet("convo_post_create", lco) - res, err := convoStmts.createPost.Exec(lco.Body) + res, err := convoStmts.create.Exec(/*co.Participants*/) if err != nil { return 0, err } @@ -215,35 +101,129 @@ func (co *ConversationPost) Create() (int, error) { type ConversationStore interface { Get(id int) (*Conversation, error) + GetUser(uid int, offset int) (cos []*Conversation, err error) + GetUserCount(uid int) (count int) Delete(id int) error Count() (count int) + Create(content string, createdBy int, participants []int) (int, error) } type DefaultConversationStore struct { get *sql.Stmt + getUser *sql.Stmt + getUserCount *sql.Stmt delete *sql.Stmt + deletePosts *sql.Stmt + create *sql.Stmt + addParticipant *sql.Stmt count *sql.Stmt } func NewDefaultConversationStore(acc *qgen.Accumulator) (*DefaultConversationStore, error) { return &DefaultConversationStore{ - get: acc.Select("conversations").Columns("participants, createdAt, lastReplyAt").Where("cid = ?").Prepare(), + //get: acc.Select("conversations").Columns("participants, createdBy, createdAt, lastReplyBy, lastReplyAt").Where("cid = ?").Prepare(), + get: acc.Select("conversations").Columns("createdBy, createdAt, lastReplyBy, lastReplyAt").Where("cid = ?").Prepare(), + + //("replies", "users", "replies.rid, replies.content, replies.createdBy, replies.createdAt, replies.lastEdit, replies.lastEditBy, users.avatar, users.name, users.group, users.url_prefix, users.url_name, users.level, replies.ipaddress, replies.likeCount, replies.attachCount, replies.actionType", "replies.createdBy = users.uid", "replies.tid = ?", "replies.rid ASC", "?,?") + //(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) + + //getUser: acc.SimpleLeftJoin("conversations_participants AS cp","conversations AS c","c.cid, c.participants, c.createdBy, c.createdAt, c.lastReplyBy, c.lastReplyAt","cp.cid = c.cid","cp.uid = ?","c.lastReplyAt DESC, c.createdAt DESC, c.cid DESC","?,?"), + getUser: acc.SimpleLeftJoin("conversations_participants AS cp","conversations AS c","c.cid, c.createdBy, c.createdAt, c.lastReplyBy, c.lastReplyAt","cp.cid = c.cid","cp.uid = ?","c.lastReplyAt DESC, c.createdAt DESC, c.cid DESC","?,?"), + getUserCount: acc.Count("conversations_participants").Where("uid = ?").Prepare(), delete: acc.Delete("conversations").Where("cid = ?").Prepare(), + deletePosts: acc.Delete("conversations_posts").Where("cid = ?").Prepare(), + //create: acc.Insert("conversations").Columns("participants, createdBy, createdAt, lastReplyAt").Fields("?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), + create: acc.Insert("conversations").Columns("createdBy, createdAt, lastReplyAt").Fields("?,UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), + addParticipant: acc.Insert("conversations_participants").Columns("uid, cid").Fields("?,?").Prepare(), count: acc.Count("conversations").Prepare(), }, acc.FirstError() } func (s *DefaultConversationStore) Get(id int) (*Conversation, error) { convo := &Conversation{ID: id} - err := s.get.QueryRow(id).Scan(&convo.Participants, &convo.CreatedAt, &convo.LastReplyAt) - return nil, err + err := s.get.QueryRow(id).Scan(/*&convo.Participants, */&convo.CreatedBy, &convo.CreatedAt, &convo.LastReplyBy, &convo.LastReplyAt) + return convo, err } +func (s *DefaultConversationStore) GetUser(uid int, offset int) (cos []*Conversation, err error) { + rows, err := s.getUser.Query(uid, offset, Config.ItemsPerPage) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + co := &Conversation{} + err := rows.Scan(&co.ID, /*&co.Participants,*/ &co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt) + if err != nil { + return nil, err + } + cos = append(cos, co) + } + + return cos, rows.Err() +} + +func (s *DefaultConversationStore) GetUserCount(uid int) (count int) { + err := s.getUserCount.QueryRow(uid).Scan(&count) + if err != nil { + LogError(err) + } + return count +} + +// TODO: Use a foreign key or transaction func (s *DefaultConversationStore) Delete(id int) error { _, err := s.delete.Exec(id) + if err != nil { + return err + } + _, err = s.deletePosts.Exec(id) return err } +func (s *DefaultConversationStore) Create(content string, createdBy int, participants []int) (int, error) { + if len(participants) == 0 { + return 0, errors.New("no participants set") + } + /*var pstr string + for _, parti := range participants { + pstr += strconv.Itoa(parti) + "," + } + pstr = pstr[:len(pstr)-1]*/ + + res, err := s.create.Exec(createdBy/*, pstr*/) + if err != nil { + return 0, err + } + lastID, err := res.LastInsertId() + if err != nil { + return 0, err + } + + post := &ConversationPost{} + post.CID = int(lastID) + post.Body = content + post.CreatedBy = createdBy + _, err = post.Create() + if err != nil { + return 0, err + } + + for _, p := range participants { + _, err := s.addParticipant.Exec(p,lastID) + if err != nil { + return 0, err + } + } + _, err = s.addParticipant.Exec(createdBy,lastID) + if err != nil { + return 0, err + } + + return int(lastID), err +} + // Count returns the total number of topics on these forums func (s *DefaultConversationStore) Count() (count int) { err := s.count.QueryRow().Scan(&count)