Add GetForRenderRoute to DefaultAttachmentStore.

Use ErrCorruptAttachPath in DefaultAttachmentStore.
Consider the possibility that the requested attachment doesn't have a file extension.
Rename variables to reduce boilerplate.
Add TestThaw.
Avoid an allocation in hookgen.

Add route_attach_start hook.
Add route_attach_post_get hook.
This commit is contained in:
Azareal 2021-03-24 21:45:18 +10:00
parent adfed477a0
commit 26e8bf32a7
8 changed files with 160 additions and 137 deletions

View File

@ -1,9 +1,9 @@
package hookgen package hookgen
import ( import (
"bytes"
"log" "log"
"os" "os"
"bytes"
"text/template" "text/template"
) )
@ -27,34 +27,36 @@ type Hook struct {
func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) { func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) {
vhookskip := func(name, params string) { vhookskip := func(name, params string) {
add(name,params,"(bool,RouteError)","VhookSkippable_",false,true,"false,nil","") add(name, params, "(bool,RouteError)", "VhookSkippable_", false, true, "false,nil", "")
} }
vhookskip("forum_check_pre_perms","w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header") vhookskip("forum_check_pre_perms", "w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header")
vhookskip("router_after_filters","w http.ResponseWriter,r *http.Request,prefix string") vhookskip("router_after_filters", "w http.ResponseWriter,r *http.Request,prefix string")
vhookskip("router_pre_route","w http.ResponseWriter,r *http.Request,u *User,prefix string") vhookskip("router_pre_route", "w http.ResponseWriter,r *http.Request,u *User,prefix string")
vhookskip("route_forum_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header") vhookskip("route_forum_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_topic_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header") vhookskip("route_topic_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_attach_start", "w http.ResponseWriter,r *http.Request,u *User,fname string")
vhookskip("route_attach_post_get", "w http.ResponseWriter,r *http.Request,u *User,a *Attachment")
vhooknoret := func(name, params string) { vhooknoret := func(name, params string) {
add(name,params,"","Vhooks",false,false,"false,nil","") add(name, params, "", "Vhooks", false, false, "false,nil", "")
} }
vhooknoret("router_end","w http.ResponseWriter,r *http.Request,u *User,prefix string, extraData string") vhooknoret("router_end", "w http.ResponseWriter,r *http.Request,u *User,prefix string,extraData string")
vhooknoret("topic_reply_row_assign","r *ReplyUser") vhooknoret("topic_reply_row_assign", "r *ReplyUser")
//forums_frow_assign //forums_frow_assign
//Hook(name string, data interface{}) interface{} //Hook(name string, data interface{}) interface{}
/*hook := func(name, params, ret, pure string) { /*hook := func(name, params, ret, pure string) {
add(name,params,ret,"Hooks",true,false,ret,pure) add(name,params,ret,"Hooks",true,false,ret,pure)
}*/ }*/
hooknoret := func(name, params string) { hooknoret := func(name, params string) {
add(name,params,"","HooksNoRet",true,false,"","") add(name, params, "", "HooksNoRet", true, false, "", "")
} }
hooknoret("forums_frow_assign","f *Forum") hooknoret("forums_frow_assign", "f *Forum")
hookskip := func(name, params string) { hookskip := func(name, params string) {
add(name,params,"(skip bool)","HooksSkip",true,true,"","") add(name, params, "(skip bool)", "HooksSkip", true, true, "", "")
} }
//hookskip("forums_frow_assign","f *Forum") //hookskip("forums_frow_assign","f *Forum")
hookskip("topic_create_frow_assign","f *Forum") hookskip("topic_create_frow_assign", "f *Forum")
hookss := func(name string) { hookss := func(name string) {
add(name,"d string","string","Sshooks",true,false,"","d") add(name, "d string", "string", "Sshooks", true, false, "", "d")
} }
hookss("topic_ogdesc_assign") hookss("topic_ogdesc_assign")
} }
@ -86,7 +88,7 @@ func H_{{.Name}}_hook(t *HookTable,{{.Params}}) {{.Ret}} { {{if .Any}}
log.Fatal(e) log.Fatal(e)
} }
err := writeFile("./common/gen_extend.go", string(b.Bytes())) err := writeFile("./common/gen_extend.go", b.String())
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -14,6 +14,8 @@ import (
var Attachments AttachmentStore var Attachments AttachmentStore
var ErrCorruptAttachPath = errors.New("corrupt attachment path")
type MiniAttachment struct { type MiniAttachment struct {
ID int ID int
SectionID int SectionID int
@ -41,6 +43,7 @@ type Attachment struct {
} }
type AttachmentStore interface { type AttachmentStore interface {
GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error)
FGet(id int) (*Attachment, error) FGet(id int) (*Attachment, error)
Get(id int) (*MiniAttachment, error) Get(id int) (*MiniAttachment, error)
MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error)
@ -58,6 +61,8 @@ type AttachmentStore interface {
} }
type DefaultAttachmentStore struct { type DefaultAttachmentStore struct {
getForRenderRoute *sql.Stmt
fget *sql.Stmt fget *sql.Stmt
get *sql.Stmt get *sql.Stmt
getByObj *sql.Stmt getByObj *sql.Stmt
@ -76,6 +81,8 @@ type DefaultAttachmentStore struct {
func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) { func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) {
a := "attachments" a := "attachments"
return &DefaultAttachmentStore{ return &DefaultAttachmentStore{
getForRenderRoute: acc.Select(a).Columns("sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(), getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(),
@ -93,6 +100,15 @@ func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore,
}, acc.FirstError() }, acc.FirstError()
} }
// TODO: Revamp this to make it less of a copy-paste from the original code in the route
// ! Lacks some attachment initialisation code
func (s *DefaultAttachmentStore) GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error) {
a := &Attachment{SectionID: sid}
e := s.getForRenderRoute.QueryRow(filename, sid, sectionTable).Scan(&a.SectionTable, &a.OriginID, &a.OriginTable, &a.UploadedBy, &a.Path)
// TODO: Initialise attachment struct fields?
return a, e
}
func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) { func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) {
rows, err := s.getByObj.Query(originTable, originID) rows, err := s.getByObj.Query(originTable, originID)
defer rows.Close() defer rows.Close()
@ -104,7 +120,7 @@ func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
alist = append(alist, a) alist = append(alist, a)
@ -140,7 +156,7 @@ func (s *DefaultAttachmentStore) BulkMiniGetList(originTable string, ids []int)
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
if currentID == 0 { if currentID == 0 {
@ -169,7 +185,7 @@ func (s *DefaultAttachmentStore) FGet(id int) (*Attachment, error) {
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
return a, nil return a, nil
@ -183,7 +199,7 @@ func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) {
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
return a, nil return a, nil
@ -209,32 +225,32 @@ func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable, extra
} }
func (s *DefaultAttachmentStore) Count() (count int) { func (s *DefaultAttachmentStore) Count() (count int) {
err := s.count.QueryRow().Scan(&count) e := s.count.QueryRow().Scan(&count)
if err != nil { if e != nil {
LogError(err) LogError(e)
} }
return count return count
} }
func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) { func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) {
err := s.countIn.QueryRow(originTable, oid).Scan(&count) e := s.countIn.QueryRow(originTable, oid).Scan(&count)
if err != nil { if e != nil {
LogError(err) LogError(e)
} }
return count return count
} }
func (s *DefaultAttachmentStore) CountInPath(path string) (count int) { func (s *DefaultAttachmentStore) CountInPath(path string) (count int) {
err := s.countInPath.QueryRow(path).Scan(&count) e := s.countInPath.QueryRow(path).Scan(&count)
if err != nil { if e != nil {
LogError(err) LogError(e)
} }
return count return count
} }
func (s *DefaultAttachmentStore) Delete(id int) error { func (s *DefaultAttachmentStore) Delete(id int) error {
_, err := s.delete.Exec(id) _, e := s.delete.Exec(id)
return err return e
} }
// TODO: Split this out of this store // TODO: Split this out of this store
@ -256,10 +272,7 @@ func (s *DefaultAttachmentStore) AddLinked(otable string, oid int) (err error) {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
if err != nil {
return err return err
}
return nil
} }
// TODO: Split this out of this store // TODO: Split this out of this store
@ -280,10 +293,7 @@ func (s *DefaultAttachmentStore) RemoveLinked(otable string, oid int) (err error
} }
err = Rstore.GetCache().Remove(oid) err = Rstore.GetCache().Remove(oid)
} }
if err != nil {
return err return err
}
return nil
} }
// TODO: Add a table for the files and lock the file row when performing tasks related to the file // TODO: Add a table for the files and lock the file row when performing tasks related to the file

