Add BulkIsBlockedBy method to BlockStore.

Add TotalP to accCountBuilder.
Add EachP to AccSelectBuilder.
Add expectf.
Add more conversation test cases.
Refactor conversation tests.
Add block tests.
This commit is contained in:
Azareal 2021-02-27 16:13:03 +10:00
parent 9437561c76
commit f20b0bd936
3 changed files with 205 additions and 57 deletions

View File

@ -12,6 +12,7 @@ var UserBlocks BlockStore
type BlockStore interface { type BlockStore interface {
IsBlockedBy(blocker, blockee int) (bool, error) IsBlockedBy(blocker, blockee int) (bool, error)
BulkIsBlockedBy(blockers []int, blockee int) (bool, error)
Add(blocker, blockee int) error Add(blocker, blockee int) error
Remove(blocker, blockee int) error Remove(blocker, blockee int) error
BlockedByOffset(blocker, offset, perPage int) ([]int, error) BlockedByOffset(blocker, offset, perPage int) ([]int, error)
@ -45,6 +46,22 @@ func (s *DefaultBlockStore) IsBlockedBy(blocker, blockee int) (bool, error) {
return err == nil, err return err == nil, err
} }
// TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
func (s *DefaultBlockStore) BulkIsBlockedBy(blockers []int, blockee int) (bool, error) {
if len(blockers) == 0 {
return false, nil
}
if len(blockers) == 1 {
return s.IsBlockedBy(blockers[0], blockee)
}
idList, q := inqbuild(blockers)
count, err := qgen.NewAcc().Count("users_blocks").Where("blocker IN(" + q + ") AND blockedUser=?").TotalP(idList...)
if err == ErrNoRows {
return false, nil
}
return count == 0, err
}
func (s *DefaultBlockStore) Add(blocker, blockee int) error { func (s *DefaultBlockStore) Add(blocker, blockee int) error {
_, err := s.add.Exec(blocker, blockee) _, err := s.add.Exec(blocker, blockee)
return err return err
@ -61,7 +78,6 @@ func (s *DefaultBlockStore) BlockedByOffset(blocker, offset, perPage int) (uids
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var uid int var uid int
err := rows.Scan(&uid) err := rows.Scan(&uid)
@ -70,7 +86,6 @@ func (s *DefaultBlockStore) BlockedByOffset(blocker, offset, perPage int) (uids
} }
uids = append(uids, uid) uids = append(uids, uid)
} }
return uids, rows.Err() return uids, rows.Err()
} }

View File

