diff --git a/common/counters/routes.go b/common/counters/routes.go index c9bddbb1..ecb5c6d3 100644 --- a/common/counters/routes.go +++ b/common/counters/routes.go @@ -14,6 +14,7 @@ var RouteViewCounter *DefaultRouteViewCounter type DefaultRouteViewCounter struct { buckets []*RWMutexCounterBucket //[RouteID]count insert *sql.Stmt + insert5 *sql.Stmt } func NewDefaultRouteViewCounter(acc *qgen.Accumulator) (*DefaultRouteViewCounter, error) { @@ -21,9 +22,12 @@ func NewDefaultRouteViewCounter(acc *qgen.Accumulator) (*DefaultRouteViewCounter for bucketID, _ := range routeBuckets { routeBuckets[bucketID] = &RWMutexCounterBucket{counter: 0} } + + fields := "?,UTC_TIMESTAMP(),?" co := &DefaultRouteViewCounter{ buckets: routeBuckets, - insert: acc.Insert("viewchunks").Columns("count, createdAt, route").Fields("?,UTC_TIMESTAMP(),?").Prepare(), + insert: acc.Insert("viewchunks").Columns("count,createdAt,route").Fields(fields).Prepare(), + insert5: acc.BulkInsert("viewchunks").Columns("count,createdAt,route").Fields(fields,fields,fields,fields,fields).Prepare(), } c.AddScheduledFifteenMinuteTask(co.Tick) // There could be a lot of routes, so we don't want to be running this every second //c.AddScheduledSecondTask(co.Tick) @@ -31,39 +35,82 @@ func NewDefaultRouteViewCounter(acc *qgen.Accumulator) (*DefaultRouteViewCounter return co, acc.FirstError() } -func (co *DefaultRouteViewCounter) Tick() error { - for routeID, routeBucket := range co.buckets { - var count int - routeBucket.RLock() - count = routeBucket.counter - routeBucket.counter = 0 - routeBucket.RUnlock() +type RVCount struct { + RouteID int + Count int +} - err := co.insertChunk(count, routeID) // TODO: Bulk insert for speed? +func (co *DefaultRouteViewCounter) Tick() error { + var tb []RVCount + for routeID, b := range co.buckets { + var count int + b.RLock() + count = b.counter + b.counter = 0 + b.RUnlock() + + if count == 0 { + continue + } + tb = append(tb, RVCount{routeID,count}) + } + + // TODO: Expand on this? + var i int + if len(tb) >= 5 { + for ; len(tb) > (i+5); i += 5 { + err := co.insert5Chunk(tb[i:i+5]) + if err != nil { + c.DebugLogf("tb: %+v\n", tb) + c.DebugLog("i: ", i) + return errors.Wrap(errors.WithStack(err), "route counter x 5") + } + } + } + + for ; len(tb) > i; i++ { + err := co.insertChunk(tb[i].Count, tb[i].RouteID) if err != nil { + c.DebugLogf("tb: %+v\n", tb) + c.DebugLog("i: ", i) return errors.Wrap(errors.WithStack(err), "route counter") } } + return nil } -func (co *DefaultRouteViewCounter) insertChunk(count int, route int) error { - if count == 0 { - return nil - } +func (co *DefaultRouteViewCounter) insertChunk(count, route int) error { routeName := reverseRouteMapEnum[route] c.DebugLogf("Inserting a vchunk with a count of %d for route %s (%d)", count, routeName, route) _, err := co.insert.Exec(count, routeName) return err } +func (co *DefaultRouteViewCounter) insert5Chunk(rvs []RVCount) error { + args := make([]interface{}, len(rvs) * 2) + i := 0 + for _, rv := range rvs { + routeName := reverseRouteMapEnum[rv.RouteID] + c.DebugLogf("Queueing a vchunk with a count of %d for routes %s (%d)", rv.Count, routeName, rv.RouteID) + args[i] = rv.Count + args[i+1] = routeName + i += 2 + } + c.DebugLogf("args: %+v\n", args) + _, err := co.insert5.Exec(args...) + return err +} + func (co *DefaultRouteViewCounter) Bump(route int) { // TODO: Test this check - c.DebugDetail("co.buckets[", route, "]: ", co.buckets[route]) + b := co.buckets[route] + c.DebugDetail("co.buckets[", route, "]: ", b) if len(co.buckets) <= route || route < 0 { return } - co.buckets[route].Lock() - co.buckets[route].counter++ - co.buckets[route].Unlock() + // TODO: Avoid lock by using atomic increment? + b.Lock() + b.counter++ + b.Unlock() } diff --git a/query_gen/acc_builders.go b/query_gen/acc_builders.go index 37876b14..ec5aa81a 100644 --- a/query_gen/acc_builders.go +++ b/query_gen/acc_builders.go @@ -2,32 +2,33 @@ package qgen import ( "database/sql" + //"fmt" "strconv" ) type accDeleteBuilder struct { - table string - where string + table string + where string dateCutoff *dateCutoff // We might want to do this in a slightly less hacky way build *Accumulator } -func (b *accDeleteBuilder) Where(where string) *accDeleteBuilder { +func (b *accDeleteBuilder) Where(w string) *accDeleteBuilder { if b.where != "" { b.where += " AND " } - b.where += where + b.where += w return b } -func (b *accDeleteBuilder) DateCutoff(column string, quantity int, unit string) *accDeleteBuilder { - b.dateCutoff = &dateCutoff{column, quantity, unit, 0} +func (b *accDeleteBuilder) DateCutoff(col string, quantity int, unit string) *accDeleteBuilder { + b.dateCutoff = &dateCutoff{col, quantity, unit, 0} return b } -func (b *accDeleteBuilder) DateOlderThan(column string, quantity int, unit string) *accDeleteBuilder { - b.dateCutoff = &dateCutoff{column, quantity, unit, 1} +func (b *accDeleteBuilder) DateOlderThan(col string, quantity int, unit string) *accDeleteBuilder { + b.dateCutoff = &dateCutoff{col, quantity, unit, 1} return b } @@ -78,13 +79,13 @@ func (u *accUpdateBuilder) Where(where string) *accUpdateBuilder { return u } -func (b *accUpdateBuilder) DateCutoff(column string, quantity int, unit string) *accUpdateBuilder { - b.up.dateCutoff = &dateCutoff{column, quantity, unit, 0} +func (b *accUpdateBuilder) DateCutoff(col string, quantity int, unit string) *accUpdateBuilder { + b.up.dateCutoff = &dateCutoff{col, quantity, unit, 0} return b } -func (b *accUpdateBuilder) DateOlderThan(column string, quantity int, unit string) *accUpdateBuilder { - b.up.dateCutoff = &dateCutoff{column, quantity, unit, 1} +func (b *accUpdateBuilder) DateOlderThan(col string, quantity int, unit string) *accUpdateBuilder { + b.up.dateCutoff = &dateCutoff{col, quantity, unit, 1} return b } @@ -105,6 +106,7 @@ func (b *accUpdateBuilder) Exec(args ...interface{}) (res sql.Result, err error) if err != nil { return res, err } + //fmt.Println("query:", query) return b.build.exec(query, args...) } @@ -121,13 +123,13 @@ type AccSelectBuilder struct { build *Accumulator } -func (b *AccSelectBuilder) Columns(columns string) *AccSelectBuilder { - b.columns = columns +func (b *AccSelectBuilder) Columns(cols string) *AccSelectBuilder { + b.columns = cols return b } -func (b *AccSelectBuilder) Cols(columns string) *AccSelectBuilder { - b.columns = columns +func (b *AccSelectBuilder) Cols(cols string) *AccSelectBuilder { + b.columns = cols return b } @@ -140,13 +142,13 @@ func (b *AccSelectBuilder) Where(where string) *AccSelectBuilder { } // TODO: Don't implement the SQL at the accumulator level but the adapter level -func (b *AccSelectBuilder) In(column string, inList []int) *AccSelectBuilder { +func (b *AccSelectBuilder) In(col string, inList []int) *AccSelectBuilder { if len(inList) == 0 { return b } // TODO: Optimise this - where := column + " IN(" + where := col + " IN(" for _, item := range inList { where += strconv.Itoa(item) + "," } @@ -160,19 +162,19 @@ func (b *AccSelectBuilder) In(column string, inList []int) *AccSelectBuilder { } // TODO: Don't implement the SQL at the accumulator level but the adapter level -func (b *AccSelectBuilder) InPQuery(column string, inList []int) (*sql.Rows, error) { +func (b *AccSelectBuilder) InPQuery(col string, inList []int) (*sql.Rows, error) { if len(inList) == 0 { return nil, sql.ErrNoRows } // TODO: Optimise this - where := column + " IN(" + where := col + " IN(" - idList := make([]interface{},len(inList)) + idList := make([]interface{}, len(inList)) for i, id := range inList { idList[i] = strconv.Itoa(id) where += "?," } - where = where[0 : len(where)-1] + ")" + where = where[0:len(where)-1] + ")" if b.where != "" { where += " AND " + b.where @@ -182,14 +184,19 @@ func (b *AccSelectBuilder) InPQuery(column string, inList []int) (*sql.Rows, err return b.Query(idList...) } -func (b *AccSelectBuilder) InQ(column string, subBuilder *AccSelectBuilder) *AccSelectBuilder { - b.inChain = subBuilder - b.inColumn = column +func (b *AccSelectBuilder) InQ(col string, sb *AccSelectBuilder) *AccSelectBuilder { + b.inChain = sb + b.inColumn = col return b } -func (b *AccSelectBuilder) DateCutoff(column string, quantity int, unit string) *AccSelectBuilder { - b.dateCutoff = &dateCutoff{column, quantity, unit, 0} +func (b *AccSelectBuilder) DateCutoff(col string, quantity int, unit string) *AccSelectBuilder { + b.dateCutoff = &dateCutoff{col, quantity, unit, 0} + return b +} + +func (b *AccSelectBuilder) DateOlderThanQ(col, unit string) *AccSelectBuilder { + b.dateCutoff = &dateCutoff{col, 0, unit, 11} return b } @@ -251,7 +258,7 @@ func (b *AccSelectBuilder) QueryRow(args ...interface{}) *AccRowWrap { } // Experimental, reduces lines -func (b *AccSelectBuilder) Each(handle func(*sql.Rows) error) error { +func (b *AccSelectBuilder) Each(h func(*sql.Rows) error) error { query, err := b.query() if err != nil { return err @@ -263,14 +270,13 @@ func (b *AccSelectBuilder) Each(handle func(*sql.Rows) error) error { defer rows.Close() for rows.Next() { - err = handle(rows) - if err != nil { + if err = h(rows); err != nil { return err } } return rows.Err() } -func (b *AccSelectBuilder) EachInt(handle func(int) error) error { +func (b *AccSelectBuilder) EachInt(h func(int) error) error { query, err := b.query() if err != nil { return err @@ -287,8 +293,7 @@ func (b *AccSelectBuilder) EachInt(handle func(int) error) error { if err != nil { return err } - err = handle(theInt) - if err != nil { + if err = h(theInt); err != nil { return err } } @@ -303,8 +308,8 @@ type accInsertBuilder struct { build *Accumulator } -func (b *accInsertBuilder) Columns(columns string) *accInsertBuilder { - b.columns = columns +func (b *accInsertBuilder) Columns(cols string) *accInsertBuilder { + b.columns = cols return b } @@ -334,7 +339,50 @@ func (b *accInsertBuilder) Run(args ...interface{}) (int, error) { if err != nil { return 0, err } + lastID, err := res.LastInsertId() + return int(lastID), err +} + +type accBulkInsertBuilder struct { + table string + columns string + fieldSet []string + + build *Accumulator +} + +func (b *accBulkInsertBuilder) Columns(cols string) *accBulkInsertBuilder { + b.columns = cols + return b +} + +func (b *accBulkInsertBuilder) Fields(fieldSet ...string) *accBulkInsertBuilder { + b.fieldSet = fieldSet + return b +} + +func (b *accBulkInsertBuilder) Prepare() *sql.Stmt { + return b.build.SimpleBulkInsert(b.table, b.columns, b.fieldSet) +} + +func (b *accBulkInsertBuilder) Exec(args ...interface{}) (res sql.Result, err error) { + query, err := b.build.adapter.SimpleBulkInsert("", b.table, b.columns, b.fieldSet) + if err != nil { + return res, err + } + return b.build.exec(query, args...) +} + +func (b *accBulkInsertBuilder) Run(args ...interface{}) (int, error) { + query, err := b.build.adapter.SimpleBulkInsert("", b.table, b.columns, b.fieldSet) + if err != nil { + return 0, err + } + res, err := b.build.exec(query, args...) + if err != nil { + return 0, err + } lastID, err := res.LastInsertId() return int(lastID), err } @@ -350,11 +398,11 @@ type accCountBuilder struct { build *Accumulator } -func (b *accCountBuilder) Where(where string) *accCountBuilder { +func (b *accCountBuilder) Where(w string) *accCountBuilder { if b.where != "" { b.where += " AND " } - b.where += where + b.where += w return b } diff --git a/query_gen/accumulator.go b/query_gen/accumulator.go index 46c6f30b..3b20184e 100644 --- a/query_gen/accumulator.go +++ b/query_gen/accumulator.go @@ -20,233 +20,241 @@ type Accumulator struct { firstErr error } -func (build *Accumulator) SetConn(conn *sql.DB) { - build.conn = conn +func (acc *Accumulator) SetConn(conn *sql.DB) { + acc.conn = conn } -func (build *Accumulator) SetAdapter(name string) error { +func (acc *Accumulator) SetAdapter(name string) error { adap, err := GetAdapter(name) if err != nil { return err } - build.adapter = adap + acc.adapter = adap return nil } -func (build *Accumulator) GetAdapter() Adapter { - return build.adapter +func (acc *Accumulator) GetAdapter() Adapter { + return acc.adapter } -func (build *Accumulator) FirstError() error { - return build.firstErr +func (acc *Accumulator) FirstError() error { + return acc.firstErr } -func (build *Accumulator) RecordError(err error) { +func (acc *Accumulator) RecordError(err error) { if err == nil { return } - if build.firstErr == nil { - build.firstErr = err + if acc.firstErr == nil { + acc.firstErr = err } } -func (build *Accumulator) prepare(res string, err error) *sql.Stmt { +func (acc *Accumulator) prepare(res string, err error) *sql.Stmt { // TODO: Can we make this less noisy on debug mode? if LogPrepares { log.Print("res: ", res) } if err != nil { - build.RecordError(err) + acc.RecordError(err) return nil } - stmt, err := build.conn.Prepare(res) - build.RecordError(err) + stmt, err := acc.conn.Prepare(res) + acc.RecordError(err) return stmt } -func (build *Accumulator) RawPrepare(res string) *sql.Stmt { - return build.prepare(res, nil) +func (acc *Accumulator) RawPrepare(res string) *sql.Stmt { + return acc.prepare(res, nil) } -func (build *Accumulator) query(query string, args ...interface{}) (rows *sql.Rows, err error) { - err = build.FirstError() +func (acc *Accumulator) query(q string, args ...interface{}) (rows *sql.Rows, err error) { + err = acc.FirstError() if err != nil { return rows, err } - return build.conn.Query(query, args...) + return acc.conn.Query(q, args...) } -func (build *Accumulator) exec(query string, args ...interface{}) (res sql.Result, err error) { - err = build.FirstError() +func (acc *Accumulator) exec(q string, args ...interface{}) (res sql.Result, err error) { + err = acc.FirstError() if err != nil { return res, err } - return build.conn.Exec(query, args...) + return acc.conn.Exec(q, args...) } -func (build *Accumulator) Tx(handler func(*TransactionBuilder) error) { - tx, err := build.conn.Begin() +func (acc *Accumulator) Tx(handler func(*TransactionBuilder) error) { + tx, err := acc.conn.Begin() if err != nil { - build.RecordError(err) + acc.RecordError(err) return } - err = handler(&TransactionBuilder{tx, build.adapter, nil}) + err = handler(&TransactionBuilder{tx, acc.adapter, nil}) if err != nil { tx.Rollback() - build.RecordError(err) + acc.RecordError(err) return } - build.RecordError(tx.Commit()) + acc.RecordError(tx.Commit()) } -func (build *Accumulator) SimpleSelect(table, columns, where, orderby, limit string) *sql.Stmt { - return build.prepare(build.adapter.SimpleSelect("", table, columns, where, orderby, limit)) +func (acc *Accumulator) SimpleSelect(table, columns, where, orderby, limit string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleSelect("", table, columns, where, orderby, limit)) } -func (build *Accumulator) SimpleCount(table, where, limit string) *sql.Stmt { - return build.prepare(build.adapter.SimpleCount("", table, where, limit)) +func (acc *Accumulator) SimpleCount(table, where, limit string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleCount("", table, where, limit)) } -func (build *Accumulator) SimpleLeftJoin(table1, table2, columns, joiners, where, orderby, limit string) *sql.Stmt { - return build.prepare(build.adapter.SimpleLeftJoin("", table1, table2, columns, joiners, where, orderby, limit)) +func (acc *Accumulator) SimpleLeftJoin(table1, table2, columns, joiners, where, orderby, limit string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleLeftJoin("", table1, table2, columns, joiners, where, orderby, limit)) } -func (build *Accumulator) SimpleInnerJoin(table1, table2, columns, joiners, where, orderby, limit string) *sql.Stmt { - return build.prepare(build.adapter.SimpleInnerJoin("", table1, table2, columns, joiners, where, orderby, limit)) +func (acc *Accumulator) SimpleInnerJoin(table1, table2, columns, joiners, where, orderby, limit string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleInnerJoin("", table1, table2, columns, joiners, where, orderby, limit)) } -func (build *Accumulator) CreateTable(table string, charset string, collation string, columns []DBTableColumn, keys []DBTableKey) *sql.Stmt { - return build.prepare(build.adapter.CreateTable("", table, charset, collation, columns, keys)) +func (acc *Accumulator) CreateTable(table, charset, collation string, columns []DBTableColumn, keys []DBTableKey) *sql.Stmt { + return acc.prepare(acc.adapter.CreateTable("", table, charset, collation, columns, keys)) } -func (build *Accumulator) SimpleInsert(table, columns, fields string) *sql.Stmt { - return build.prepare(build.adapter.SimpleInsert("", table, columns, fields)) +func (acc *Accumulator) SimpleInsert(table, columns, fields string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleInsert("", table, columns, fields)) } -func (build *Accumulator) SimpleInsertSelect(ins DBInsert, sel DBSelect) *sql.Stmt { - return build.prepare(build.adapter.SimpleInsertSelect("", ins, sel)) +func (acc *Accumulator) SimpleBulkInsert(table, cols string, fieldSet []string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleBulkInsert("", table, cols, fieldSet)) } -func (build *Accumulator) SimpleInsertLeftJoin(ins DBInsert, sel DBJoin) *sql.Stmt { - return build.prepare(build.adapter.SimpleInsertLeftJoin("", ins, sel)) +func (acc *Accumulator) SimpleInsertSelect(ins DBInsert, sel DBSelect) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleInsertSelect("", ins, sel)) } -func (build *Accumulator) SimpleInsertInnerJoin(ins DBInsert, sel DBJoin) *sql.Stmt { - return build.prepare(build.adapter.SimpleInsertInnerJoin("", ins, sel)) +func (acc *Accumulator) SimpleInsertLeftJoin(ins DBInsert, sel DBJoin) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleInsertLeftJoin("", ins, sel)) } -func (build *Accumulator) SimpleUpdate(table, set, where string) *sql.Stmt { - return build.prepare(build.adapter.SimpleUpdate(qUpdate(table, set, where))) +func (acc *Accumulator) SimpleInsertInnerJoin(ins DBInsert, sel DBJoin) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleInsertInnerJoin("", ins, sel)) } -func (build *Accumulator) SimpleUpdateSelect(table, set, table2, cols, where, orderby, limit string) *sql.Stmt { - pre := qUpdate(table, set, "").WhereQ(build.GetAdapter().Builder().Select().Table(table2).Columns(cols).Where(where).Orderby(orderby).Limit(limit)) - return build.prepare(build.adapter.SimpleUpdateSelect(pre)) +func (acc *Accumulator) SimpleUpdate(table, set, where string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleUpdate(qUpdate(table, set, where))) } -func (build *Accumulator) SimpleDelete(table, where string) *sql.Stmt { - return build.prepare(build.adapter.SimpleDelete("", table, where)) +func (acc *Accumulator) SimpleUpdateSelect(table, set, table2, cols, where, orderby, limit string) *sql.Stmt { + pre := qUpdate(table, set, "").WhereQ(acc.GetAdapter().Builder().Select().Table(table2).Columns(cols).Where(where).Orderby(orderby).Limit(limit)) + return acc.prepare(acc.adapter.SimpleUpdateSelect(pre)) +} + +func (acc *Accumulator) SimpleDelete(table, where string) *sql.Stmt { + return acc.prepare(acc.adapter.SimpleDelete("", table, where)) } // I don't know why you need this, but here it is x.x -func (build *Accumulator) Purge(table string) *sql.Stmt { - return build.prepare(build.adapter.Purge("", table)) +func (acc *Accumulator) Purge(table string) *sql.Stmt { + return acc.prepare(acc.adapter.Purge("", table)) } -func (build *Accumulator) prepareTx(tx *sql.Tx, res string, err error) (stmt *sql.Stmt) { +func (acc *Accumulator) prepareTx(tx *sql.Tx, res string, err error) (stmt *sql.Stmt) { if err != nil { - build.RecordError(err) + acc.RecordError(err) return nil } stmt, err = tx.Prepare(res) - build.RecordError(err) + acc.RecordError(err) return stmt } // These ones support transactions -func (build *Accumulator) SimpleSelectTx(tx *sql.Tx, table string, columns string, where string, orderby string, limit string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleSelect("", table, columns, where, orderby, limit) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleSelectTx(tx *sql.Tx, table, columns, where, orderby, limit string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleSelect("", table, columns, where, orderby, limit) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleCountTx(tx *sql.Tx, table string, where string, limit string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleCount("", table, where, limit) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleCountTx(tx *sql.Tx, table, where, limit string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleCount("", table, where, limit) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleLeftJoinTx(tx *sql.Tx, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleLeftJoin("", table1, table2, columns, joiners, where, orderby, limit) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleLeftJoinTx(tx *sql.Tx, table1, table2, columns, joiners, where, orderby, limit string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleLeftJoin("", table1, table2, columns, joiners, where, orderby, limit) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleInnerJoinTx(tx *sql.Tx, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleInnerJoin("", table1, table2, columns, joiners, where, orderby, limit) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleInnerJoinTx(tx *sql.Tx, table1, table2, columns, joiners, where, orderby, limit string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleInnerJoin("", table1, table2, columns, joiners, where, orderby, limit) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) CreateTableTx(tx *sql.Tx, table string, charset string, collation string, columns []DBTableColumn, keys []DBTableKey) (stmt *sql.Stmt) { - res, err := build.adapter.CreateTable("", table, charset, collation, columns, keys) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) CreateTableTx(tx *sql.Tx, table, charset, collation string, columns []DBTableColumn, keys []DBTableKey) (stmt *sql.Stmt) { + res, err := acc.adapter.CreateTable("", table, charset, collation, columns, keys) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleInsertTx(tx *sql.Tx, table string, columns string, fields string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleInsert("", table, columns, fields) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleInsertTx(tx *sql.Tx, table, columns, fields string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleInsert("", table, columns, fields) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleInsertSelectTx(tx *sql.Tx, ins DBInsert, sel DBSelect) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleInsertSelect("", ins, sel) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleInsertSelectTx(tx *sql.Tx, ins DBInsert, sel DBSelect) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleInsertSelect("", ins, sel) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleInsertLeftJoinTx(tx *sql.Tx, ins DBInsert, sel DBJoin) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleInsertLeftJoin("", ins, sel) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleInsertLeftJoinTx(tx *sql.Tx, ins DBInsert, sel DBJoin) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleInsertLeftJoin("", ins, sel) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleInsertInnerJoinTx(tx *sql.Tx, ins DBInsert, sel DBJoin) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleInsertInnerJoin("", ins, sel) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleInsertInnerJoinTx(tx *sql.Tx, ins DBInsert, sel DBJoin) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleInsertInnerJoin("", ins, sel) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleUpdateTx(tx *sql.Tx, table string, set string, where string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleUpdate(qUpdate(table, set, where)) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleUpdateTx(tx *sql.Tx, table, set, where string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleUpdate(qUpdate(table, set, where)) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) SimpleDeleteTx(tx *sql.Tx, table string, where string) (stmt *sql.Stmt) { - res, err := build.adapter.SimpleDelete("", table, where) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) SimpleDeleteTx(tx *sql.Tx, table, where string) (stmt *sql.Stmt) { + res, err := acc.adapter.SimpleDelete("", table, where) + return acc.prepareTx(tx, res, err) } // I don't know why you need this, but here it is x.x -func (build *Accumulator) PurgeTx(tx *sql.Tx, table string) (stmt *sql.Stmt) { - res, err := build.adapter.Purge("", table) - return build.prepareTx(tx, res, err) +func (acc *Accumulator) PurgeTx(tx *sql.Tx, table string) (stmt *sql.Stmt) { + res, err := acc.adapter.Purge("", table) + return acc.prepareTx(tx, res, err) } -func (build *Accumulator) Delete(table string) *accDeleteBuilder { - return &accDeleteBuilder{table, "", nil, build} +func (acc *Accumulator) Delete(table string) *accDeleteBuilder { + return &accDeleteBuilder{table, "", nil, acc} } -func (build *Accumulator) Update(table string) *accUpdateBuilder { - return &accUpdateBuilder{qUpdate(table, "", ""), build} +func (acc *Accumulator) Update(table string) *accUpdateBuilder { + return &accUpdateBuilder{qUpdate(table, "", ""), acc} } -func (build *Accumulator) Select(table string) *AccSelectBuilder { - return &AccSelectBuilder{table, "", "", "", "", nil, nil, "", build} +func (acc *Accumulator) Select(table string) *AccSelectBuilder { + return &AccSelectBuilder{table, "", "", "", "", nil, nil, "", acc} } -func (build *Accumulator) Exists(tbl, col string) *AccSelectBuilder { - return build.Select(tbl).Columns(col).Where(col + "=?") +func (acc *Accumulator) Exists(tbl, col string) *AccSelectBuilder { + return acc.Select(tbl).Columns(col).Where(col + "=?") } -func (build *Accumulator) Insert(table string) *accInsertBuilder { - return &accInsertBuilder{table, "", "", build} +func (acc *Accumulator) Insert(table string) *accInsertBuilder { + return &accInsertBuilder{table, "", "", acc} } -func (build *Accumulator) Count(table string) *accCountBuilder { - return &accCountBuilder{table, "", "", nil, nil, "", build} +func (acc *Accumulator) BulkInsert(table string) *accBulkInsertBuilder { + return &accBulkInsertBuilder{table, "", nil, acc} +} + +func (acc *Accumulator) Count(table string) *accCountBuilder { + return &accCountBuilder{table, "", "", nil, nil, "", acc} } type SimpleModel struct { @@ -255,7 +263,7 @@ type SimpleModel struct { update *sql.Stmt } -func (build *Accumulator) SimpleModel(tbl, colstr, primary string) SimpleModel { +func (acc *Accumulator) SimpleModel(tbl, colstr, primary string) SimpleModel { var qlist, uplist string for _, col := range strings.Split(colstr, ",") { qlist += "?," @@ -268,9 +276,9 @@ func (build *Accumulator) SimpleModel(tbl, colstr, primary string) SimpleModel { where := primary + "=?" return SimpleModel{ - delete: build.Delete(tbl).Where(where).Prepare(), - create: build.Insert(tbl).Columns(colstr).Fields(qlist).Prepare(), - update: build.Update(tbl).Set(uplist).Where(where).Prepare(), + delete: acc.Delete(tbl).Where(where).Prepare(), + create: acc.Insert(tbl).Columns(colstr).Fields(qlist).Prepare(), + update: acc.Update(tbl).Set(uplist).Where(where).Prepare(), } } @@ -298,8 +306,8 @@ func (m SimpleModel) CreateID(args ...interface{}) (int, error) { return int(lastID), err } -func (build *Accumulator) Model(table string) *accModelBuilder { - return &accModelBuilder{table, "", build} +func (acc *Accumulator) Model(table string) *accModelBuilder { + return &accModelBuilder{table, "", acc} } type accModelBuilder struct { diff --git a/query_gen/mssql.go b/query_gen/mssql.go index bfef89ac..c858fae9 100644 --- a/query_gen/mssql.go +++ b/query_gen/mssql.go @@ -106,7 +106,6 @@ func (a *MssqlAdapter) parseColumn(column DBTableColumn) (col DBTableColumn, siz case "boolean": column.Type = "bit" } - if column.Size > 0 { size = " (" + strconv.Itoa(column.Size) + ")" } @@ -230,6 +229,18 @@ func (a *MssqlAdapter) AddForeignKey(name, table, column, ftable, fcolumn string } func (a *MssqlAdapter) SimpleInsert(name, table, cols, fields string) (string, error) { + q, err := a.simpleBulkInsert(name, table, cols, []string{fields}) + a.pushStatement(name, "insert", q) + return q, err +} + +func (a *MssqlAdapter) SimpleBulkInsert(name, table, cols string, fieldSet []string) (string, error) { + q, err := a.simpleBulkInsert(name, table, cols, fieldSet) + a.pushStatement(name, "bulk-insert", q) + return q, err +} + +func (a *MssqlAdapter) simpleBulkInsert(name, table, cols string, fieldSet []string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -252,21 +263,24 @@ func (a *MssqlAdapter) SimpleInsert(name, table, cols, fields string) (string, e q = q[0 : len(q)-1] q += ") VALUES (" - for _, field := range processFields(fields) { - field.Name = strings.Replace(field.Name, "UTC_TIMESTAMP()", "GETUTCDATE()", -1) - //log.Print("field.Name ", field.Name) - nameLen := len(field.Name) - if field.Name[0] == '"' && field.Name[nameLen-1] == '"' && nameLen >= 3 { - field.Name = "'" + field.Name[1:nameLen-1] + "'" + for oi, fields := range fieldSet { + if oi != 0 { + q += ",(" } - if field.Name[0] == '\'' && field.Name[nameLen-1] == '\'' && nameLen >= 3 { - field.Name = "'" + strings.Replace(field.Name[1:nameLen-1], "'", "''", -1) + "'" + for _, field := range processFields(fields) { + field.Name = strings.Replace(field.Name, "UTC_TIMESTAMP()", "GETUTCDATE()", -1) + //log.Print("field.Name ", field.Name) + nameLen := len(field.Name) + if field.Name[0] == '"' && field.Name[nameLen-1] == '"' && nameLen >= 3 { + field.Name = "'" + field.Name[1:nameLen-1] + "'" + } + if field.Name[0] == '\'' && field.Name[nameLen-1] == '\'' && nameLen >= 3 { + field.Name = "'" + strings.Replace(field.Name[1:nameLen-1], "'", "''", -1) + "'" + } + q += field.Name + "," } - q += field.Name + "," + q = q[0:len(q)-1] + ")" } - q = q[0:len(q)-1] + ")" - - a.pushStatement(name, "insert", q) return q, nil } diff --git a/query_gen/mysql.go b/query_gen/mysql.go index 99075599..5ad55bba 100644 --- a/query_gen/mysql.go +++ b/query_gen/mysql.go @@ -267,13 +267,25 @@ func (a *MysqlAdapter) AddIndex(name, table, iname, colname string) (string, err // TODO: Test to make sure everything works here // Only supports FULLTEXT right now -func (a *MysqlAdapter) AddKey(name, table, column string, key DBTableKey) (string, error) { +func (a *MysqlAdapter) AddKey(name, table, cols string, key DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } + if cols == "" { + return "", errors.New("You need to specify columns") + } + + var colstr string + for _, col := range strings.Split(cols,",") { + colstr += "`" + col + "`," + } + if len(colstr) > 1 { + colstr = colstr[:len(colstr)-1] + } + var q string if key.Type == "fulltext" { - q = "ALTER TABLE `" + table + "` ADD FULLTEXT(`" + column + "`)" + q = "ALTER TABLE `" + table + "` ADD FULLTEXT(" + colstr + ")" } else { return "", errors.New("Only fulltext is supported by AddKey right now") } @@ -363,6 +375,50 @@ func (a *MysqlAdapter) SimpleInsert(name, table, columns, fields string) (string return q, nil } +func (a *MysqlAdapter) SimpleBulkInsert(name, table, columns string, fieldSet []string) (string, error) { + if table == "" { + return "", errors.New("You need a name for this table") + } + + var sb strings.Builder + sb.Grow(silen1 + len(table)) + sb.WriteString("INSERT INTO `") + sb.WriteString(table) + sb.WriteString("`(") + if columns != "" { + sb.WriteString(a.buildColumns(columns)) + sb.WriteString(") VALUES (") + for oi, fields := range fieldSet { + if oi != 0 { + sb.WriteString(",(") + } + fs := processFields(fields) + sb.Grow(len(fs) * 3) + for i, field := range fs { + if i != 0 { + sb.WriteString(",") + } + nameLen := len(field.Name) + if field.Name[0] == '"' && field.Name[nameLen-1] == '"' && nameLen >= 3 { + field.Name = "'" + field.Name[1:nameLen-1] + "'" + } + if field.Name[0] == '\'' && field.Name[nameLen-1] == '\'' && nameLen >= 3 { + field.Name = "'" + strings.Replace(field.Name[1:nameLen-1], "'", "''", -1) + "'" + } + sb.WriteString(field.Name) + } + sb.WriteString(")") + } + } else { + sb.WriteString(") VALUES ()") + } + + // TODO: Shunt the table name logic and associated stmt list up to the a higher layer to reduce the amount of unnecessary overhead in the builder / accumulator + q := sb.String() + a.pushStatement(name, "bulk-insert", q) + return q, nil +} + func (a *MysqlAdapter) buildColumns(columns string) (q string) { if columns == "" { return "" diff --git a/query_gen/pgsql.go b/query_gen/pgsql.go index a70807a9..9613a07a 100644 --- a/query_gen/pgsql.go +++ b/query_gen/pgsql.go @@ -230,6 +230,11 @@ func (a *PgsqlAdapter) SimpleInsert(name, table, columns, fields string) (string return q, nil } +// TODO: Implement this +func (a *PgsqlAdapter) SimpleBulkInsert(name, table, columns string, fieldSet []string) (string, error) { + return "", nil +} + func (a *PgsqlAdapter) buildColumns(cols string) (q string) { if cols == "" { return "" diff --git a/query_gen/querygen.go b/query_gen/querygen.go index fb6237d2..cf71ca0b 100644 --- a/query_gen/querygen.go +++ b/query_gen/querygen.go @@ -144,6 +144,7 @@ type Adapter interface { RemoveIndex(name, table, column string) (string, error) AddForeignKey(name, table, column, ftable, fcolumn string, cascade bool) (out string, e error) SimpleInsert(name, table, columns, fields string) (string, error) + SimpleBulkInsert(name, table, columns string, fieldSet []string) (string, error) SimpleUpdate(b *updatePrebuilder) (string, error) SimpleUpdateSelect(b *updatePrebuilder) (string, error) // ! Experimental SimpleDelete(name, table, where string) (string, error)