View File

@ -1,7 +1,7 @@
/* /*
* *
* Gosora Plugin System * Gosora Plugin System
* Copyright Azareal 2016 - 2020 * Copyright Azareal 2016 - 2021
* *
*/ */
package common package common
@ -25,9 +25,9 @@ type PluginList map[string]*Plugin
// TODO: Have a proper store rather than a map? // TODO: Have a proper store rather than a map?
var Plugins PluginList = make(map[string]*Plugin) var Plugins PluginList = make(map[string]*Plugin)
func (list PluginList) Add(pl *Plugin) { func (l PluginList) Add(pl *Plugin) {
buildPlugin(pl) buildPlugin(pl)
list[pl.UName] = pl l[pl.UName] = pl
} }
func buildPlugin(pl *Plugin) { func buildPlugin(pl *Plugin) {
@ -94,6 +94,8 @@ var hookTable = &HookTable{
"route_topic_list_start": nil, "route_topic_list_start": nil,
"route_topic_list_mostviewed_start": nil, "route_topic_list_mostviewed_start": nil,
"route_forum_list_start": nil, "route_forum_list_start": nil,
"route_attach_start": nil,
"route_attach_post_get": nil,
"action_end_create_topic": nil, "action_end_create_topic": nil,
"action_end_edit_topic": nil, "action_end_edit_topic": nil,

View File

@ -9,7 +9,7 @@ import (
var IPSearch IPSearcher var IPSearch IPSearcher
type IPSearcher interface { type IPSearcher interface {
Lookup(ip string) (uids []int, err error) Lookup(ip string) (uids []int, e error)
} }
type DefaultIPSearcher struct { type DefaultIPSearcher struct {
@ -23,27 +23,29 @@ type DefaultIPSearcher struct {
func NewDefaultIPSearcher() (*DefaultIPSearcher, error) { func NewDefaultIPSearcher() (*DefaultIPSearcher, error) {
acc := qgen.NewAcc() acc := qgen.NewAcc()
uu := "users" uu := "users"
q := func(tbl string) *sql.Stmt {
return acc.Select(uu).Columns("uid").InQ("uid", acc.Select(tbl).Columns("createdBy").Where("ip=?")).Prepare()
}
return &DefaultIPSearcher{ return &DefaultIPSearcher{
searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(), searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(),
searchTopics: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("topics").Columns("createdBy").Where("ip=?")).Prepare(), searchTopics: q("topics"),
searchReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("replies").Columns("createdBy").Where("ip=?")).Prepare(), searchReplies: q("replies"),
searchUsersReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("users_replies").Columns("createdBy").Where("ip=?")).Prepare(), searchUsersReplies: q("users_replies"),
}, acc.FirstError() }, acc.FirstError()
} }
func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) { func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, e error) {
var uid int var uid int
reqUserList := make(map[int]bool) reqUserList := make(map[int]bool)
runQuery2 := func(rows *sql.Rows, err error) error { runQuery2 := func(rows *sql.Rows, e error) error {
if err != nil { if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err := rows.Scan(&uid) if e := rows.Scan(&uid); e != nil {
if err != nil { return e
return err
} }
reqUserList[uid] = true reqUserList[uid] = true
} }
@ -53,21 +55,21 @@ func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) {
return runQuery2(stmt.Query(ip)) return runQuery2(stmt.Query(ip))
} }
err = runQuery2(s.searchUsers.Query(ip, ip)) e = runQuery2(s.searchUsers.Query(ip, ip))
if err != nil { if e != nil {
return uids, err return uids, e
} }
err = runQuery(s.searchTopics) e = runQuery(s.searchTopics)
if err != nil { if e != nil {
return uids, err return uids, e
} }
err = runQuery(s.searchReplies) e = runQuery(s.searchReplies)
if err != nil { if e != nil {
return uids, err return uids, e
} }
err = runQuery(s.searchUsersReplies) e = runQuery(s.searchUsersReplies)
if err != nil { if e != nil {
return uids, err return uids, e
} }
// Convert the user ID map to a slice, then bulk load the users // Convert the user ID map to a slice, then bulk load the users

View File

@ -44,13 +44,13 @@ func simpleForumUserCheck(w http.ResponseWriter, r *http.Request, u *User, fid i
return h, rerr return h, rerr
} }
fperms, err := FPStore.Get(fid, u.Group) fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows { if err == ErrNoRows {
fperms = BlankForumPerms() fp = BlankForumPerms()
} else if err != nil { } else if err != nil {
return h, InternalError(err, w, r) return h, InternalError(err, w, r)
} }
cascadeForumPerms(fperms, u) cascadeForumPerms(fp, u)
return h, nil return h, nil
} }
@ -72,13 +72,13 @@ func forumUserCheck(h *Header, w http.ResponseWriter, r *http.Request, u *User,
return rerr return rerr
} }
fperms, err := FPStore.Get(fid, u.Group) fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows { if err == ErrNoRows {
fperms = BlankForumPerms() fp = BlankForumPerms()
} else if err != nil { } else if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
cascadeForumPerms(fperms, u) cascadeForumPerms(fp, u)
h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this
return rerr return rerr
} }

