diff --git a/CHANGELOG.md b/CHANGELOG.md index bf3b49d..9b762cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 This is how `AfterSeleect/Insert/Update/DeleteHooks` hooks are now implemented. - Added `Type() QueryType` method to `bob.Query` to get the type of query it is. Available constants are `Unknown, Select, Insert, Update, Delete`. - Postgres and SQLite Update/Delete queries now refresh the models after the query is executed. This is enabled by the `RETURNING` clause, so it is not available in MySQL. +- Added the `Case()` starter to all dialects to build `CASE` expressions. (thanks @k4n4ry) ### Changed diff --git a/dialect/mysql/select_test.go b/dialect/mysql/select_test.go index ac10768..a4c048e 100644 --- a/dialect/mysql/select_test.go +++ b/dialect/mysql/select_test.go @@ -30,6 +30,34 @@ func TestSelect(t *testing.T) { sm.Where(mysql.Quote("id").In(mysql.Arg(100, 200, 300))), ), }, + "case with else": { + ExpectedSQL: "SELECT id, name, (CASE WHEN (`id` = '1') THEN 'A' ELSE 'B' END) AS `C` FROM users", + Query: mysql.Select( + sm.Columns( + "id", + "name", + mysql.Case(). + When(mysql.Quote("id").EQ(mysql.S("1")), mysql.S("A")). + Else(mysql.S("B")). + As("C"), + ), + sm.From("users"), + ), + }, + "case without else": { + ExpectedSQL: "SELECT id, name, (CASE WHEN (`id` = '1') THEN 'A' END) AS `C` FROM users", + Query: mysql.Select( + sm.Columns( + "id", + "name", + mysql.Case(). + When(mysql.Quote("id").EQ(mysql.S("1")), mysql.S("A")). + End(). + As("C"), + ), + sm.From("users"), + ), + }, "select distinct": { ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (`id` IN (?, ?, ?))", ExpectedArgs: []any{100, 200, 300}, diff --git a/dialect/mysql/starters.go b/dialect/mysql/starters.go index 870d63e..66d8910 100644 --- a/dialect/mysql/starters.go +++ b/dialect/mysql/starters.go @@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression { func Cast(exp bob.Expression, typname string) Expression { return bmod.Cast(exp, typname) } + +// SQL: CASE WHEN a THEN b ELSE c END +// Go: mysql.Case().When("a", "b").Else("c") +func Case() expr.CaseChain[Expression, Expression] { + return expr.NewCase[Expression, Expression]() +} diff --git a/dialect/psql/select_test.go b/dialect/psql/select_test.go index 8f1a4cd..82ba0bf 100644 --- a/dialect/psql/select_test.go +++ b/dialect/psql/select_test.go @@ -29,6 +29,34 @@ func TestSelect(t *testing.T) { sm.Where(psql.Quote("id").In(psql.Arg(100, 200, 300))), ), }, + "case with else": { + ExpectedSQL: `SELECT id, name, (CASE WHEN (id = '1') THEN 'A' ELSE 'B' END) AS "C" FROM users`, + Query: psql.Select( + sm.Columns( + "id", + "name", + psql.Case(). + When(psql.Quote("id").EQ(psql.S("1")), psql.S("A")). + Else(psql.S("B")). + As("C"), + ), + sm.From("users"), + ), + }, + "case without else": { + ExpectedSQL: `SELECT id, name, (CASE WHEN (id = '1') THEN 'A' END) AS "C" FROM users`, + Query: psql.Select( + sm.Columns( + "id", + "name", + psql.Case(). + When(psql.Quote("id").EQ(psql.S("1")), psql.S("A")). + End(). + As("C"), + ), + sm.From("users"), + ), + }, "select distinct": { ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (id IN ($1, $2, $3))", ExpectedArgs: []any{100, 200, 300}, diff --git a/dialect/psql/starters.go b/dialect/psql/starters.go index 3c1db06..19d1efb 100644 --- a/dialect/psql/starters.go +++ b/dialect/psql/starters.go @@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression { func Cast(exp bob.Expression, typname string) Expression { return bmod.Cast(exp, typname) } + +// SQL: CASE WHEN a THEN b ELSE c END +// Go: psql.Case().When("a", "b").Else("c") +func Case() expr.CaseChain[Expression, Expression] { + return expr.NewCase[Expression, Expression]() +} diff --git a/dialect/sqlite/select_test.go b/dialect/sqlite/select_test.go index b8033c1..bb6abb9 100644 --- a/dialect/sqlite/select_test.go +++ b/dialect/sqlite/select_test.go @@ -30,6 +30,34 @@ func TestSelect(t *testing.T) { sm.Where(sqlite.Quote("id").In(sqlite.Arg(100, 200, 300))), ), }, + "case with else": { + ExpectedSQL: `SELECT id, name, (CASE WHEN ("id" = '1') THEN 'A' ELSE 'B' END) AS "C" FROM users`, + Query: sqlite.Select( + sm.Columns( + "id", + "name", + sqlite.Case(). + When(sqlite.Quote("id").EQ(sqlite.S("1")), sqlite.S("A")). + Else(sqlite.S("B")). + As("C"), + ), + sm.From("users"), + ), + }, + "case without else": { + ExpectedSQL: `SELECT id, name, (CASE WHEN ("id" = '1') THEN 'A' END) AS "C" FROM users`, + Query: sqlite.Select( + sm.Columns( + "id", + "name", + sqlite.Case(). + When(sqlite.Quote("id").EQ(sqlite.S("1")), sqlite.S("A")). + End(). + As("C"), + ), + sm.From("users"), + ), + }, "select distinct": { ExpectedSQL: `SELECT DISTINCT id, name FROM users WHERE ("id" IN (?1, ?2, ?3))`, ExpectedArgs: []any{100, 200, 300}, diff --git a/dialect/sqlite/starters.go b/dialect/sqlite/starters.go index 33c2cda..04cf5c3 100644 --- a/dialect/sqlite/starters.go +++ b/dialect/sqlite/starters.go @@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression { func Cast(exp bob.Expression, typname string) Expression { return bmod.Cast(exp, typname) } + +// SQL: CASE WHEN a THEN b ELSE c END +// Go: sqlite.Case().When("a", "b").Else("c") +func Case() expr.CaseChain[Expression, Expression] { + return expr.NewCase[Expression, Expression]() +} diff --git a/expr/case.go b/expr/case.go new file mode 100644 index 0000000..381e927 --- /dev/null +++ b/expr/case.go @@ -0,0 +1,83 @@ +package expr + +import ( + "context" + "errors" + "io" + + "github.com/stephenafamo/bob" +) + +type ( + caseExpr struct { + whens []when + elseExpr bob.Expression + } + when struct { + condition bob.Expression + then bob.Expression + } +) + +func (c caseExpr) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + var args []any + + if len(c.whens) == 0 { + return nil, errors.New("case must have at least one when expression") + } + + w.Write([]byte("CASE")) + for _, when := range c.whens { + w.Write([]byte(" WHEN ")) + whenArgs, err := when.condition.WriteSQL(ctx, w, d, start+len(args)) + if err != nil { + return nil, err + } + args = append(args, whenArgs...) + + w.Write([]byte(" THEN ")) + thenArgs, err := when.then.WriteSQL(ctx, w, d, start+len(args)) + if err != nil { + return nil, err + } + args = append(args, thenArgs...) + } + + if c.elseExpr != nil { + w.Write([]byte(" ELSE ")) + elseArgs, err := c.elseExpr.WriteSQL(ctx, w, d, start+len(args)) + if err != nil { + return nil, err + } + args = append(args, elseArgs...) + } + w.Write([]byte(" END")) + + return args, nil +} + +type CaseChain[T bob.Expression, B builder[T]] func() caseExpr + +func NewCase[T bob.Expression, B builder[T]]() CaseChain[T, B] { + return CaseChain[T, B](func() caseExpr { return caseExpr{} }) +} + +func (cc CaseChain[T, B]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return cc().WriteSQL(ctx, w, d, start) +} + +func (cc CaseChain[T, B]) When(condition, then bob.Expression) CaseChain[T, B] { + c := cc() + c.whens = append(c.whens, when{condition: condition, then: then}) + return CaseChain[T, B](func() caseExpr { return c }) +} + +func (cc CaseChain[T, B]) Else(then bob.Expression) T { + c := cc() + c.elseExpr = then + return X[T, B](c) +} + +func (cc CaseChain[T, B]) End() T { + return X[T, B](cc()) +}