diff --git a/common/attachments.go b/common/attachments.go index c93b10b5..32472b4f 100644 --- a/common/attachments.go +++ b/common/attachments.go @@ -25,7 +25,22 @@ type MiniAttachment struct { Ext string } +type Attachment struct { + ID int + SectionTable string + SectionID int + OriginTable string + OriginID int + UploadedBy int + Path string + Extra string + + Image bool + Ext string +} + type AttachmentStore interface { + FGet(id int) (*Attachment, error) Get(id int) (*MiniAttachment, error) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) BulkMiniGetList(originTable string, ids []int) (amap map[int][]*MiniAttachment, err error) @@ -37,10 +52,12 @@ type AttachmentStore interface { CountInPath(path string) int Delete(id int) error - UpdateLinked(otable string, oid int) (err error) + AddLinked(otable string, oid int) (err error) + RemoveLinked(otable string, oid int) (err error) } type DefaultAttachmentStore struct { + fget *sql.Stmt get *sql.Stmt getByObj *sql.Stmt add *sql.Stmt @@ -58,6 +75,7 @@ type DefaultAttachmentStore struct { func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) { a := "attachments" return &DefaultAttachmentStore{ + fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(), add: acc.Insert(a).Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path, extra").Fields("?,?,?,?,?,?,?").Prepare(), @@ -145,6 +163,21 @@ func (s *DefaultAttachmentStore) BulkMiniGetList(originTable string, ids []int) return amap, rows.Err() } +func (s *DefaultAttachmentStore) FGet(id int) (*Attachment, error) { + a := &Attachment{ID: id} + err := s.fget.QueryRow(id).Scan(&a.OriginTable, &a.OriginID, &a.SectionTable, &a.SectionID, &a.UploadedBy, &a.Path, &a.Extra) + if err != nil { + return nil, err + } + extarr := strings.Split(a.Path, ".") + if len(extarr) < 2 { + return nil, errors.New("corrupt attachment path") + } + a.Ext = extarr[len(extarr)-1] + a.Image = ImageFileExts.Contains(a.Ext) + return a, nil +} + func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) { a := &MiniAttachment{ID: id} err := s.get.QueryRow(id).Scan(&a.OriginID, &a.SectionID, &a.UploadedBy, &a.Path, &a.Extra) @@ -209,7 +242,7 @@ func (s *DefaultAttachmentStore) Delete(id int) error { } // TODO: Split this out of this store -func (s *DefaultAttachmentStore) UpdateLinked(otable string, oid int) (err error) { +func (s *DefaultAttachmentStore) AddLinked(otable string, oid int) (err error) { switch otable { case "topics": _, err = s.topicUpdateAttachs.Exec(s.CountIn(otable, oid), oid) @@ -219,6 +252,37 @@ func (s *DefaultAttachmentStore) UpdateLinked(otable string, oid int) (err error err = Topics.Reload(oid) case "replies": _, err = s.replyUpdateAttachs.Exec(s.CountIn(otable, oid), oid) + if err != nil { + return err + } + err = Rstore.GetCache().Remove(oid) + } + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return err + } + return nil +} + +// TODO: Split this out of this store +func (s *DefaultAttachmentStore) RemoveLinked(otable string, oid int) (err error) { + switch otable { + case "topics": + _, err = s.topicUpdateAttachs.Exec(s.CountIn(otable, oid), oid) + if err != nil { + return err + } + if tc := Topics.GetCache(); tc != nil { + tc.Remove(oid) + } + case "replies": + _, err = s.replyUpdateAttachs.Exec(s.CountIn(otable, oid), oid) + if err != nil { + return err + } + err = Rstore.GetCache().Remove(oid) } if err != nil { return err @@ -228,26 +292,31 @@ func (s *DefaultAttachmentStore) UpdateLinked(otable string, oid int) (err error // TODO: Add a table for the files and lock the file row when performing tasks related to the file func DeleteAttachment(aid int) error { - attach, err := Attachments.Get(aid) + a, err := Attachments.FGet(aid) if err != nil { - //fmt.Println("o1") return err } - err = Attachments.Delete(aid) + err = deleteAttachment(a) + if err != nil { + return err + } + _ = Attachments.RemoveLinked(a.OriginTable, a.OriginID) + return nil +} + +func deleteAttachment(a *Attachment) error { + err := Attachments.Delete(a.ID) if err != nil { - //fmt.Println("o2") return err } - count := Attachments.CountInPath(attach.Path) + count := Attachments.CountInPath(a.Path) if count == 0 { - err := os.Remove("./attachs/" + attach.Path) + err := os.Remove("./attachs/" + a.Path) if err != nil { - //fmt.Println("o3") return err } } - //fmt.Println("o4") return nil } diff --git a/common/password_reset.go b/common/password_reset.go index 354b1073..e0628e34 100644 --- a/common/password_reset.go +++ b/common/password_reset.go @@ -5,7 +5,7 @@ import ( "database/sql" "errors" - "github.com/Azareal/Gosora/query_gen" + qgen "github.com/Azareal/Gosora/query_gen" ) var PasswordResetter *DefaultPasswordResetter @@ -30,10 +30,10 @@ type DefaultPasswordResetter struct { func NewDefaultPasswordResetter(acc *qgen.Accumulator) (*DefaultPasswordResetter, error) { pr := "password_resets" return &DefaultPasswordResetter{ - getTokens: acc.Select(pr).Columns("token").Where("uid = ?").Prepare(), - create: acc.Insert(pr).Columns("email, uid, validated, token, createdAt").Fields("?,?,0,?,UTC_TIMESTAMP()").Prepare(), + getTokens: acc.Select(pr).Columns("token").Where("uid=?").Prepare(), + create: acc.Insert(pr).Columns("email,uid,validated,token,createdAt").Fields("?,?,0,?,UTC_TIMESTAMP()").Prepare(), //create: acc.Insert(pr).Cols("email,uid,validated=0,token,createdAt=UTC_TIMESTAMP()").Prep(), - delete: acc.Delete(pr).Where("uid=?").Prepare(), + delete: acc.Delete(pr).Where("uid=?").Prepare(), //model: acc.Model(w).Cols("email,uid,validated=0,token").Key("uid").CreatedAt("createdAt").Prep(), }, acc.FirstError() } diff --git a/common/topic.go b/common/topic.go index 4d8dde31..94a387ee 100644 --- a/common/topic.go +++ b/common/topic.go @@ -216,7 +216,7 @@ func init() { DbInits.Add(func(acc *qgen.Accumulator) error { t := "topics" topicStmts = TopicStmts{ - getRids: acc.Select("replies").Columns("rid").Where("tid = ?").Orderby("rid ASC").Limit("?,?").Prepare(), + getRids: acc.Select("replies").Columns("rid").Where("tid=?").Orderby("rid ASC").Limit("?,?").Prepare(), getReplies: acc.SimpleLeftJoin("replies AS r", "users AS u", "r.rid, r.content, r.createdBy, r.createdAt, r.lastEdit, r.lastEditBy, u.avatar, u.name, u.group, u.level, r.ip, r.likeCount, r.attachCount, r.actionType", "r.createdBy = u.uid", "r.tid = ?", "r.rid ASC", "?,?"), addReplies: acc.Update(t).Set("postCount=postCount+?, lastReplyBy=?, lastReplyAt=UTC_TIMESTAMP()").Where("tid=?").Prepare(), updateLastReply: acc.Update(t).Set("lastReplyID=?").Where("lastReplyID > ? AND tid = ?").Prepare(), @@ -390,7 +390,11 @@ func handleAttachments(stmt *sql.Stmt, id int) error { if err != nil { return err } - err = DeleteAttachment(aid) + a, err := Attachments.FGet(aid) + if err != nil { + return err + } + err = deleteAttachment(a) if err != nil && err != sql.ErrNoRows { return err } diff --git a/misc_test.go b/misc_test.go index f0a25d2e..f1bd11c3 100644 --- a/misc_test.go +++ b/misc_test.go @@ -1071,7 +1071,12 @@ func TestAttachments(t *testing.T) { expect(t, c.Attachments.Count() == 0, "the number of attachments should be 0") expect(t, c.Attachments.CountIn("topics", 1) == 0, "the number of attachments in topic 1 should be 0") expect(t, c.Attachments.CountInPath(filename) == 0, fmt.Sprintf("the number of attachments with path '%s' should be 0", filename)) - _, err := c.Attachments.Get(1) + _, err := c.Attachments.FGet(1) + if err != nil && err != sql.ErrNoRows { + t.Error(err) + } + expect(t, err == sql.ErrNoRows, ".FGet should have no results") + _, err = c.Attachments.Get(1) if err != nil && err != sql.ErrNoRows { t.Error(err) } @@ -1087,31 +1092,43 @@ func TestAttachments(t *testing.T) { } expect(t, err == sql.ErrNoRows, ".BulkMiniGetList should have no results") - // Sim an upload, try a proper upload through the proper pathway later on - _, err = os.Stat(destFile) - if err != nil && !os.IsNotExist(err) { + simUpload := func() { + // Sim an upload, try a proper upload through the proper pathway later on + _, err = os.Stat(destFile) + if err != nil && !os.IsNotExist(err) { + expectNilErr(t, err) + } else if err == nil { + err := os.Remove(destFile) + expectNilErr(t, err) + } + + input, err := ioutil.ReadFile(srcFile) expectNilErr(t, err) - } else if err == nil { - err := os.Remove(destFile) + err = ioutil.WriteFile(destFile, input, 0644) expectNilErr(t, err) } - - input, err := ioutil.ReadFile(srcFile) - expectNilErr(t, err) - err = ioutil.WriteFile(destFile, input, 0644) - expectNilErr(t, err) + simUpload() tid, err := c.Topics.Create(2, "Attach Test", "Fillter Body", 1, "") expectNilErr(t, err) aid, err := c.Attachments.Add(2, "forums", tid, "topics", 1, filename, "") expectNilErr(t, err) - expectNilErr(t, c.Attachments.UpdateLinked("topics", tid)) + expectNilErr(t, c.Attachments.AddLinked("topics", tid)) expect(t, c.Attachments.Count() == 1, "the number of attachments should be 1") expect(t, c.Attachments.CountIn("topics", tid) == 1, fmt.Sprintf("the number of attachments in topic %d should be 1", tid)) expect(t, c.Attachments.CountInPath(filename) == 1, fmt.Sprintf("the number of attachments with path '%s' should be 1", filename)) - var a *c.MiniAttachment - f := func(aid, sid, oid, uploadedBy int, path, extra, ext string) { + e := func(a *c.MiniAttachment, aid, sid, oid, uploadedBy int, path, extra, ext string) { + expect(t, a.ID == aid, fmt.Sprintf("ID should be %d not %d", aid, a.ID)) + expect(t, a.SectionID == sid, fmt.Sprintf("SectionID should be %d not %d", sid, a.SectionID)) + expect(t, a.OriginID == oid, fmt.Sprintf("OriginID should be %d not %d", oid, a.OriginID)) + expect(t, a.UploadedBy == uploadedBy, fmt.Sprintf("UploadedBy should be %d not %d", uploadedBy, a.UploadedBy)) + expect(t, a.Path == path, fmt.Sprintf("Path should be %s not %s", path, a.Path)) + expect(t, a.Extra == extra, fmt.Sprintf("Extra should be %s not %s", extra, a.Extra)) + expect(t, a.Image, "Image should be true") + expect(t, a.Ext == ext, fmt.Sprintf("Ext should be %s not %s", ext, a.Ext)) + } + e2 := func(a *c.Attachment, aid, sid, oid, uploadedBy int, path, extra, ext string) { expect(t, a.ID == aid, fmt.Sprintf("ID should be %d not %d", aid, a.ID)) expect(t, a.SectionID == sid, fmt.Sprintf("SectionID should be %d not %d", sid, a.SectionID)) expect(t, a.OriginID == oid, fmt.Sprintf("OriginID should be %d not %d", oid, a.OriginID)) @@ -1129,15 +1146,19 @@ func TestAttachments(t *testing.T) { } else { tbl = "replies" } - a, err = c.Attachments.Get(aid) + fa, err := c.Attachments.FGet(aid) expectNilErr(t, err) - f(aid, 2, oid, 1, filename, extra, "png") + e2(fa, aid, 2, oid, 1, filename, extra, "png") + + a, err := c.Attachments.Get(aid) + expectNilErr(t, err) + e(a, aid, 2, oid, 1, filename, extra, "png") alist, err := c.Attachments.MiniGetList(tbl, oid) expectNilErr(t, err) expect(t, len(alist) == 1, fmt.Sprintf("len(alist) should be 1 not %d", len(alist))) a = alist[0] - f(aid, 2, oid, 1, filename, extra, "png") + e(a, aid, 2, oid, 1, filename, extra, "png") amap, err := c.Attachments.BulkMiniGetList(tbl, []int{oid}) expectNilErr(t, err) @@ -1148,7 +1169,7 @@ func TestAttachments(t *testing.T) { } expect(t, len(alist) == 1, fmt.Sprintf("len(alist) should be 1 not %d", len(alist))) a = alist[0] - f(aid, 2, oid, 1, filename, extra, "png") + e(a, aid, 2, oid, 1, filename, extra, "png") } topic, err := c.Topics.Get(tid) @@ -1156,41 +1177,61 @@ func TestAttachments(t *testing.T) { expect(t, topic.AttachCount == 1, fmt.Sprintf("topic.AttachCount should be 1 not %d", topic.AttachCount)) f2(aid, tid, "", true) - // TODO: Cover the other bits of creation / deletion not covered in the AttachmentStore like updating the reply / topic attachCount - // TODO: Move attachment tests - expectNilErr(t, c.Attachments.Delete(aid)) - expect(t, c.Attachments.Count() == 0, "the number of attachments should be 0") - expect(t, c.Attachments.CountIn("topics", tid) == 0, fmt.Sprintf("the number of attachments in topic %d should be 0", tid)) - expect(t, c.Attachments.CountInPath(filename) == 0, fmt.Sprintf("the number of attachments with path '%s' should be 0", filename)) - _, err = c.Attachments.Get(aid) - if err != nil && err != sql.ErrNoRows { - t.Error(err) + deleteTest := func(aid, oid int, topic bool) { + var tbl string + if topic { + tbl = "topics" + } else { + tbl = "replies" + } + //expectNilErr(t, c.Attachments.Delete(aid)) + expectNilErr(t, c.DeleteAttachment(aid)) + expect(t, c.Attachments.Count() == 0, "the number of attachments should be 0") + expect(t, c.Attachments.CountIn(tbl, oid) == 0, fmt.Sprintf("the number of attachments in topic %d should be 0", tid)) + expect(t, c.Attachments.CountInPath(filename) == 0, fmt.Sprintf("the number of attachments with path '%s' should be 0", filename)) + _, err = c.Attachments.FGet(aid) + if err != nil && err != sql.ErrNoRows { + t.Error(err) + } + expect(t, err == sql.ErrNoRows, ".FGet should have no results") + _, err = c.Attachments.Get(aid) + if err != nil && err != sql.ErrNoRows { + t.Error(err) + } + expect(t, err == sql.ErrNoRows, ".Get should have no results") + _, err = c.Attachments.MiniGetList(tbl, oid) + if err != nil && err != sql.ErrNoRows { + t.Error(err) + } + expect(t, err == sql.ErrNoRows, ".MiniGetList should have no results") + _, err = c.Attachments.BulkMiniGetList(tbl, []int{oid}) + if err != nil && err != sql.ErrNoRows { + t.Error(err) + } + expect(t, err == sql.ErrNoRows, ".BulkMiniGetList should have no results") } - expect(t, err == sql.ErrNoRows, ".Get should have no results") - _, err = c.Attachments.MiniGetList("topics", tid) - if err != nil && err != sql.ErrNoRows { - t.Error(err) - } - expect(t, err == sql.ErrNoRows, ".MiniGetList should have no results") - _, err = c.Attachments.BulkMiniGetList("topics", []int{tid}) - if err != nil && err != sql.ErrNoRows { - t.Error(err) - } - expect(t, err == sql.ErrNoRows, ".BulkMiniGetList should have no results") + deleteTest(aid, tid, true) + topic, err = c.Topics.Get(tid) + expectNilErr(t, err) + expect(t, topic.AttachCount == 0, fmt.Sprintf("topic.AttachCount should be 0 not %d", topic.AttachCount)) + simUpload() rid, err := c.Rstore.Create(topic, "Reply Filler", "", 1) expectNilErr(t, err) aid, err = c.Attachments.Add(2, "forums", rid, "replies", 1, filename, strconv.Itoa(topic.ID)) expectNilErr(t, err) - expectNilErr(t, c.Attachments.UpdateLinked("replies", rid)) + expectNilErr(t, c.Attachments.AddLinked("replies", rid)) r, err := c.Rstore.Get(rid) expectNilErr(t, err) expect(t, r.AttachCount == 1, fmt.Sprintf("r.AttachCount should be 1 not %d", r.AttachCount)) f2(aid, rid, strconv.Itoa(topic.ID), false) + deleteTest(aid, rid, false) + r, err = c.Rstore.Get(rid) + expectNilErr(t, err) + expect(t, r.AttachCount == 0, fmt.Sprintf("r.AttachCount should be 0 not %d", r.AttachCount)) - // TODO: Delete reply attachment // TODO: Path overlap tests } diff --git a/routes/attachments.go b/routes/attachments.go index d74811d0..6afb4a88 100644 --- a/routes/attachments.go +++ b/routes/attachments.go @@ -121,7 +121,7 @@ func uploadAttachment(w http.ResponseWriter, r *http.Request, user c.User, sid i pathMap[filename] = strconv.Itoa(aid) } - err = c.Attachments.UpdateLinked(otable, oid) + err = c.Attachments.AddLinked(otable, oid) if err != nil { return nil, c.InternalError(err, w, r) } diff --git a/routes/reply.go b/routes/reply.go index c3a3bb36..9f71ebb7 100644 --- a/routes/reply.go +++ b/routes/reply.go @@ -383,7 +383,6 @@ func AddAttachToReplySubmit(w http.ResponseWriter, r *http.Request, user c.User, if len(pathMap) == 0 { return c.InternalErrorJS(errors.New("no paths for attachment add"), w, r) } - _ = c.Rstore.GetCache().Remove(reply.ID) skip, rerr := lite.Hooks.VhookSkippable("action_end_add_attach_to_reply", reply.ID, &user) if skip || rerr != nil { @@ -447,7 +446,6 @@ func RemoveAttachFromReplySubmit(w http.ResponseWriter, r *http.Request, user c. return rerr } } - _ = c.Rstore.GetCache().Remove(reply.ID) skip, rerr := lite.Hooks.VhookSkippable("action_end_remove_attach_from_reply", reply.ID, &user) if skip || rerr != nil {