View File

@ -44,19 +44,18 @@ func NewSQLSearcher(acc *qgen.Accumulator) (*SQLSearcher, error) {
func (s *SQLSearcher) queryAll(q string) ([]int, error) { func (s *SQLSearcher) queryAll(q string) ([]int, error) {
var ids []int var ids []int
run := func(stmt *sql.Stmt, q ...interface{}) error { run := func(stmt *sql.Stmt, q ...interface{}) error {
rows, err := stmt.Query(q...) rows, e := stmt.Query(q...)
if err == sql.ErrNoRows { if e == sql.ErrNoRows {
return nil return nil
} else if err != nil { } else if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var id int var id int
err := rows.Scan(&id) if e := rows.Scan(&id); e != nil {
if err != nil { return e
return err
} }
ids = append(ids, id) ids = append(ids, id)
} }
@ -81,19 +80,18 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
if len(zones) == 0 { if len(zones) == 0 {
return nil, nil return nil, nil
} }
run := func(rows *sql.Rows, err error) error { run := func(rows *sql.Rows, e error) error {
/*if err == sql.ErrNoRows { /*if e == sql.ErrNoRows {
return nil return nil
} else */if err != nil { } else */if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var id int var id int
err := rows.Scan(&id) if e := rows.Scan(&id); e != nil {
if err != nil { return e
return err
} }
ids = append(ids, id) ids = append(ids, id)
} }
@ -116,14 +114,12 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
acc := qgen.NewAcc() acc := qgen.NewAcc()
/*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")") /*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")")
err = acc.FirstError() if err = acc.FirstError(); err != nil {
if err != nil {
return nil, err return nil, err
}*/ }*/
// TODO: Cache common IN counts // TODO: Cache common IN counts
stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")") stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")")
err = acc.FirstError() if err = acc.FirstError(); err != nil {
if err != nil {
return nil, err return nil, err
} }
err = run(stmt.Query(q, q)) err = run(stmt.Query(q, q))
@ -131,8 +127,7 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
return nil, err return nil, err
} }
stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")") stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")")
err = acc.FirstError() if err = acc.FirstError(); err != nil {
if err != nil {
return nil, err return nil, err
} }
err = run(stmt.Query(q)) err = run(stmt.Query(q))

