Skip to content

Commit

Permalink
Add QuoteTo method
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 8, 2020
1 parent 5fce175 commit 078ba75
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 43 deletions.
7 changes: 5 additions & 2 deletions dialects/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"regexp"
"strconv"
"strings"

_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
Expand Down Expand Up @@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "@p" + strconv.Itoa(len(stmt.Vars))
}

func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // `name`
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('"')
builder.WriteString(str)
builder.WriteByte('"')
}

var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
Expand Down
7 changes: 5 additions & 2 deletions dialects/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"math"
"strings"

_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
Expand Down Expand Up @@ -39,8 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
}

func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
Expand Down
7 changes: 5 additions & 2 deletions dialects/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"regexp"
"strconv"
"strings"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
Expand Down Expand Up @@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "$" + strconv.Itoa(len(stmt.Vars))
}

func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // "name"
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('"')
builder.WriteString(str)
builder.WriteByte('"')
}

var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
Expand Down
7 changes: 5 additions & 2 deletions dialects/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sqlite

import (
"database/sql"
"strings"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
Expand Down Expand Up @@ -38,8 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
}

func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
Expand Down
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ module github.com/jinzhu/gorm
go 1.13

require (
github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect
github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.3.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
)
1 change: 0 additions & 1 deletion gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {

if dialector != nil {
err = dialector.Initialize(db)
db.quoteChars = dialector.QuoteChars()
}
return
}
Expand Down
3 changes: 2 additions & 1 deletion interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorm
import (
"context"
"database/sql"
"strings"

"github.com/jinzhu/gorm/schema"
)
Expand All @@ -13,7 +14,7 @@ type Dialector interface {
Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string
BindVar(stmt *Statement, v interface{}) string
QuoteChars() [2]byte
QuoteTo(*strings.Builder, string)
Explain(sql string, vars ...interface{}) string
}

Expand Down
55 changes: 24 additions & 31 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,65 +76,58 @@ func (stmt *Statement) WriteByte(c byte) (err error) {
return stmt.SQL.WriteByte(c)
}

// WriteQuoted write quoted field
func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
_, err = stmt.SQL.WriteString(stmt.Quote(field))
return
// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) error {
stmt.QuoteTo(&stmt.SQL, value)
return nil
}

// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string {
var str strings.Builder
str.WriteByte(stmt.DB.quoteChars[0])

// QuoteTo write quoted value to writer
func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) {
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
str.WriteString(stmt.Table)
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
} else {
str.WriteString(v.Name)
stmt.DB.Dialector.QuoteTo(writer, v.Name)
}

if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
}
case clause.Column:
if v.Table != "" {
if v.Table == clause.CurrentTable {
str.WriteString(stmt.Table)
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
} else {
str.WriteString(v.Table)
stmt.DB.Dialector.QuoteTo(writer, v.Table)
}
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteByte('.')
str.WriteByte(stmt.DB.quoteChars[0])
writer.WriteByte('.')
}

if v.Name == clause.PrimaryKey {
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
}
} else {
str.WriteString(v.Name)
stmt.DB.Dialector.QuoteTo(writer, v.Name)
}

if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
}
default:
str.WriteString(fmt.Sprint(field))
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
}
}

str.WriteByte(stmt.DB.quoteChars[1])
return str.String()
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
}

// Write write string
Expand Down
8 changes: 6 additions & 2 deletions tests/dummy_dialecter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package tests

import (
"strings"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema"
Expand All @@ -21,8 +23,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

func (DummyDialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
func (DummyDialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
}

func (DummyDialector) Explain(sql string, vars ...interface{}) string {
Expand Down

0 comments on commit 078ba75

Please sign in to comment.