From 5edc78116fe46a7d001db52d80a78f97756ac1ad Mon Sep 17 00:00:00 2001 From: sammyrnycreal Date: Mon, 14 Feb 2022 14:13:26 -0500 Subject: [PATCH] Fixed the use of "or" to be " OR ", to account for words that contain "or" or "and" (e.g., 'score', 'band') in a sql statement as the name of a field. --- clause/where.go | 39 ++++++++++++++++++++++----------------- clause/where_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/clause/where.go b/clause/where.go index 20a011362..10b6df856 100644 --- a/clause/where.go +++ b/clause/where.go @@ -4,6 +4,11 @@ import ( "strings" ) +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + // Where where clause type Where struct { Exprs []Expression @@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) { } } - buildExprs(where.Exprs, builder, " AND ") + buildExprs(where.Exprs, builder, AndWithSpace) } func buildExprs(exprs []Expression, builder Builder, joinCond string) { @@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") + builder.WriteString(OrWithSpace) } else { builder.WriteString(joinCond) } @@ -46,23 +51,23 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case OrConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case AndConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case Expr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) case NamedExpr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } @@ -110,10 +115,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) builder.WriteByte(')') } else { - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) } } @@ -131,10 +136,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) builder.WriteByte(')') } else { - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) } } @@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) { for idx, c := range not.Exprs { if idx > 0 { - builder.WriteString(" AND ") + builder.WriteString(AndWithSpace) } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { @@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) { builder.WriteString("NOT ") e, wrapInParentheses := c.(Expr) if wrapInParentheses { - sql := strings.ToLower(e.SQL) - if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } diff --git a/clause/where_test.go b/clause/where_test.go index 272c7b76b..35e3dbeeb 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -66,6 +66,45 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), + }, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + []interface{}{"1", 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", + []interface{}{"1", 100}, + }, } for idx, result := range results {