View File

@ -38,6 +38,21 @@ func (t *SingleServerThaw) Thaw() {
} }
} }
type TestThaw struct {
}
func NewTestThaw() *TestThaw {
return &TestThaw{}
}
func (t *TestThaw) Thawed() bool {
return true
}
func (t *TestThaw) Thaw() {
}
func (t *TestThaw) Tick() error {
return nil
}
type DefaultThaw struct { type DefaultThaw struct {
thawed int64 thawed int64
} }

View File

@ -8,50 +8,47 @@ import (
"strings" "strings"
c "github.com/Azareal/Gosora/common" c "github.com/Azareal/Gosora/common"
qgen "github.com/Azareal/Gosora/query_gen"
) )
type AttachmentStmts struct {
get *sql.Stmt
}
var attachmentStmts AttachmentStmts
// TODO: Abstract this with an attachment store
func init() {
c.DbInits.Add(func(acc *qgen.Accumulator) error {
attachmentStmts = AttachmentStmts{
get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
}
return acc.FirstError()
})
}
var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year)) var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year))
func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError { func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError {
filename = c.Stripslashes(filename)
ext := filepath.Ext("./attachs/" + filename)
if !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
sid, err := strconv.Atoi(r.FormValue("sid")) sid, err := strconv.Atoi(r.FormValue("sid"))
if err != nil { if err != nil {
return c.LocalError("The sid is not an integer", w, r, u) return c.LocalError("The sid is not an integer", w, r, u)
} }
sectionTable := r.FormValue("stype") sectionTable := r.FormValue("stype")
var originTable string filename = c.Stripslashes(filename)
var originID, uploadedBy int if filename == "" {
err = attachmentStmts.get.QueryRow(filename, sid, sectionTable).Scan(&sid, &sectionTable, &originID, &originTable, &uploadedBy, &filename) return c.LocalError("Bad filename", w, r, u)
}
ext := filepath.Ext(filename)
if ext == "" || !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
// TODO: Use the same hook table as upstream
hTbl := c.GetHookTable()
skip, rerr := c.H_route_attach_start_hook(hTbl, w, r, u, filename)
if skip || rerr != nil {
return rerr
}
a, err := c.Attachments.GetForRenderRoute(filename, sid, sectionTable)
// ErrCorruptAttachPath is a possibility now
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return c.NotFound(w, r, nil) return c.NotFound(w, r, nil)
} else if err != nil { } else if err != nil {
return c.InternalError(err, w, r) return c.InternalError(err, w, r)
} }
if sectionTable == "forums" { skip, rerr = c.H_route_attach_post_get_hook(hTbl, w, r, u, a)
if skip || rerr != nil {
return rerr
}
if a.SectionTable == "forums" {
_, ferr := c.SimpleForumUserCheck(w, r, u, sid) _, ferr := c.SimpleForumUserCheck(w, r, u, sid)
if ferr != nil { if ferr != nil {
return ferr return ferr
@ -63,7 +60,7 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
return c.LocalError("Unknown section", w, r, u) return c.LocalError("Unknown section", w, r, u)
} }
if originTable != "topics" && originTable != "replies" { if a.OriginTable != "topics" && a.OriginTable != "replies" {
return c.LocalError("Unknown origin", w, r, u) return c.LocalError("Unknown origin", w, r, u)
} }
@ -89,11 +86,11 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
} }
func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError { func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError {
err := c.DeleteAttachment(aid) e := c.DeleteAttachment(aid)
if err == sql.ErrNoRows { if e == sql.ErrNoRows {
return c.NotFoundJSQ(w, r, nil, js) return c.NotFoundJSQ(w, r, nil, js)
} else if err != nil { } else if e != nil {
return c.InternalErrorJSQ(err, w, r, js) return c.InternalErrorJSQ(e, w, r, js)
} }
return nil return nil
} }