@ -383,6 +383,13 @@ func expect(t *testing.T, item bool, errmsg string) {
} }
} }
func expectf(t *testing.T, item bool, errmsg string, args ...interface{}) {
if !item {
debug.PrintStack()
t.Fatalf(errmsg, args...)
}
}
func TestPermsMiddleware(t *testing.T) { func TestPermsMiddleware(t *testing.T) {
miscinit(t) miscinit(t)
if !c.PluginsInited { if !c.PluginsInited {
@ -1434,37 +1441,64 @@ func TestConvos(t *testing.T) {
c.InitPlugins() c.InitPlugins()
} }
_, err := c.Convos.Get(-1) sf := func(i interface{}, e error) error {
recordMustNotExist(t, err, "convo -1 should not exist") return e
_, err = c.Convos.Get(0) }
recordMustNotExist(t, err, "convo 0 should not exist") mf := func(e error, msg string, exists bool) {
_, err = c.Convos.Get(1) if !exists {
recordMustNotExist(t, err, "convo 1 should not exist") recordMustNotExist(t, e, msg)
} else {
recordMustExist(t, e, msg)
}
}
gu := func(uid, offset int, exists bool) {
s := ""
if !exists {
s = " not"
}
mf(sf(c.Convos.GetUser(uid, offset)), fmt.Sprintf("convo getuser %d %d should%s exist", uid, offset, s), exists)
}
gue := func(uid, offset int, exists bool) {
s := ""
if !exists {
s = " not"
}
mf(sf(c.Convos.GetUserExtra(uid, offset)), fmt.Sprintf("convo getuserextra %d %d should%s exist", uid, offset, s), exists)
}
_, err = c.Convos.GetUser(-1, -1)
recordMustNotExist(t, err, "convo getuser -1 -1 should not exist")
_, err = c.Convos.GetUser(-1, 0)
recordMustNotExist(t, err, "convo getuser -1 0 should not exist")
_, err = c.Convos.GetUser(0, 0)
recordMustNotExist(t, err, "convo getuser 0 0 should not exist")
_, err = c.Convos.GetUser(1, 0)
recordMustNotExist(t, err, "convos getuser 1 0 should not exist")
expect(t, c.Convos.GetUserCount(-1) == 0, "getusercount should be zero") expect(t, c.Convos.GetUserCount(-1) == 0, "getusercount should be zero")
expect(t, c.Convos.GetUserCount(0) == 0, "getusercount should be zero") expect(t, c.Convos.GetUserCount(0) == 0, "getusercount should be zero")
expect(t, c.Convos.GetUserCount(1) == 0, "getusercount should be zero") mf(sf(c.Convos.Get(-1)), "convo -1 should not exist", false)
mf(sf(c.Convos.Get(0)), "convo 0 should not exist", false)
gu(-1, -1, false)
gu(-1, 0, false)
gu(0, 0, false)
gue(-1, -1, false)
gue(-1, 0, false)
gue(0, 0, false)
_, err = c.Convos.GetUserExtra(-1, -1) nf := func(cid int, count int) {
recordMustNotExist(t, err, "convos getuserextra -1 -1 should not exist") ex := count > 0
_, err = c.Convos.GetUserExtra(-1, 0) s := ""
recordMustNotExist(t, err, "convos getuserextra -1 0 should not exist") if !ex {
_, err = c.Convos.GetUserExtra(0, 0) s = " not"
recordMustNotExist(t, err, "convos getuserextra 0 0 should not exist") }
_, err = c.Convos.GetUserExtra(1, 0) mf(sf(c.Convos.Get(cid)), fmt.Sprintf("convo %d should%s exist", cid, s), ex)
recordMustNotExist(t, err, "convos getuserextra 1 0 should not exist") gu(1, 0, ex)
gu(1, 5, false) // invariant may change in future tests
expect(t, c.Convos.Count() == 0, "convos count should be 0") expectf(t, c.Convos.GetUserCount(1) == count, "getusercount should be %d", count)
gue(1, 0, ex)
gue(1, 5, false) // invariant may change in future tests
expectf(t, c.Convos.Count() == count, "convos count should be %d", count)
}
nf(1, 0)
cid, err := c.Convos.Create("hehe", 1, []int{2}) awaitingActivation := 5
uid, err := c.Users.Create("Saturn", "ReallyBadPassword", "", awaitingActivation, false)
expectNilErr(t, err)
cid, err := c.Convos.Create("hehe", 1, []int{uid})
expectNilErr(t, err) expectNilErr(t, err)
expect(t, cid == 1, "cid should be 1") expect(t, cid == 1, "cid should be 1")
expect(t, c.Convos.Count() == 1, "convos count should be 1") expect(t, c.Convos.Count() == 1, "convos count should be 1")
@ -1476,8 +1510,87 @@ func TestConvos(t *testing.T) {
// TODO: CreatedAt test // TODO: CreatedAt test
expect(t, co.LastReplyBy == 1, "co.LastReplyBy should be 1") expect(t, co.LastReplyBy == 1, "co.LastReplyBy should be 1")
// TODO: LastReplyAt test // TODO: LastReplyAt test
expectIntToBeX(t, co.PostsCount(), 1, "postscount should be 1, not %d")
expect(t, co.Has(uid), "saturn should be in the conversation")
expect(t, !co.Has(9999), "uid 9999 should not be in the conversation")
uids, err := co.Uids()
expectNilErr(t, err)
expectIntToBeX(t, len(uids), 2, "uids length should be 2, not %d")
expect(t, uids[0] == uid, fmt.Sprintf("uids[0] should be %d, not %d", uid, uids[0]))
expect(t, uids[1] == 1, fmt.Sprintf("uids[1] should be %d, not %d", 1, uids[1]))
nf(cid, 1)
expectNilErr(t, c.Convos.Delete(cid))
expectIntToBeX(t, co.PostsCount(), 0, "postscount should be 0, not %d")
expect(t, !co.Has(uid), "saturn should not be in a deleted conversation")
uids, err = co.Uids()
expectNilErr(t, err)
expectIntToBeX(t, len(uids), 0, "uids length should be 0, not %d")
nf(cid, 0)
// TODO: More tests // TODO: More tests
// Block tests
ok, err := c.UserBlocks.IsBlockedBy(1, 1)
expectNilErr(t, err)
expect(t, !ok, "there shouldn't be any blocks")
ok, err = c.UserBlocks.BulkIsBlockedBy([]int{1}, 1)
expectNilErr(t, err)
expect(t, !ok, "there shouldn't be any blocks")
bf := func(blocker, offset, perPage, expectLen, blockee int) {
l, err := c.UserBlocks.BlockedByOffset(blocker, offset, perPage)
expectNilErr(t, err)
expect(t, len(l) == expectLen, fmt.Sprintf("there should be %d users blocked by %d not %d", expectLen, blocker, len(l)))
if len(l) > 0 {
expectf(t, l[0] == blockee, "blocked uid should be %d not %d", blockee, l[0])
}
}
nbf := func(blocker, blockee int) {
ok, err := c.UserBlocks.IsBlockedBy(1, 2)
expectNilErr(t, err)
expect(t, !ok, "there shouldn't be any blocks")
ok, err = c.UserBlocks.BulkIsBlockedBy([]int{1}, 2)
expectNilErr(t, err)
expect(t, !ok, "there shouldn't be any blocks")
expectIntToBeX(t, c.UserBlocks.BlockedByCount(1), 0, "blockedbycount for 1 should be 1, not %d")
bf(1, 0, 1, 0, 0)
bf(1, 0, 15, 0, 0)
bf(1, 1, 15, 0, 0)
bf(1, 5, 15, 0, 0)
}
nbf(1, 2)
expectNilErr(t, c.UserBlocks.Add(1, 2))
ok, err = c.UserBlocks.IsBlockedBy(1, 2)
expectNilErr(t, err)
expect(t, ok, "2 should be blocked by 1")
expectIntToBeX(t, c.UserBlocks.BlockedByCount(1), 1, "blockedbycount for 1 should be 1, not %d")
bf(1, 0, 1, 1, 2)
bf(1, 0, 15, 1, 2)
bf(1, 1, 15, 0, 0)
bf(1, 5, 15, 0, 0)
// Double add test
expectNilErr(t, c.UserBlocks.Add(1, 2))
ok, err = c.UserBlocks.IsBlockedBy(1, 2)
expectNilErr(t, err)
expect(t, ok, "2 should be blocked by 1")
//expectIntToBeX(t, c.UserBlocks.BlockedByCount(1), 1, "blockedbycount for 1 should be 1, not %d") // todo: fix this
//bf(1, 0, 1, 1, 2) // todo: fix this
//bf(1, 0, 15, 1, 2) // todo: fix this
//bf(1, 1, 15, 0, 0) // todo: fix this
bf(1, 5, 15, 0, 0)
expectNilErr(t, c.UserBlocks.Remove(1, 2))
nbf(1, 2)
// Double remove test
expectNilErr(t, c.UserBlocks.Remove(1, 2))
nbf(1, 2)
// TODO: Self-block test
// TODO: More Block tests
} }
func TestActivityStream(t *testing.T) { func TestActivityStream(t *testing.T) {

View File

@ -51,12 +51,10 @@ func (b *accDeleteBuilder) Run(args ...interface{}) (int, error) {
if stmt == nil { if stmt == nil {
return 0, b.build.FirstError() return 0, b.build.FirstError()
} }
res, err := stmt.Exec(args...) res, err := stmt.Exec(args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
lastID, err := res.LastInsertId() lastID, err := res.LastInsertId()
return int(lastID), err return int(lastID), err
} }
@ -246,11 +244,11 @@ type AccRowWrap struct {
err error err error
} }
func (wrap *AccRowWrap) Scan(dest ...interface{}) error { func (w *AccRowWrap) Scan(dest ...interface{}) error {
if wrap.err != nil { if w.err != nil {
return wrap.err return w.err
} }
return wrap.row.Scan(dest...) return w.row.Scan(dest...)
} }
// TODO: Test to make sure the errors are passed up properly // TODO: Test to make sure the errors are passed up properly
@ -264,42 +262,56 @@ func (b *AccSelectBuilder) QueryRow(args ...interface{}) *AccRowWrap {
// Experimental, reduces lines // Experimental, reduces lines
func (b *AccSelectBuilder) Each(h func(*sql.Rows) error) error { func (b *AccSelectBuilder) Each(h func(*sql.Rows) error) error {
query, err := b.query() query, e := b.query()
if err != nil { if e != nil {
return err return e
} }
rows, err := b.build.query(query) rows, e := b.build.query(query)
if err != nil { if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
if err = h(rows); err != nil { if e = h(rows); e != nil {
return err return e
}
}
return rows.Err()
}
func (b *AccSelectBuilder) EachP(h func(*sql.Rows) error, p ...interface{}) error {
query, e := b.query()
if e != nil {
return e
}
rows, e := b.build.query(query, p)
if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
if e = h(rows); e != nil {
return e
} }
} }
return rows.Err() return rows.Err()
} }
func (b *AccSelectBuilder) EachInt(h func(int) error) error { func (b *AccSelectBuilder) EachInt(h func(int) error) error {
query, err := b.query() query, e := b.query()
if err != nil { if e != nil {
return err return e
} }
rows, err := b.build.query(query) rows, e := b.build.query(query)
if err != nil { if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var theInt int var theInt int
err = rows.Scan(&theInt) if e = rows.Scan(&theInt); e != nil {
if err != nil { return e
return err
} }
if err = h(theInt); err != nil { if e = h(theInt); e != nil {
return err return e
} }
} }
return rows.Err() return rows.Err()
@ -348,10 +360,9 @@ func (b *accInsertBuilder) Run(args ...interface{}) (int, error) {
return int(lastID), err return int(lastID), err
} }
type accBulkInsertBuilder struct { type accBulkInsertBuilder struct {
table string table string
columns string columns string
fieldSet []string fieldSet []string
build *Accumulator build *Accumulator
@ -441,4 +452,13 @@ func (b *accCountBuilder) Total() (total int, err error) {
return total, err return total, err
} }
func (b *accCountBuilder) TotalP(params ...interface{}) (total int, err error) {
stmt := b.Prepare()
if stmt == nil {
return 0, b.build.FirstError()
}
err = stmt.QueryRow(params).Scan(&total)
return total, err
}
// TODO: Add a Sum builder for summing viewchunks up into one number for the dashboard? // TODO: Add a Sum builder for summing viewchunks up into one number for the dashboard?