From 504f42760a2f4be453c51798bc075dc7fd414bd5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 17:07:00 +0800 Subject: [PATCH] Refactor clause Writer --- clause/clause.go | 11 ++++--- clause/delete.go | 4 +-- clause/expression.go | 60 +++++++++++++++++++++-------------- clause/from.go | 8 ++--- clause/group_by.go | 2 +- clause/insert.go | 4 +-- clause/limit.go | 8 ++--- clause/locking.go | 8 ++--- clause/order_by.go | 2 +- clause/set.go | 2 +- clause/update.go | 2 +- clause/values.go | 6 ++-- clause/where.go | 12 +++---- dialects/mssql/mssql.go | 10 +++--- dialects/mysql/mysql.go | 10 +++--- dialects/postgres/postgres.go | 10 +++--- dialects/sqlite/sqlite.go | 10 +++--- interfaces.go | 4 +-- statement.go | 41 ++++++++++-------------- tests/dummy_dialecter.go | 11 +++---- 20 files changed, 117 insertions(+), 108 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index df8e3a57e..59b229cec 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -12,13 +12,16 @@ type ClauseBuilder interface { Build(Clause, Builder) } +type Writer interface { + WriteByte(byte) error + WriteString(string) (int, error) +} + // Builder builder interface type Builder interface { - WriteByte(byte) error - Write(sql ...string) error + Writer WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string + AddVar(Writer, ...interface{}) } // Clause diff --git a/clause/delete.go b/clause/delete.go index 2a622b45e..fc462cd7f 100644 --- a/clause/delete.go +++ b/clause/delete.go @@ -9,11 +9,11 @@ func (d Delete) Name() string { } func (d Delete) Build(builder Builder) { - builder.Write("DELETE") + builder.WriteString("DELETE") if d.Modifier != "" { builder.WriteByte(' ') - builder.Write(d.Modifier) + builder.WriteString(d.Modifier) } } diff --git a/clause/expression.go b/clause/expression.go index d72db08d0..8150f8387 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,9 +1,5 @@ package clause -import ( - "strings" -) - // Expression expression interface type Expression interface { Build(builder Builder) @@ -22,11 +18,15 @@ type Expr struct { // Build build raw expression func (expr Expr) Build(builder Builder) { - sql := expr.SQL - for _, v := range expr.Vars { - sql = strings.Replace(sql, "?", builder.AddVar(v), 1) + var idx int + for _, v := range []byte(expr.SQL) { + if v == '?' { + builder.AddVar(builder, expr.Vars[idx]) + idx++ + } else { + builder.WriteByte(v) + } } - builder.Write(sql) } // IN Whether a value is within a set of values @@ -40,11 +40,14 @@ func (in IN) Build(builder Builder) { switch len(in.Values) { case 0: - builder.Write(" IN (NULL)") + builder.WriteString(" IN (NULL)") case 1: - builder.Write(" = ", builder.AddVar(in.Values...)) + builder.WriteString(" = ") + builder.AddVar(builder, in.Values...) default: - builder.Write(" IN (", builder.AddVar(in.Values...), ")") + builder.WriteString(" IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') } } @@ -52,9 +55,12 @@ func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: - builder.Write(" <> ", builder.AddVar(in.Values...)) + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values...) default: - builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") + builder.WriteString(" NOT IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') } } @@ -68,9 +74,10 @@ func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) if eq.Value == nil { - builder.Write(" IS NULL") + builder.WriteString(" IS NULL") } else { - builder.Write(" = ", builder.AddVar(eq.Value)) + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) } } @@ -85,9 +92,10 @@ func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) if neq.Value == nil { - builder.Write(" IS NOT NULL") + builder.WriteString(" IS NOT NULL") } else { - builder.Write(" <> ", builder.AddVar(neq.Value)) + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) } } @@ -100,7 +108,8 @@ type Gt Eq func (gt Gt) Build(builder Builder) { builder.WriteQuoted(gt.Column) - builder.Write(" > ", builder.AddVar(gt.Value)) + builder.WriteString(" > ") + builder.AddVar(builder, gt.Value) } func (gt Gt) NegationBuild(builder Builder) { @@ -112,7 +121,8 @@ type Gte Eq func (gte Gte) Build(builder Builder) { builder.WriteQuoted(gte.Column) - builder.Write(" >= ", builder.AddVar(gte.Value)) + builder.WriteString(" >= ") + builder.AddVar(builder, gte.Value) } func (gte Gte) NegationBuild(builder Builder) { @@ -124,7 +134,8 @@ type Lt Eq func (lt Lt) Build(builder Builder) { builder.WriteQuoted(lt.Column) - builder.Write(" < ", builder.AddVar(lt.Value)) + builder.WriteString(" < ") + builder.AddVar(builder, lt.Value) } func (lt Lt) NegationBuild(builder Builder) { @@ -136,7 +147,8 @@ type Lte Eq func (lte Lte) Build(builder Builder) { builder.WriteQuoted(lte.Column) - builder.Write(" <= ", builder.AddVar(lte.Value)) + builder.WriteString(" <= ") + builder.AddVar(builder, lte.Value) } func (lte Lte) NegationBuild(builder Builder) { @@ -148,12 +160,14 @@ type Like Eq func (like Like) Build(builder Builder) { builder.WriteQuoted(like.Column) - builder.Write(" LIKE ", builder.AddVar(like.Value)) + builder.WriteString(" LIKE ") + builder.AddVar(builder, like.Value) } func (like Like) NegationBuild(builder Builder) { builder.WriteQuoted(like.Column) - builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) + builder.WriteString(" NOT LIKE ") + builder.AddVar(builder, like.Value) } // Map diff --git a/clause/from.go b/clause/from.go index f01065b5b..5e8c5d259 100644 --- a/clause/from.go +++ b/clause/from.go @@ -50,18 +50,18 @@ func (from From) Build(builder Builder) { func (join Join) Build(builder Builder) { if join.Type != "" { - builder.Write(string(join.Type)) + builder.WriteString(string(join.Type)) builder.WriteByte(' ') } - builder.Write("JOIN ") + builder.WriteString("JOIN ") builder.WriteQuoted(join.Table) if len(join.ON.Exprs) > 0 { - builder.Write(" ON ") + builder.WriteString(" ON ") join.ON.Build(builder) } else if len(join.Using) > 0 { - builder.Write(" USING (") + builder.WriteString(" USING (") for idx, c := range join.Using { if idx > 0 { builder.WriteByte(',') diff --git a/clause/group_by.go b/clause/group_by.go index a245d50a3..c1383c36a 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -22,7 +22,7 @@ func (groupBy GroupBy) Build(builder Builder) { } if len(groupBy.Having) > 0 { - builder.Write(" HAVING ") + builder.WriteString(" HAVING ") Where{Exprs: groupBy.Having}.Build(builder) } } diff --git a/clause/insert.go b/clause/insert.go index 3f86c98fe..8efaa0352 100644 --- a/clause/insert.go +++ b/clause/insert.go @@ -13,11 +13,11 @@ func (insert Insert) Name() string { // Build build insert clause func (insert Insert) Build(builder Builder) { if insert.Modifier != "" { - builder.Write(insert.Modifier) + builder.WriteString(insert.Modifier) builder.WriteByte(' ') } - builder.Write("INTO ") + builder.WriteString("INTO ") if insert.Table.Name == "" { builder.WriteQuoted(currentTable) } else { diff --git a/clause/limit.go b/clause/limit.go index e30666afc..ba5cf6c43 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -16,12 +16,12 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { if limit.Limit > 0 { - builder.Write("LIMIT ") - builder.Write(strconv.Itoa(limit.Limit)) + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(limit.Limit)) if limit.Offset > 0 { - builder.Write(" OFFSET ") - builder.Write(strconv.Itoa(limit.Offset)) + builder.WriteString(" OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) } } } diff --git a/clause/locking.go b/clause/locking.go index 48b84b34b..3be1063b4 100644 --- a/clause/locking.go +++ b/clause/locking.go @@ -22,16 +22,16 @@ func (f For) Build(builder Builder) { builder.WriteByte(' ') } - builder.Write("FOR ") - builder.Write(locking.Strength) + builder.WriteString("FOR ") + builder.WriteString(locking.Strength) if locking.Table.Name != "" { - builder.Write(" OF ") + builder.WriteString(" OF ") builder.WriteQuoted(locking.Table) } if locking.Options != "" { builder.WriteByte(' ') - builder.Write(locking.Options) + builder.WriteString(locking.Options) } } } diff --git a/clause/order_by.go b/clause/order_by.go index 2734f2bc5..307bf9308 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -24,7 +24,7 @@ func (orderBy OrderBy) Build(builder Builder) { builder.WriteQuoted(column.Column) if column.Desc { - builder.Write(" DESC") + builder.WriteString(" DESC") } } } diff --git a/clause/set.go b/clause/set.go index 3b7e972da..de78b1be3 100644 --- a/clause/set.go +++ b/clause/set.go @@ -19,7 +19,7 @@ func (set Set) Build(builder Builder) { } builder.WriteQuoted(assignment.Column) builder.WriteByte('=') - builder.Write(builder.AddVar(assignment.Value)) + builder.AddVar(builder, assignment.Value) } } else { builder.WriteQuoted(PrimaryColumn) diff --git a/clause/update.go b/clause/update.go index c375b3737..f9d68ac67 100644 --- a/clause/update.go +++ b/clause/update.go @@ -13,7 +13,7 @@ func (update Update) Name() string { // Build build update clause func (update Update) Build(builder Builder) { if update.Modifier != "" { - builder.Write(update.Modifier) + builder.WriteString(update.Modifier) builder.WriteByte(' ') } diff --git a/clause/values.go b/clause/values.go index 2c8dcf89f..a997fc26b 100644 --- a/clause/values.go +++ b/clause/values.go @@ -22,7 +22,7 @@ func (values Values) Build(builder Builder) { } builder.WriteByte(')') - builder.Write(" VALUES ") + builder.WriteString(" VALUES ") for idx, value := range values.Values { if idx > 0 { @@ -30,11 +30,11 @@ func (values Values) Build(builder Builder) { } builder.WriteByte('(') - builder.Write(builder.AddVar(value...)) + builder.AddVar(builder, value...) builder.WriteByte(')') } } else { - builder.Write("DEFAULT VALUES") + builder.WriteString("DEFAULT VALUES") } } diff --git a/clause/where.go b/clause/where.go index 0ee1a1415..08c78b220 100644 --- a/clause/where.go +++ b/clause/where.go @@ -26,9 +26,9 @@ func (where Where) Build(builder Builder) { if expr != nil { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.Write(" OR ") + builder.WriteString(" OR ") } else { - builder.Write(" AND ") + builder.WriteString(" AND ") } } @@ -65,7 +65,7 @@ func (and AndConditions) Build(builder Builder) { } for idx, c := range and.Exprs { if idx > 0 { - builder.Write(" AND ") + builder.WriteString(" AND ") } c.Build(builder) } @@ -91,7 +91,7 @@ func (or OrConditions) Build(builder Builder) { } for idx, c := range or.Exprs { if idx > 0 { - builder.Write(" OR ") + builder.WriteString(" OR ") } c.Build(builder) } @@ -117,13 +117,13 @@ func (not NotConditions) Build(builder Builder) { } for idx, c := range not.Exprs { if idx > 0 { - builder.Write(" AND ") + builder.WriteString(" AND ") } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { negationBuilder.NegationBuild(builder) } else { - builder.Write(" NOT ") + builder.WriteString(" NOT ") c.Build(builder) } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 7e51de75f..0842fa792 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -5,11 +5,11 @@ import ( "fmt" "regexp" "strconv" - "strings" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -42,10 +42,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "@p" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('"') - builder.WriteString(str) - builder.WriteByte('"') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('"') + writer.WriteString(str) + writer.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("@p(\\d+)") diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 55b5a53f3..cff779e3c 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -4,11 +4,11 @@ import ( "database/sql" "fmt" "math" - "strings" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -40,10 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('`') - builder.WriteString(str) - builder.WriteByte('`') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index e90fa4ae4..99569f069 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -5,10 +5,10 @@ import ( "fmt" "regexp" "strconv" - "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -42,10 +42,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "$" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('"') - builder.WriteString(str) - builder.WriteByte('"') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('"') + writer.WriteString(str) + writer.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 8e3cc0589..4105863fb 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -2,10 +2,10 @@ package sqlite import ( "database/sql" - "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -39,10 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('`') - builder.WriteString(str) - builder.WriteByte('`') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/interfaces.go b/interfaces.go index 9859d1fa2..310f801aa 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,8 +3,8 @@ package gorm import ( "context" "database/sql" - "strings" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" ) @@ -14,7 +14,7 @@ type Dialector interface { Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string - QuoteTo(*strings.Builder, string) + QuoteTo(clause.Writer, string) Explain(sql string, vars ...interface{}) string } diff --git a/statement.go b/statement.go index 0190df7c3..e632b4094 100644 --- a/statement.go +++ b/statement.go @@ -34,7 +34,6 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg - placeholders strings.Builder } // StatementOptimizer statement optimizer interface @@ -43,15 +42,12 @@ type StatementOptimizer interface { } // Write write string -func (stmt *Statement) Write(sql ...string) (err error) { - for _, s := range sql { - _, err = stmt.SQL.WriteString(s) - } - return +func (stmt *Statement) WriteString(str string) (int, error) { + return stmt.SQL.WriteString(str) } // Write write string -func (stmt *Statement) WriteByte(c byte) (err error) { +func (stmt *Statement) WriteByte(c byte) error { return stmt.SQL.WriteByte(c) } @@ -62,7 +58,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error { } // QuoteTo write quoted value to writer -func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) { +func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { @@ -110,44 +106,41 @@ func (stmt Statement) Quote(field interface{}) string { } // Write write string -func (stmt *Statement) AddVar(vars ...interface{}) string { - stmt.placeholders = strings.Builder{} - stmt.placeholders.Reset() - +func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { for idx, v := range vars { if idx > 0 { - stmt.placeholders.WriteByte(',') + writer.WriteByte(',') } switch v := v.(type) { case sql.NamedArg: if len(v.Name) > 0 { stmt.NamedVars = append(stmt.NamedVars, v) - stmt.placeholders.WriteByte('@') - stmt.placeholders.WriteString(v.Name) + writer.WriteByte('@') + writer.WriteString(v.Name) } else { stmt.Vars = append(stmt.Vars, v.Value) - stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) + writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) } case clause.Column, clause.Table: - stmt.placeholders.WriteString(stmt.Quote(v)) + stmt.QuoteTo(writer, v) case clause.Expr: - stmt.placeholders.WriteString(v.SQL) + writer.WriteString(v.SQL) stmt.Vars = append(stmt.Vars, v.Vars...) case []interface{}: if len(v) > 0 { - stmt.placeholders.WriteByte('(') - stmt.placeholders.WriteString(stmt.AddVar(v...)) - stmt.placeholders.WriteByte(')') + writer.WriteByte('(') + stmt.skipResetPlacehodler = true + stmt.AddVar(writer, v...) + writer.WriteByte(')') } else { - stmt.placeholders.WriteString("(NULL)") + writer.WriteString("(NULL)") } default: stmt.Vars = append(stmt.Vars, v) - stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) + writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } - return stmt.placeholders.String() } // AddClause add clause diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 9e3146fe8..f6e9d9f9b 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -1,9 +1,8 @@ package tests import ( - "strings" - "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" ) @@ -23,10 +22,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (DummyDialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('`') - builder.WriteString(str) - builder.WriteByte('`') +func (DummyDialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string {