From c928c84c95e788245fda631266aebc5fc8e14441 Mon Sep 17 00:00:00 2001 From: Azareal Date: Thu, 31 Oct 2019 17:25:56 +1000 Subject: [PATCH] Reduce boilerplate in the currently defunct mssql.go and pgsql.go --- gen_mssql.go | 4 +- query_gen/mssql.go | 562 ++++++++++++++++++++++----------------------- query_gen/pgsql.go | 175 +++++++------- 3 files changed, 362 insertions(+), 379 deletions(-) diff --git a/gen_mssql.go b/gen_mssql.go index 896661c4..9a35023b 100644 --- a/gen_mssql.go +++ b/gen_mssql.go @@ -45,10 +45,10 @@ func _gen_mssql() (err error) { } common.DebugLog("Preparing getForumTopics statement.") - stmts.getForumTopics, err = db.Prepare("SELECT [topics].[tid],[topics].[title],[topics].[content],[topics].[createdBy],[topics].[is_closed],[topics].[sticky],[topics].[createdAt],[topics].[lastReplyAt],[topics].[parentID],[users].[name],[users].[avatar] FROM [topics] LEFT JOIN [users] ON [topics].[createdBy] = [users].[uid] WHERE [topics].[parentID] = ?1 ORDER BY topics.sticky DESC,topics.lastReplyAt DESC,topics.createdBy DESC") + stmts.getForumTopics, err = db.Prepare("SELECT [topics].[tid],[topics].[title],[topics].[content],[topics].[createdBy],[topics].[is_closed],[topics].[sticky],[topics].[createdAt],[topics].[lastReplyAt],[topics].[parentID],[users].[name],[users].[avatar] FROM [topics] LEFT JOIN [users] ON [topics].[createdBy]=[users].[uid] WHERE [topics].[parentID] = ?1 ORDER BY topics.sticky DESC,topics.lastReplyAt DESC,topics.createdBy DESC") if err != nil { log.Print("Error in getForumTopics statement.") - log.Print("Bad Query: ","SELECT [topics].[tid],[topics].[title],[topics].[content],[topics].[createdBy],[topics].[is_closed],[topics].[sticky],[topics].[createdAt],[topics].[lastReplyAt],[topics].[parentID],[users].[name],[users].[avatar] FROM [topics] LEFT JOIN [users] ON [topics].[createdBy] = [users].[uid] WHERE [topics].[parentID] = ?1 ORDER BY topics.sticky DESC,topics.lastReplyAt DESC,topics.createdBy DESC") + log.Print("Bad Query: ","SELECT [topics].[tid],[topics].[title],[topics].[content],[topics].[createdBy],[topics].[is_closed],[topics].[sticky],[topics].[createdAt],[topics].[lastReplyAt],[topics].[parentID],[users].[name],[users].[avatar] FROM [topics] LEFT JOIN [users] ON [topics].[createdBy]=[users].[uid] WHERE [topics].[parentID] = ?1 ORDER BY topics.sticky DESC,topics.lastReplyAt DESC,topics.createdBy DESC") return err } diff --git a/query_gen/mssql.go b/query_gen/mssql.go index a86c75a4..8818b1de 100644 --- a/query_gen/mssql.go +++ b/query_gen/mssql.go @@ -23,40 +23,40 @@ type MssqlAdapter struct { } // GetName gives you the name of the database adapter. In this case, it's Mssql -func (adapter *MssqlAdapter) GetName() string { - return adapter.Name +func (a *MssqlAdapter) GetName() string { + return a.Name } -func (adapter *MssqlAdapter) GetStmt(name string) DBStmt { - return adapter.Buffer[name] +func (a *MssqlAdapter) GetStmt(name string) DBStmt { + return a.Buffer[name] } -func (adapter *MssqlAdapter) GetStmts() map[string]DBStmt { - return adapter.Buffer +func (a *MssqlAdapter) GetStmts() map[string]DBStmt { + return a.Buffer } // TODO: Implement this -func (adapter *MssqlAdapter) BuildConn(config map[string]string) (*sql.DB, error) { +func (a *MssqlAdapter) BuildConn(config map[string]string) (*sql.DB, error) { return nil, nil } -func (adapter *MssqlAdapter) DbVersion() string { +func (a *MssqlAdapter) DbVersion() string { return "SELECT CONCAT(SERVERPROPERTY('productversion'), SERVERPROPERTY ('productlevel'), SERVERPROPERTY ('edition'))" } -func (adapter *MssqlAdapter) DropTable(name string, table string) (string, error) { +func (a *MssqlAdapter) DropTable(name string, table string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - querystr := "DROP TABLE IF EXISTS [" + table + "];" - adapter.pushStatement(name, "drop-table", querystr) - return querystr, nil + q := "DROP TABLE IF EXISTS [" + table + "];" + a.pushStatement(name, "drop-table", q) + return q, nil } // TODO: Add support for foreign keys? // TODO: Convert any remaining stringy types to nvarchar // We may need to change the CreateTable API to better suit Mssql and the other database drivers which are coming up -func (adapter *MssqlAdapter) CreateTable(name string, table string, charset string, collation string, columns []DBTableColumn, keys []DBTableKey) (string, error) { +func (a *MssqlAdapter) CreateTable(name string, table string, charset string, collation string, columns []DBTableColumn, keys []DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -64,32 +64,32 @@ func (adapter *MssqlAdapter) CreateTable(name string, table string, charset stri return "", errors.New("You can't have a table with no columns") } - var querystr = "CREATE TABLE [" + table + "] (" + q := "CREATE TABLE [" + table + "] (" for _, column := range columns { - column, size, end := adapter.parseColumn(column) - querystr += "\n\t[" + column.Name + "] " + column.Type + size + end + "," + column, size, end := a.parseColumn(column) + q += "\n\t[" + column.Name + "] " + column.Type + size + end + "," } if len(keys) > 0 { for _, key := range keys { - querystr += "\n\t" + key.Type + q += "\n\t" + key.Type if key.Type != "unique" { - querystr += " key" + q += " key" } - querystr += "(" + q += "(" for _, column := range strings.Split(key.Columns, ",") { - querystr += "[" + column + "]," + q += "[" + column + "]," } - querystr = querystr[0:len(querystr)-1] + ")," + q = q[0:len(q)-1] + ")," } } - querystr = querystr[0:len(querystr)-1] + "\n);" - adapter.pushStatement(name, "create-table", querystr) - return querystr, nil + q = q[0:len(q)-1] + "\n);" + a.pushStatement(name, "create-table", q) + return q, nil } -func (adapter *MssqlAdapter) parseColumn(column DBTableColumn) (col DBTableColumn, size string, end string) { +func (a *MssqlAdapter) parseColumn(column DBTableColumn) (col DBTableColumn, size string, end string) { var max, createdAt bool switch column.Type { case "createdAt": @@ -118,7 +118,7 @@ func (adapter *MssqlAdapter) parseColumn(column DBTableColumn) (col DBTableColum end = " DEFAULT " if createdAt { end += "GETUTCDATE()" // TODO: Use GETUTCDATE() in updates instead of the neutral format - } else if adapter.stringyType(column.Type) && column.Default != "''" { + } else if a.stringyType(column.Type) && column.Default != "''" { end += "'" + column.Default + "'" } else { end += column.Default @@ -137,20 +137,20 @@ func (adapter *MssqlAdapter) parseColumn(column DBTableColumn) (col DBTableColum // TODO: Test this, not sure if some things work // TODO: Add support for keys -func (adapter *MssqlAdapter) AddColumn(name string, table string, column DBTableColumn, key *DBTableKey) (string, error) { +func (a *MssqlAdapter) AddColumn(name string, table string, column DBTableColumn, key *DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - column, size, end := adapter.parseColumn(column) - querystr := "ALTER TABLE [" + table + "] ADD [" + column.Name + "] " + column.Type + size + end + ";" - adapter.pushStatement(name, "add-column", querystr) - return querystr, nil + column, size, end := a.parseColumn(column) + q := "ALTER TABLE [" + table + "] ADD [" + column.Name + "] " + column.Type + size + end + ";" + a.pushStatement(name, "add-column", q) + return q, nil } // TODO: Implement this // TODO: Test to make sure everything works here -func (adapter *MssqlAdapter) AddIndex(name string, table string, iname string, colname string) (string, error) { +func (a *MssqlAdapter) AddIndex(name string, table string, iname string, colname string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -165,7 +165,7 @@ func (adapter *MssqlAdapter) AddIndex(name string, table string, iname string, c // TODO: Implement this // TODO: Test to make sure everything works here -func (adapter *MssqlAdapter) AddKey(name string, table string, column string, key DBTableKey) (string, error) { +func (a *MssqlAdapter) AddKey(name string, table string, column string, key DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -177,47 +177,47 @@ func (adapter *MssqlAdapter) AddKey(name string, table string, column string, ke // TODO: Implement this // TODO: Test to make sure everything works here -func (adapter *MssqlAdapter) AddForeignKey(name string, table string, column string, ftable string, fcolumn string, cascade bool) (out string, e error) { - var c = func(str string, val bool) { +func (a *MssqlAdapter) AddForeignKey(name string, table string, column string, ftable string, fcolumn string, cascade bool) (out string, e error) { + c := func(str string, val bool) { if e != nil || !val { return } - e = errors.New("You need a "+str+" for this table") + e = errors.New("You need a " + str + " for this table") } - c("name",table=="") - c("column",column=="") - c("ftable",ftable=="") - c("fcolumn",fcolumn=="") + c("name", table == "") + c("column", column == "") + c("ftable", ftable == "") + c("fcolumn", fcolumn == "") if e != nil { return "", e } return "", errors.New("not implemented") } -func (adapter *MssqlAdapter) SimpleInsert(name string, table string, columns string, fields string) (string, error) { +func (a *MssqlAdapter) SimpleInsert(name string, table string, columns string, fields string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - var querystr = "INSERT INTO [" + table + "] (" + q := "INSERT INTO [" + table + "] (" if columns == "" { - querystr += ") VALUES ()" - adapter.pushStatement(name, "insert", querystr) - return querystr, nil + q += ") VALUES ()" + a.pushStatement(name, "insert", q) + return q, nil } // Escape the column names, just in case we've used a reserved keyword for _, column := range processColumns(columns) { if column.Type == "function" { - querystr += column.Left + "," + q += column.Left + "," } else { - querystr += "[" + column.Left + "]," + q += "[" + column.Left + "]," } } // Remove the trailing comma - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] - querystr += ") VALUES (" + q += ") VALUES (" for _, field := range processFields(fields) { field.Name = strings.Replace(field.Name, "UTC_TIMESTAMP()", "GETUTCDATE()", -1) //log.Print("field.Name ", field.Name) @@ -228,18 +228,18 @@ func (adapter *MssqlAdapter) SimpleInsert(name string, table string, columns str if field.Name[0] == '\'' && field.Name[nameLen-1] == '\'' && nameLen >= 3 { field.Name = "'" + strings.Replace(field.Name[1:nameLen-1], "'", "''", -1) + "'" } - querystr += field.Name + "," + q += field.Name + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0:len(q)-1] + ")" - adapter.pushStatement(name, "insert", querystr+")") - return querystr + ")", nil + a.pushStatement(name, "insert", q) + return q, nil } // ! DEPRECATED -func (adapter *MssqlAdapter) SimpleReplace(name string, table string, columns string, fields string) (string, error) { +func (a *MssqlAdapter) SimpleReplace(name string, table string, columns string, fields string) (string, error) { log.Print("In SimpleReplace") - key, ok := adapter.keys[table] + key, ok := a.keys[table] if !ok { return "", errors.New("Unable to elide key from table '" + table + "', please use SimpleUpsert (coming soon!) instead") } @@ -269,10 +269,10 @@ func (adapter *MssqlAdapter) SimpleReplace(name string, table string, columns st continue } } - return adapter.SimpleUpsert(name, table, columns, fields, "key = "+keyValue) + return a.SimpleUpsert(name, table, columns, fields, "key = "+keyValue) } -func (adapter *MssqlAdapter) SimpleUpsert(name string, table string, columns string, fields string, where string) (string, error) { +func (a *MssqlAdapter) SimpleUpsert(name string, table string, columns string, fields string, where string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -285,9 +285,8 @@ func (adapter *MssqlAdapter) SimpleUpsert(name string, table string, columns str var fieldCount int var fieldOutput string - var querystr = "MERGE [" + table + "] WITH(HOLDLOCK) as t1 USING (VALUES(" - - var parsedFields = processFields(fields) + q := "MERGE [" + table + "] WITH(HOLDLOCK) as t1 USING (VALUES(" + parsedFields := processFields(fields) for _, field := range parsedFields { fieldCount++ field.Name = strings.Replace(field.Name, "UTC_TIMESTAMP()", "GETUTCDATE()", -1) @@ -302,14 +301,13 @@ func (adapter *MssqlAdapter) SimpleUpsert(name string, table string, columns str fieldOutput += field.Name + "," } fieldOutput = fieldOutput[0 : len(fieldOutput)-1] - querystr += fieldOutput + ")) AS updates (" + q += fieldOutput + ")) AS updates (" // nolint The linter wants this to be less readable for fieldID, _ := range parsedFields { - querystr += "f" + strconv.Itoa(fieldID) + "," + q += "f" + strconv.Itoa(fieldID) + "," } - querystr = querystr[0 : len(querystr)-1] - querystr += ") ON " + q = q[0:len(q)-1] + ") ON " //querystr += "t1.[" + key + "] = " // Add support for BETWEEN x.x @@ -317,25 +315,25 @@ func (adapter *MssqlAdapter) SimpleUpsert(name string, table string, columns str for _, token := range loc.Expr { switch token.Type { case "substitute": - querystr += " ?" + q += " ?" case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } } - var matched = " WHEN MATCHED THEN UPDATE SET " - var notMatched = "WHEN NOT MATCHED THEN INSERT(" + matched := " WHEN MATCHED THEN UPDATE SET " + notMatched := "WHEN NOT MATCHED THEN INSERT(" var fieldList string // Escape the column names, just in case we've used a reserved keyword @@ -355,17 +353,17 @@ func (adapter *MssqlAdapter) SimpleUpsert(name string, table string, columns str fieldList = fieldList[0 : len(fieldList)-1] notMatched += ") VALUES (" + fieldList + ");" - querystr += matched + " " + notMatched + q += matched + " " + notMatched // TODO: Run this on debug mode? if name[0] == '_' { - log.Print(name+" query: ", querystr) + log.Print(name+" query: ", q) } - adapter.pushStatement(name, "upsert", querystr) - return querystr, nil + a.pushStatement(name, "upsert", q) + return q, nil } -func (adapter *MssqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) { +func (a *MssqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) { if up.table == "" { return "", errors.New("You need a name for this table") } @@ -373,35 +371,35 @@ func (adapter *MssqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) return "", errors.New("You need to set data in this update statement") } - var querystr = "UPDATE [" + up.table + "] SET " + q := "UPDATE [" + up.table + "] SET " for _, item := range processSet(up.set) { - querystr += "[" + item.Column + "] =" + q += "[" + item.Column + "] =" for _, token := range item.Expr { switch token.Type { case "substitute": - querystr += " ?" + q += " ?" case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += "," + q += "," } // Remove the trailing comma - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] // Add support for BETWEEN x.x if len(up.where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(up.where) { for _, token := range loc.Expr { switch token.Type { @@ -410,80 +408,80 @@ func (adapter *MssqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } - adapter.pushStatement(up.name, "update", querystr) - return querystr, nil + a.pushStatement(up.name, "update", q) + return q, nil } -func (adapter *MssqlAdapter) SimpleUpdateSelect(b *updatePrebuilder) (string, error) { +func (a *MssqlAdapter) SimpleUpdateSelect(b *updatePrebuilder) (string, error) { return "", errors.New("not implemented") } -func (adapter *MssqlAdapter) SimpleDelete(name string, table string, where string) (string, error) { +func (a *MssqlAdapter) SimpleDelete(name string, table string, where string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } if where == "" { return "", errors.New("You need to specify what data you want to delete") } - - var querystr = "DELETE FROM [" + table + "] WHERE" + q := "DELETE FROM [" + table + "] WHERE" // Add support for BETWEEN x.x for _, loc := range processWhere(where) { for _, token := range loc.Expr { switch token.Type { case "substitute": - querystr += " ?" + q += " ?" case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = strings.TrimSpace(querystr[0 : len(querystr)-4]) - adapter.pushStatement(name, "delete", querystr) - return querystr, nil + q = strings.TrimSpace(q[0 : len(q)-4]) + a.pushStatement(name, "delete", q) + return q, nil } -func (adapter *MssqlAdapter) ComplexDelete(b *deletePrebuilder) (string, error) { +func (a *MssqlAdapter) ComplexDelete(b *deletePrebuilder) (string, error) { return "", errors.New("not implemented") } // We don't want to accidentally wipe tables, so we'll have a separate method for purging tables instead -func (adapter *MssqlAdapter) Purge(name string, table string) (string, error) { +func (a *MssqlAdapter) Purge(name string, table string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - adapter.pushStatement(name, "purge", "DELETE FROM ["+table+"]") - return "DELETE FROM [" + table + "]", nil + q := "DELETE FROM [" + table + "]" + a.pushStatement(name, "purge", q) + return q, nil } -func (adapter *MssqlAdapter) SimpleSelect(name string, table string, columns string, where string, orderby string, limit string) (string, error) { +func (a *MssqlAdapter) SimpleSelect(name string, table string, columns string, where string, orderby string, limit string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -494,56 +492,53 @@ func (adapter *MssqlAdapter) SimpleSelect(name string, table string, columns str if len(orderby) == 0 && limit != "" { return "", errors.New("Orderby needs to be set to use limit on Mssql") } - - var substituteCount = 0 - var querystr = "" + substituteCount := 0 + q := "" // Escape the column names, just in case we've used a reserved keyword - var colslice = strings.Split(strings.TrimSpace(columns), ",") + colslice := strings.Split(strings.TrimSpace(columns), ",") for _, column := range colslice { - querystr += "[" + strings.TrimSpace(column) + "]," + q += "[" + strings.TrimSpace(column) + "]," } - querystr = querystr[0 : len(querystr)-1] - - querystr += " FROM [" + table + "]" + q = q[0:len(q)-1] + " FROM [" + table + "]" // Add support for BETWEEN x.x if len(where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(where) { for _, token := range loc.Expr { switch token.Type { case "substitute": substituteCount++ - querystr += " ?" + strconv.Itoa(substituteCount) + q += " ?" + strconv.Itoa(substituteCount) case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up // MSSQL seems to convert the formats? so we'll compare it with a regular date. Do this with the other methods too? if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } // TODO: MSSQL requires ORDER BY for LIMIT if len(orderby) != 0 { - querystr += " ORDER BY " + q += " ORDER BY " for _, column := range processOrderby(orderby) { // TODO: We might want to escape this column - querystr += column.Column + " " + strings.ToUpper(column.Order) + "," + q += column.Column + " " + strings.ToUpper(column.Order) + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] } if limit != "" { @@ -552,9 +547,9 @@ func (adapter *MssqlAdapter) SimpleSelect(name string, table string, columns str if limiter.Offset != "" { if limiter.Offset == "?" { substituteCount++ - querystr += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" + q += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" } else { - querystr += " OFFSET " + limiter.Offset + " ROWS" + q += " OFFSET " + limiter.Offset + " ROWS" } } @@ -564,25 +559,25 @@ func (adapter *MssqlAdapter) SimpleSelect(name string, table string, columns str substituteCount++ limiter.MaxCount = "?" + strconv.Itoa(substituteCount) } - querystr += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " + q += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " } } - querystr = strings.TrimSpace("SELECT " + querystr) + q = strings.TrimSpace("SELECT " + q) // TODO: Run this on debug mode? if name[0] == '_' && limit == "" { - log.Print(name+" query: ", querystr) + log.Print(name+" query: ", q) } - adapter.pushStatement(name, "select", querystr) - return querystr, nil + a.pushStatement(name, "select", q) + return q, nil } // TODO: ComplexSelect -func (adapter *MssqlAdapter) ComplexSelect(preBuilder *selectPrebuilder) (string, error) { +func (a *MssqlAdapter) ComplexSelect(preBuilder *selectPrebuilder) (string, error) { return "", nil } -func (adapter *MssqlAdapter) SimpleLeftJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { +func (a *MssqlAdapter) SimpleLeftJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { if table1 == "" { return "", errors.New("You need a name for the left table") } @@ -599,12 +594,11 @@ func (adapter *MssqlAdapter) SimpleLeftJoin(name string, table1 string, table2 s if len(orderby) == 0 && limit != "" { return "", errors.New("Orderby needs to be set to use limit on Mssql") } - var substituteCount = 0 - var querystr = "" + substituteCount := 0 + q := "" for _, column := range processColumns(columns) { var source, alias string - // Escape the column names, just in case we've used a reserved keyword if column.Table != "" { source = "[" + column.Table + "].[" + column.Left + "]" @@ -617,64 +611,64 @@ func (adapter *MssqlAdapter) SimpleLeftJoin(name string, table1 string, table2 s if column.Alias != "" { alias = " AS '" + column.Alias + "'" } - querystr += source + alias + "," + q += source + alias + "," } // Remove the trailing comma - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] - querystr += " FROM [" + table1 + "] LEFT JOIN [" + table2 + "] ON " - for _, joiner := range processJoiner(joiners) { - querystr += "[" + joiner.LeftTable + "].[" + joiner.LeftColumn + "] " + joiner.Operator + " [" + joiner.RightTable + "].[" + joiner.RightColumn + "] AND " + q += " FROM [" + table1 + "] LEFT JOIN [" + table2 + "] ON " + for _, j := range processJoiner(joiners) { + q += "[" + j.LeftTable + "].[" + j.LeftColumn + "]" + j.Operator + "[" + j.RightTable + "].[" + j.RightColumn + "] AND " } // Remove the trailing AND - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] // Add support for BETWEEN x.x if len(where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(where) { for _, token := range loc.Expr { switch token.Type { case "substitute": substituteCount++ - querystr += " ?" + strconv.Itoa(substituteCount) + q += " ?" + strconv.Itoa(substituteCount) case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": halves := strings.Split(token.Contents, ".") if len(halves) == 2 { - querystr += " [" + halves[0] + "].[" + halves[1] + "]" + q += " [" + halves[0] + "].[" + halves[1] + "]" } else { - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" } case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } // TODO: MSSQL requires ORDER BY for LIMIT if len(orderby) != 0 { - querystr += " ORDER BY " + q += " ORDER BY " for _, column := range processOrderby(orderby) { log.Print("column: ", column) // TODO: We might want to escape this column - querystr += column.Column + " " + strings.ToUpper(column.Order) + "," + q += column.Column + " " + strings.ToUpper(column.Order) + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] } else if limit != "" { - key, ok := adapter.keys[table1] + key, ok := a.keys[table1] if ok { - querystr += " ORDER BY [" + table1 + "].[" + key + "]" + q += " ORDER BY [" + table1 + "].[" + key + "]" } } @@ -683,9 +677,9 @@ func (adapter *MssqlAdapter) SimpleLeftJoin(name string, table1 string, table2 s if limiter.Offset != "" { if limiter.Offset == "?" { substituteCount++ - querystr += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" + q += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" } else { - querystr += " OFFSET " + limiter.Offset + " ROWS" + q += " OFFSET " + limiter.Offset + " ROWS" } } @@ -695,20 +689,20 @@ func (adapter *MssqlAdapter) SimpleLeftJoin(name string, table1 string, table2 s substituteCount++ limiter.MaxCount = "?" + strconv.Itoa(substituteCount) } - querystr += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " + q += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " } } - querystr = strings.TrimSpace("SELECT " + querystr) + q = strings.TrimSpace("SELECT " + q) // TODO: Run this on debug mode? if name[0] == '_' && limit == "" { - log.Print(name+" query: ", querystr) + log.Print(name+" query: ", q) } - adapter.pushStatement(name, "select", querystr) - return querystr, nil + a.pushStatement(name, "select", q) + return q, nil } -func (adapter *MssqlAdapter) SimpleInnerJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { +func (a *MssqlAdapter) SimpleInnerJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { if table1 == "" { return "", errors.New("You need a name for the left table") } @@ -725,13 +719,11 @@ func (adapter *MssqlAdapter) SimpleInnerJoin(name string, table1 string, table2 if len(orderby) == 0 && limit != "" { return "", errors.New("Orderby needs to be set to use limit on Mssql") } - - var substituteCount = 0 - var querystr = "" + substituteCount := 0 + q := "" for _, column := range processColumns(columns) { var source, alias string - // Escape the column names, just in case we've used a reserved keyword if column.Table != "" { source = "[" + column.Table + "].[" + column.Left + "]" @@ -744,65 +736,65 @@ func (adapter *MssqlAdapter) SimpleInnerJoin(name string, table1 string, table2 if column.Alias != "" { alias = " AS '" + column.Alias + "'" } - querystr += source + alias + "," + q += source + alias + "," } // Remove the trailing comma - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] - querystr += " FROM [" + table1 + "] INNER JOIN [" + table2 + "] ON " - for _, joiner := range processJoiner(joiners) { - querystr += "[" + joiner.LeftTable + "].[" + joiner.LeftColumn + "] " + joiner.Operator + " [" + joiner.RightTable + "].[" + joiner.RightColumn + "] AND " + q += " FROM [" + table1 + "] INNER JOIN [" + table2 + "] ON " + for _, j := range processJoiner(joiners) { + q += "[" + j.LeftTable + "].[" + j.LeftColumn + "]" + j.Operator + "[" + j.RightTable + "].[" + j.RightColumn + "] AND " } // Remove the trailing AND - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] // Add support for BETWEEN x.x if len(where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(where) { for _, token := range loc.Expr { switch token.Type { case "substitute": substituteCount++ - querystr += " ?" + strconv.Itoa(substituteCount) + q += " ?" + strconv.Itoa(substituteCount) case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": halves := strings.Split(token.Contents, ".") if len(halves) == 2 { - querystr += " [" + halves[0] + "].[" + halves[1] + "]" + q += " [" + halves[0] + "].[" + halves[1] + "]" } else { - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" } case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } // TODO: MSSQL requires ORDER BY for LIMIT if len(orderby) != 0 { - querystr += " ORDER BY " + q += " ORDER BY " for _, column := range processOrderby(orderby) { log.Print("column: ", column) // TODO: We might want to escape this column - querystr += column.Column + " " + strings.ToUpper(column.Order) + "," + q += column.Column + " " + strings.ToUpper(column.Order) + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] } else if limit != "" { - key, ok := adapter.keys[table1] + key, ok := a.keys[table1] if ok { log.Print("key: ", key) - querystr += " ORDER BY [" + table1 + "].[" + key + "]" + q += " ORDER BY [" + table1 + "].[" + key + "]" } } @@ -811,9 +803,9 @@ func (adapter *MssqlAdapter) SimpleInnerJoin(name string, table1 string, table2 if limiter.Offset != "" { if limiter.Offset == "?" { substituteCount++ - querystr += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" + q += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" } else { - querystr += " OFFSET " + limiter.Offset + " ROWS" + q += " OFFSET " + limiter.Offset + " ROWS" } } @@ -823,20 +815,20 @@ func (adapter *MssqlAdapter) SimpleInnerJoin(name string, table1 string, table2 substituteCount++ limiter.MaxCount = "?" + strconv.Itoa(substituteCount) } - querystr += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " + q += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " } } - querystr = strings.TrimSpace("SELECT " + querystr) + q = strings.TrimSpace("SELECT " + q) // TODO: Run this on debug mode? if name[0] == '_' && limit == "" { - log.Print(name+" query: ", querystr) + log.Print(name+" query: ", q) } - adapter.pushStatement(name, "select", querystr) - return querystr, nil + a.pushStatement(name, "select", q) + return q, nil } -func (adapter *MssqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel DBSelect) (string, error) { +func (a *MssqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel DBSelect) (string, error) { // TODO: More errors. // TODO: Add this to the MySQL adapter in order to make this problem more discoverable? if len(sel.Orderby) == 0 && sel.Limit != "" { @@ -844,79 +836,76 @@ func (adapter *MssqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel D } /* Insert */ - var querystr = "INSERT INTO [" + ins.Table + "] (" + q := "INSERT INTO [" + ins.Table + "] (" // Escape the column names, just in case we've used a reserved keyword for _, column := range processColumns(ins.Columns) { if column.Type == "function" { - querystr += column.Left + "," + q += column.Left + "," } else { - querystr += "[" + column.Left + "]," + q += "[" + column.Left + "]," } } - querystr = querystr[0:len(querystr)-1] + ") SELECT " + q = q[0:len(q)-1] + ") SELECT " /* Select */ - var substituteCount = 0 + substituteCount := 0 for _, column := range processColumns(sel.Columns) { var source, alias string - // Escape the column names, just in case we've used a reserved keyword if column.Type == "function" || column.Type == "substitute" { source = column.Left } else { source = "[" + column.Left + "]" } - if column.Alias != "" { alias = " AS [" + column.Alias + "]" } - querystr += " " + source + alias + "," + q += " " + source + alias + "," } - querystr = querystr[0 : len(querystr)-1] - querystr += " FROM [" + sel.Table + "] " + q = q[0:len(q)-1] + " FROM [" + sel.Table + "] " // Add support for BETWEEN x.x if len(sel.Where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(sel.Where) { for _, token := range loc.Expr { switch token.Type { case "substitute": substituteCount++ - querystr += " ?" + strconv.Itoa(substituteCount) + q += " ?" + strconv.Itoa(substituteCount) case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } // TODO: MSSQL requires ORDER BY for LIMIT if len(sel.Orderby) != 0 { - querystr += " ORDER BY " + q += " ORDER BY " for _, column := range processOrderby(sel.Orderby) { // TODO: We might want to escape this column - querystr += column.Column + " " + strings.ToUpper(column.Order) + "," + q += column.Column + " " + strings.ToUpper(column.Order) + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] } else if sel.Limit != "" { - key, ok := adapter.keys[sel.Table] + key, ok := a.keys[sel.Table] if ok { - querystr += " ORDER BY [" + sel.Table + "].[" + key + "]" + q += " ORDER BY [" + sel.Table + "].[" + key + "]" } } @@ -925,9 +914,9 @@ func (adapter *MssqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel D if limiter.Offset != "" { if limiter.Offset == "?" { substituteCount++ - querystr += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" + q += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" } else { - querystr += " OFFSET " + limiter.Offset + " ROWS" + q += " OFFSET " + limiter.Offset + " ROWS" } } @@ -937,21 +926,20 @@ func (adapter *MssqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel D substituteCount++ limiter.MaxCount = "?" + strconv.Itoa(substituteCount) } - querystr += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " + q += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " } } - querystr = strings.TrimSpace(querystr) + q = strings.TrimSpace(q) // TODO: Run this on debug mode? if name[0] == '_' && sel.Limit == "" { - log.Print(name+" query: ", querystr) + log.Print(name+" query: ", q) } - - adapter.pushStatement(name, "insert", querystr) - return querystr, nil + a.pushStatement(name, "insert", q) + return q, nil } -func (adapter *MssqlAdapter) simpleJoin(name string, ins DBInsert, sel DBJoin, joinType string) (string, error) { +func (a *MssqlAdapter) simpleJoin(name string, ins DBInsert, sel DBJoin, joinType string) (string, error) { // TODO: More errors. // TODO: Add this to the MySQL adapter in order to make this problem more discoverable? if len(sel.Orderby) == 0 && sel.Limit != "" { @@ -959,24 +947,23 @@ func (adapter *MssqlAdapter) simpleJoin(name string, ins DBInsert, sel DBJoin, j } /* Insert */ - var querystr = "INSERT INTO [" + ins.Table + "] (" + q := "INSERT INTO [" + ins.Table + "] (" // Escape the column names, just in case we've used a reserved keyword for _, column := range processColumns(ins.Columns) { if column.Type == "function" { - querystr += column.Left + "," + q += column.Left + "," } else { - querystr += "[" + column.Left + "]," + q += "[" + column.Left + "]," } } - querystr = querystr[0:len(querystr)-1] + ") SELECT " + q = q[0:len(q)-1] + ") SELECT " /* Select */ - var substituteCount = 0 + substituteCount := 0 for _, column := range processColumns(sel.Columns) { var source, alias string - // Escape the column names, just in case we've used a reserved keyword if column.Table != "" { source = "[" + column.Table + "].[" + column.Left + "]" @@ -985,68 +972,67 @@ func (adapter *MssqlAdapter) simpleJoin(name string, ins DBInsert, sel DBJoin, j } else { source = "[" + column.Left + "]" } - if column.Alias != "" { alias = " AS '" + column.Alias + "'" } - querystr += source + alias + "," + q += source + alias + "," } // Remove the trailing comma - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] - querystr += " FROM [" + sel.Table1 + "] " + joinType + " JOIN [" + sel.Table2 + "] ON " - for _, joiner := range processJoiner(sel.Joiners) { - querystr += "[" + joiner.LeftTable + "].[" + joiner.LeftColumn + "] " + joiner.Operator + " [" + joiner.RightTable + "].[" + joiner.RightColumn + "] AND " + q += " FROM [" + sel.Table1 + "] " + joinType + " JOIN [" + sel.Table2 + "] ON " + for _, j := range processJoiner(sel.Joiners) { + q += "[" + j.LeftTable + "].[" + j.LeftColumn + "] " + j.Operator + " [" + j.RightTable + "].[" + j.RightColumn + "] AND " } // Remove the trailing AND - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] // Add support for BETWEEN x.x if len(sel.Where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(sel.Where) { for _, token := range loc.Expr { switch token.Type { case "substitute": substituteCount++ - querystr += " ?" + strconv.Itoa(substituteCount) + q += " ?" + strconv.Itoa(substituteCount) case "function", "operator", "number", "or": // TODO: Split the function case off to speed things up if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": halves := strings.Split(token.Contents, ".") if len(halves) == 2 { - querystr += " [" + halves[0] + "].[" + halves[1] + "]" + q += " [" + halves[0] + "].[" + halves[1] + "]" } else { - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" } case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } // TODO: MSSQL requires ORDER BY for LIMIT if len(sel.Orderby) != 0 { - querystr += " ORDER BY " + q += " ORDER BY " for _, column := range processOrderby(sel.Orderby) { log.Print("column: ", column) // TODO: We might want to escape this column - querystr += column.Column + " " + strings.ToUpper(column.Order) + "," + q += column.Column + " " + strings.ToUpper(column.Order) + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] } else if sel.Limit != "" { - key, ok := adapter.keys[sel.Table1] + key, ok := a.keys[sel.Table1] if ok { - querystr += " ORDER BY [" + sel.Table1 + "].[" + key + "]" + q += " ORDER BY [" + sel.Table1 + "].[" + key + "]" } } @@ -1055,9 +1041,9 @@ func (adapter *MssqlAdapter) simpleJoin(name string, ins DBInsert, sel DBJoin, j if limiter.Offset != "" { if limiter.Offset == "?" { substituteCount++ - querystr += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" + q += " OFFSET ?" + strconv.Itoa(substituteCount) + " ROWS" } else { - querystr += " OFFSET " + limiter.Offset + " ROWS" + q += " OFFSET " + limiter.Offset + " ROWS" } } @@ -1067,37 +1053,36 @@ func (adapter *MssqlAdapter) simpleJoin(name string, ins DBInsert, sel DBJoin, j substituteCount++ limiter.MaxCount = "?" + strconv.Itoa(substituteCount) } - querystr += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " + q += " FETCH NEXT " + limiter.MaxCount + " ROWS ONLY " } } - querystr = strings.TrimSpace(querystr) + q = strings.TrimSpace(q) // TODO: Run this on debug mode? if name[0] == '_' && sel.Limit == "" { - log.Print(name+" query: ", querystr) + log.Print(name+" query: ", q) } - - adapter.pushStatement(name, "insert", querystr) - return querystr, nil + a.pushStatement(name, "insert", q) + return q, nil } -func (adapter *MssqlAdapter) SimpleInsertLeftJoin(name string, ins DBInsert, sel DBJoin) (string, error) { - return adapter.simpleJoin(name, ins, sel, "LEFT") +func (a *MssqlAdapter) SimpleInsertLeftJoin(name string, ins DBInsert, sel DBJoin) (string, error) { + return a.simpleJoin(name, ins, sel, "LEFT") } -func (adapter *MssqlAdapter) SimpleInsertInnerJoin(name string, ins DBInsert, sel DBJoin) (string, error) { - return adapter.simpleJoin(name, ins, sel, "INNER") +func (a *MssqlAdapter) SimpleInsertInnerJoin(name string, ins DBInsert, sel DBJoin) (string, error) { + return a.simpleJoin(name, ins, sel, "INNER") } -func (adapter *MssqlAdapter) SimpleCount(name string, table string, where string, limit string) (string, error) { +func (a *MssqlAdapter) SimpleCount(name string, table string, where string, limit string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - var querystr = "SELECT COUNT(*) FROM [" + table + "]" + q := "SELECT COUNT(*) FROM [" + table + "]" // TODO: Add support for BETWEEN x.x if len(where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(where) { for _, token := range loc.Expr { switch token.Type { @@ -1105,40 +1090,39 @@ func (adapter *MssqlAdapter) SimpleCount(name string, table string, where string if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "GETUTCDATE()" } - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " [" + token.Contents + "]" + q += " [" + token.Contents + "]" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } - if limit != "" { - querystr += " LIMIT " + limit + q += " LIMIT " + limit } - querystr = strings.TrimSpace(querystr) - adapter.pushStatement(name, "select", querystr) - return querystr, nil + q = strings.TrimSpace(q) + a.pushStatement(name, "select", q) + return q, nil } -func (adapter *MssqlAdapter) Builder() *prebuilder { - return &prebuilder{adapter} +func (a *MssqlAdapter) Builder() *prebuilder { + return &prebuilder{a} } -func (adapter *MssqlAdapter) Write() error { +func (a *MssqlAdapter) Write() error { var stmts, body string - for _, name := range adapter.BufferOrder { + for _, name := range a.BufferOrder { if name == "" { continue } - stmt := adapter.Buffer[name] + stmt := a.Buffer[name] // TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :( if stmt.Type != "create-table" { stmts += "\t" + name + " *sql.Stmt\n" @@ -1184,15 +1168,15 @@ func _gen_mssql() (err error) { } // Internal methods, not exposed in the interface -func (adapter *MssqlAdapter) pushStatement(name string, stype string, querystr string) { +func (a *MssqlAdapter) pushStatement(name string, stype string, querystr string) { if name == "" { return } - adapter.Buffer[name] = DBStmt{querystr, stype} - adapter.BufferOrder = append(adapter.BufferOrder, name) + a.Buffer[name] = DBStmt{querystr, stype} + a.BufferOrder = append(a.BufferOrder, name) } -func (adapter *MssqlAdapter) stringyType(ctype string) bool { +func (a *MssqlAdapter) stringyType(ctype string) bool { ctype = strings.ToLower(ctype) return ctype == "char" || ctype == "varchar" || ctype == "datetime" || ctype == "text" || ctype == "nvarchar" } @@ -1201,6 +1185,6 @@ type SetPrimaryKeys interface { SetPrimaryKeys(keys map[string]string) } -func (adapter *MssqlAdapter) SetPrimaryKeys(keys map[string]string) { - adapter.keys = keys +func (a *MssqlAdapter) SetPrimaryKeys(keys map[string]string) { + a.keys = keys } diff --git a/query_gen/pgsql.go b/query_gen/pgsql.go index b8ff7a86..1eecbd3e 100644 --- a/query_gen/pgsql.go +++ b/query_gen/pgsql.go @@ -21,39 +21,39 @@ type PgsqlAdapter struct { } // GetName gives you the name of the database adapter. In this case, it's pgsql -func (adapter *PgsqlAdapter) GetName() string { - return adapter.Name +func (a *PgsqlAdapter) GetName() string { + return a.Name } -func (adapter *PgsqlAdapter) GetStmt(name string) DBStmt { - return adapter.Buffer[name] +func (a *PgsqlAdapter) GetStmt(name string) DBStmt { + return a.Buffer[name] } -func (adapter *PgsqlAdapter) GetStmts() map[string]DBStmt { - return adapter.Buffer +func (a *PgsqlAdapter) GetStmts() map[string]DBStmt { + return a.Buffer } // TODO: Implement this -func (adapter *PgsqlAdapter) BuildConn(config map[string]string) (*sql.DB, error) { +func (a *PgsqlAdapter) BuildConn(config map[string]string) (*sql.DB, error) { return nil, nil } -func (adapter *PgsqlAdapter) DbVersion() string { +func (a *PgsqlAdapter) DbVersion() string { return "SELECT version()" } -func (adapter *PgsqlAdapter) DropTable(name string, table string) (string, error) { +func (a *PgsqlAdapter) DropTable(name string, table string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - querystr := "DROP TABLE IF EXISTS \"" + table + "\";" - adapter.pushStatement(name, "drop-table", querystr) - return querystr, nil + q := "DROP TABLE IF EXISTS \"" + table + "\";" + a.pushStatement(name, "drop-table", q) + return q, nil } // TODO: Implement this // We may need to change the CreateTable API to better suit PGSQL and the other database drivers which are coming up -func (adapter *PgsqlAdapter) CreateTable(name string, table string, charset string, collation string, columns []DBTableColumn, keys []DBTableKey) (string, error) { +func (a *PgsqlAdapter) CreateTable(name string, table string, charset string, collation string, columns []DBTableColumn, keys []DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -61,7 +61,7 @@ func (adapter *PgsqlAdapter) CreateTable(name string, table string, charset stri return "", errors.New("You can't have a table with no columns") } - var querystr = "CREATE TABLE \"" + table + "\" (" + q := "CREATE TABLE \"" + table + "\" (" for _, column := range columns { if column.AutoIncrement { column.Type = "serial" @@ -79,41 +79,40 @@ func (adapter *PgsqlAdapter) CreateTable(name string, table string, charset stri var end string if column.Default != "" { end = " DEFAULT " - if adapter.stringyType(column.Type) && column.Default != "''" { + if a.stringyType(column.Type) && column.Default != "''" { end += "'" + column.Default + "'" } else { end += column.Default } } - if !column.Null { end += " not null" } - querystr += "\n\t`" + column.Name + "` " + column.Type + size + end + "," + q += "\n\t`" + column.Name + "` " + column.Type + size + end + "," } if len(keys) > 0 { for _, key := range keys { - querystr += "\n\t" + key.Type + q += "\n\t" + key.Type if key.Type != "unique" { - querystr += " key" + q += " key" } - querystr += "(" + q += "(" for _, column := range strings.Split(key.Columns, ",") { - querystr += "`" + column + "`," + q += "`" + column + "`," } - querystr = querystr[0:len(querystr)-1] + ")," + q = q[0:len(q)-1] + ")," } } - querystr = querystr[0:len(querystr)-1] + "\n);" - adapter.pushStatement(name, "create-table", querystr) - return querystr, nil + q = q[0:len(q)-1] + "\n);" + a.pushStatement(name, "create-table", q) + return q, nil } // TODO: Implement this -func (adapter *PgsqlAdapter) AddColumn(name string, table string, column DBTableColumn, key *DBTableKey) (string, error) { +func (a *PgsqlAdapter) AddColumn(name string, table string, column DBTableColumn, key *DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -122,7 +121,7 @@ func (adapter *PgsqlAdapter) AddColumn(name string, table string, column DBTable // TODO: Implement this // TODO: Test to make sure everything works here -func (adapter *PgsqlAdapter) AddIndex(name string, table string, iname string, colname string) (string, error) { +func (a *PgsqlAdapter) AddIndex(name string, table string, iname string, colname string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -137,7 +136,7 @@ func (adapter *PgsqlAdapter) AddIndex(name string, table string, iname string, c // TODO: Implement this // TODO: Test to make sure everything works here -func (adapter *PgsqlAdapter) AddKey(name string, table string, column string, key DBTableKey) (string, error) { +func (a *PgsqlAdapter) AddKey(name string, table string, column string, key DBTableKey) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -149,17 +148,17 @@ func (adapter *PgsqlAdapter) AddKey(name string, table string, column string, ke // TODO: Implement this // TODO: Test to make sure everything works here -func (adapter *PgsqlAdapter) AddForeignKey(name string, table string, column string, ftable string, fcolumn string, cascade bool) (out string, e error) { +func (a *PgsqlAdapter) AddForeignKey(name string, table string, column string, ftable string, fcolumn string, cascade bool) (out string, e error) { var c = func(str string, val bool) { if e != nil || !val { return } - e = errors.New("You need a "+str+" for this table") + e = errors.New("You need a " + str + " for this table") } - c("name",table=="") - c("column",column=="") - c("ftable",ftable=="") - c("fcolumn",fcolumn=="") + c("name", table == "") + c("column", column == "") + c("ftable", ftable == "") + c("fcolumn", fcolumn == "") if e != nil { return "", e } @@ -168,14 +167,14 @@ func (adapter *PgsqlAdapter) AddForeignKey(name string, table string, column str // TODO: Test this // ! We need to get the last ID out of this somehow, maybe add returning to every query? Might require some sort of wrapper over the sql statements -func (adapter *PgsqlAdapter) SimpleInsert(name string, table string, columns string, fields string) (string, error) { +func (a *PgsqlAdapter) SimpleInsert(name string, table string, columns string, fields string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } - var querystr = "INSERT INTO \"" + table + "\"(" + q := "INSERT INTO \"" + table + "\"(" if columns != "" { - querystr += adapter.buildColumns(columns) + ") VALUES (" + q += a.buildColumns(columns) + ") VALUES (" for _, field := range processFields(fields) { nameLen := len(field.Name) if field.Name[0] == '"' && field.Name[nameLen-1] == '"' && nameLen >= 3 { @@ -184,35 +183,35 @@ func (adapter *PgsqlAdapter) SimpleInsert(name string, table string, columns str if field.Name[0] == '\'' && field.Name[nameLen-1] == '\'' && nameLen >= 3 { field.Name = "'" + strings.Replace(field.Name[1:nameLen-1], "'", "''", -1) + "'" } - querystr += field.Name + "," + q += field.Name + "," } - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] } else { - querystr += ") VALUES (" + q += ") VALUES (" } - querystr += ")" + q += ")" - adapter.pushStatement(name, "insert", querystr) - return querystr, nil + a.pushStatement(name, "insert", q) + return q, nil } -func (adapter *PgsqlAdapter) buildColumns(columns string) (querystr string) { +func (a *PgsqlAdapter) buildColumns(columns string) (q string) { if columns == "" { return "" } // Escape the column names, just in case we've used a reserved keyword for _, column := range processColumns(columns) { if column.Type == "function" { - querystr += column.Left + "," + q += column.Left + "," } else { - querystr += "\"" + column.Left + "\"," + q += "\"" + column.Left + "\"," } } - return querystr[0 : len(querystr)-1] + return q[0 : len(q)-1] } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleReplace(name string, table string, columns string, fields string) (string, error) { +func (a *PgsqlAdapter) SimpleReplace(name string, table string, columns string, fields string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -226,7 +225,7 @@ func (adapter *PgsqlAdapter) SimpleReplace(name string, table string, columns st } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleUpsert(name string, table string, columns string, fields string, where string) (string, error) { +func (a *PgsqlAdapter) SimpleUpsert(name string, table string, columns string, fields string, where string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -240,7 +239,7 @@ func (adapter *PgsqlAdapter) SimpleUpsert(name string, table string, columns str } // TODO: Implemented, but we need CreateTable and a better installer to *test* it -func (adapter *PgsqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) { +func (a *PgsqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) { if up.table == "" { return "", errors.New("You need a name for this table") } @@ -248,9 +247,9 @@ func (adapter *PgsqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) return "", errors.New("You need to set data in this update statement") } - var querystr = "UPDATE \"" + up.table + "\" SET " + q := "UPDATE \"" + up.table + "\" SET " for _, item := range processSet(up.set) { - querystr += "`" + item.Column + "` =" + q += "`" + item.Column + "` =" for _, token := range item.Expr { switch token.Type { case "function": @@ -258,23 +257,23 @@ func (adapter *PgsqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "LOCALTIMESTAMP()" } - querystr += " " + token.Contents + q += " " + token.Contents case "operator", "number", "substitute", "or": - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " `" + token.Contents + "`" + q += " `" + token.Contents + "`" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" } } - querystr += "," + q += "," } // Remove the trailing comma - querystr = querystr[0 : len(querystr)-1] + q = q[0 : len(q)-1] // Add support for BETWEEN x.x if len(up.where) != 0 { - querystr += " WHERE" + q += " WHERE" for _, loc := range processWhere(up.where) { for _, token := range loc.Expr { switch token.Type { @@ -283,33 +282,33 @@ func (adapter *PgsqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" { token.Contents = "LOCALTIMESTAMP()" } - querystr += " " + token.Contents + q += " " + token.Contents case "operator", "number", "substitute", "or": - querystr += " " + token.Contents + q += " " + token.Contents case "column": - querystr += " `" + token.Contents + "`" + q += " `" + token.Contents + "`" case "string": - querystr += " '" + token.Contents + "'" + q += " '" + token.Contents + "'" default: panic("This token doesn't exist o_o") } } - querystr += " AND" + q += " AND" } - querystr = querystr[0 : len(querystr)-4] + q = q[0 : len(q)-4] } - adapter.pushStatement(up.name, "update", querystr) - return querystr, nil + a.pushStatement(up.name, "update", q) + return q, nil } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleUpdateSelect(up *updatePrebuilder) (string, error) { +func (a *PgsqlAdapter) SimpleUpdateSelect(up *updatePrebuilder) (string, error) { return "", errors.New("not implemented") } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleDelete(name string, table string, where string) (string, error) { +func (a *PgsqlAdapter) SimpleDelete(name string, table string, where string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -320,7 +319,7 @@ func (adapter *PgsqlAdapter) SimpleDelete(name string, table string, where strin } // TODO: Implement this -func (adapter *PgsqlAdapter) ComplexDelete(b *deletePrebuilder) (string, error) { +func (a *PgsqlAdapter) ComplexDelete(b *deletePrebuilder) (string, error) { if b.table == "" { return "", errors.New("You need a name for this table") } @@ -332,7 +331,7 @@ func (adapter *PgsqlAdapter) ComplexDelete(b *deletePrebuilder) (string, error) // TODO: Implement this // We don't want to accidentally wipe tables, so we'll have a separate method for purging tables instead -func (adapter *PgsqlAdapter) Purge(name string, table string) (string, error) { +func (a *PgsqlAdapter) Purge(name string, table string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -340,7 +339,7 @@ func (adapter *PgsqlAdapter) Purge(name string, table string) (string, error) { } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleSelect(name string, table string, columns string, where string, orderby string, limit string) (string, error) { +func (a *PgsqlAdapter) SimpleSelect(name string, table string, columns string, where string, orderby string, limit string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } @@ -351,7 +350,7 @@ func (adapter *PgsqlAdapter) SimpleSelect(name string, table string, columns str } // TODO: Implement this -func (adapter *PgsqlAdapter) ComplexSelect(prebuilder *selectPrebuilder) (string, error) { +func (a *PgsqlAdapter) ComplexSelect(prebuilder *selectPrebuilder) (string, error) { if prebuilder.table == "" { return "", errors.New("You need a name for this table") } @@ -362,7 +361,7 @@ func (adapter *PgsqlAdapter) ComplexSelect(prebuilder *selectPrebuilder) (string } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleLeftJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { +func (a *PgsqlAdapter) SimpleLeftJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { if table1 == "" { return "", errors.New("You need a name for the left table") } @@ -379,7 +378,7 @@ func (adapter *PgsqlAdapter) SimpleLeftJoin(name string, table1 string, table2 s } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleInnerJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { +func (a *PgsqlAdapter) SimpleInnerJoin(name string, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (string, error) { if table1 == "" { return "", errors.New("You need a name for the left table") } @@ -396,39 +395,39 @@ func (adapter *PgsqlAdapter) SimpleInnerJoin(name string, table1 string, table2 } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel DBSelect) (string, error) { +func (a *PgsqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel DBSelect) (string, error) { return "", nil } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleInsertLeftJoin(name string, ins DBInsert, sel DBJoin) (string, error) { +func (a *PgsqlAdapter) SimpleInsertLeftJoin(name string, ins DBInsert, sel DBJoin) (string, error) { return "", nil } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleInsertInnerJoin(name string, ins DBInsert, sel DBJoin) (string, error) { +func (a *PgsqlAdapter) SimpleInsertInnerJoin(name string, ins DBInsert, sel DBJoin) (string, error) { return "", nil } // TODO: Implement this -func (adapter *PgsqlAdapter) SimpleCount(name string, table string, where string, limit string) (string, error) { +func (a *PgsqlAdapter) SimpleCount(name string, table string, where string, limit string) (string, error) { if table == "" { return "", errors.New("You need a name for this table") } return "", nil } -func (adapter *PgsqlAdapter) Builder() *prebuilder { - return &prebuilder{adapter} +func (a *PgsqlAdapter) Builder() *prebuilder { + return &prebuilder{a} } -func (adapter *PgsqlAdapter) Write() error { +func (a *PgsqlAdapter) Write() error { var stmts, body string - for _, name := range adapter.BufferOrder { + for _, name := range a.BufferOrder { if name[0] == '_' { continue } - stmt := adapter.Buffer[name] + stmt := a.Buffer[name] // TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :( if stmt.Type != "create-table" { stmts += "\t" + name + " *sql.Stmt\n" @@ -473,15 +472,15 @@ func _gen_pgsql() (err error) { } // Internal methods, not exposed in the interface -func (adapter *PgsqlAdapter) pushStatement(name string, stype string, querystr string) { +func (a *PgsqlAdapter) pushStatement(name string, stype string, q string) { if name == "" { return } - adapter.Buffer[name] = DBStmt{querystr, stype} - adapter.BufferOrder = append(adapter.BufferOrder, name) + a.Buffer[name] = DBStmt{q, stype} + a.BufferOrder = append(a.BufferOrder, name) } -func (adapter *PgsqlAdapter) stringyType(ctype string) bool { +func (a *PgsqlAdapter) stringyType(ctype string) bool { ctype = strings.ToLower(ctype) return ctype == "char" || ctype == "varchar" || ctype == "timestamp" || ctype == "text" }