diff --git a/CHANGELOG.md b/CHANGELOG.md index 81064917..212cc0ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added error constants for matching against both specific and generic unique constraint errors raised by the underlying database driver. (thanks @mbezhanov) - Added support for regular expressions in the `only` and `except` table filters. (thanks @mbezhanov) +### Changed + +- `context.Context` is now passed to `Query.WriteQuery()` and `Expression.WriteSQL()` methods. This allows for more control over how the query is built and executed. + This change made is possible to delete some hacks and simplify the codebase. + - The `Name()` and `NameAs()` methods of Views/Tables no longer need the context argument since the context will be passed when writing the expression. The API then becomes cleaner. + - Preloading mods no longer need to store a context internally. `SetLoadContext()` and `GetLoadContext()` have removed. + - The `ToExpr` field in `orm.RelSide` which was used for preloading is no longer needed and has been removed. + ### Removed - Remove MS SQL artifacts. (thanks @mbezhanov) diff --git a/build.go b/build.go index d5e178bc..7ef21488 100644 --- a/build.go +++ b/build.go @@ -1,15 +1,18 @@ package bob -import "bytes" +import ( + "bytes" + "context" +) // MustBuild builds a query and panics on error // useful for initializing queries that need to be reused -func MustBuild(q Query) (string, []any) { - return MustBuildN(q, 1) +func MustBuild(ctx context.Context, q Query) (string, []any) { + return MustBuildN(ctx, q, 1) } -func MustBuildN(q Query, start int) (string, []any) { - sql, args, err := BuildN(q, start) +func MustBuildN(ctx context.Context, q Query, start int) (string, []any) { + sql, args, err := BuildN(ctx, q, start) if err != nil { panic(err) } @@ -18,14 +21,14 @@ func MustBuildN(q Query, start int) (string, []any) { } // Convinient function to build query from start -func Build(q Query) (string, []any, error) { - return BuildN(q, 1) +func Build(ctx context.Context, q Query) (string, []any, error) { + return BuildN(ctx, q, 1) } // Convinient function to build query from a point -func BuildN(q Query, start int) (string, []any, error) { +func BuildN(ctx context.Context, q Query, start int) (string, []any, error) { b := &bytes.Buffer{} - args, err := q.WriteQuery(b, start) + args, err := q.WriteQuery(ctx, b, start) return b.String(), args, err } diff --git a/cached.go b/cached.go index 574942ca..4479c571 100644 --- a/cached.go +++ b/cached.go @@ -1,16 +1,17 @@ package bob import ( + "context" "fmt" "io" ) -func Cache(q Query) (BaseQuery[*cached], error) { - return CacheN(q, 1) +func Cache(ctx context.Context, q Query) (BaseQuery[*cached], error) { + return CacheN(ctx, q, 1) } -func CacheN(q Query, start int) (BaseQuery[*cached], error) { - query, args, err := BuildN(q, start) +func CacheN(ctx context.Context, q Query, start int) (BaseQuery[*cached], error) { + query, args, err := BuildN(ctx, q, start) if err != nil { return BaseQuery[*cached]{}, err } @@ -40,7 +41,7 @@ type cached struct { } // WriteSQL implements Expression. -func (c *cached) WriteSQL(w io.Writer, d Dialect, start int) ([]any, error) { +func (c *cached) WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) { if start != c.start { return nil, WrongStartError{Expected: c.start, Got: start} } diff --git a/clause/combine.go b/clause/combine.go index c1d317fd..ea93b873 100644 --- a/clause/combine.go +++ b/clause/combine.go @@ -1,6 +1,7 @@ package clause import ( + "context" "errors" "io" @@ -25,7 +26,7 @@ func (s *Combine) SetCombine(c Combine) { *s = c } -func (s Combine) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s Combine) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if s.Strategy == "" { return nil, ErrNoCombinationStrategy } @@ -38,7 +39,7 @@ func (s Combine) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte(" ")) } - args, err := bob.Express(w, d, start, s.Query) + args, err := bob.Express(ctx, w, d, start, s.Query) if err != nil { return nil, err } diff --git a/clause/conflict.go b/clause/conflict.go index 234c8eec..061f052e 100644 --- a/clause/conflict.go +++ b/clause/conflict.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -17,10 +18,10 @@ func (c *Conflict) SetConflict(conflict Conflict) { *c = conflict } -func (c Conflict) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c Conflict) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("ON CONFLICT")) - args, err := bob.ExpressIf(w, d, start, c.Target, true, "", "") + args, err := bob.ExpressIf(ctx, w, d, start, c.Target, true, "", "") if err != nil { return nil, err } @@ -28,13 +29,13 @@ func (c Conflict) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte(" DO ")) w.Write([]byte(c.Do)) - setArgs, err := bob.ExpressIf(w, d, start+len(args), c.Set, len(c.Set.Set) > 0, " SET\n", "") + setArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Set, len(c.Set.Set) > 0, " SET\n", "") if err != nil { return nil, err } args = append(args, setArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), c.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Where, len(c.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err @@ -50,17 +51,17 @@ type ConflictTarget struct { Where []any } -func (c ConflictTarget) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c ConflictTarget) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if c.Constraint != "" { - return bob.ExpressIf(w, d, start, c.Constraint, true, " ON CONSTRAINT ", "") + return bob.ExpressIf(ctx, w, d, start, c.Constraint, true, " ON CONSTRAINT ", "") } - args, err := bob.ExpressSlice(w, d, start, c.Columns, " (", ", ", ")") + args, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, " (", ", ", ")") if err != nil { return nil, err } - whereArgs, err := bob.ExpressSlice(w, d, start+len(args), c.Where, " WHERE ", " AND ", "") + whereArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), c.Where, " WHERE ", " AND ", "") if err != nil { return nil, err } diff --git a/clause/cte.go b/clause/cte.go index 4d974023..769caea1 100644 --- a/clause/cte.go +++ b/clause/cte.go @@ -1,6 +1,7 @@ package clause import ( + "context" "fmt" "io" @@ -16,9 +17,9 @@ type CTE struct { Cycle CTECycle } -func (c CTE) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c CTE) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(c.Name)) - _, err := bob.ExpressSlice(w, d, start, c.Columns, "(", ", ", ")") + _, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, "(", ", ", ")") if err != nil { return nil, err } @@ -36,20 +37,20 @@ func (c CTE) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { } w.Write([]byte("(")) - args, err := c.Query.WriteQuery(w, start) + args, err := c.Query.WriteQuery(ctx, w, start) if err != nil { return nil, err } w.Write([]byte(")")) - searchArgs, err := bob.ExpressIf(w, d, start+len(args), c.Search, + searchArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Search, len(c.Search.Columns) > 0, "\n", "") if err != nil { return nil, err } args = append(args, searchArgs...) - cycleArgs, err := bob.ExpressIf(w, d, start+len(args), c.Cycle, + cycleArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Cycle, len(c.Cycle.Columns) > 0, "\n", "") if err != nil { return nil, err @@ -70,11 +71,11 @@ type CTESearch struct { Set string } -func (c CTESearch) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c CTESearch) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { // [ SEARCH { BREADTH | DEPTH } FIRST BY column_name [, ...] SET search_seq_col_name ] fmt.Fprintf(w, "SEARCH %s FIRST BY ", c.Order) - args, err := bob.ExpressSlice(w, d, start, c.Columns, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, "", ", ", "") if err != nil { return nil, err } @@ -92,25 +93,25 @@ type CTECycle struct { DefaultVal any } -func (c CTECycle) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c CTECycle) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { //[ CYCLE column_name [, ...] SET cycle_mark_col_name [ TO cycle_mark_value DEFAULT cycle_mark_default ] USING cycle_path_col_name ] w.Write([]byte("CYCLE ")) - args, err := bob.ExpressSlice(w, d, start, c.Columns, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, "", ", ", "") if err != nil { return nil, err } fmt.Fprintf(w, " SET %s", c.Set) - markArgs, err := bob.ExpressIf(w, d, start+len(args), c.SetVal, + markArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.SetVal, c.SetVal != nil, " TO ", "") if err != nil { return nil, err } args = append(args, markArgs...) - defaultArgs, err := bob.ExpressIf(w, d, start+len(args), c.DefaultVal, + defaultArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.DefaultVal, c.DefaultVal != nil, " DEFAULT ", "") if err != nil { return nil, err diff --git a/clause/fetch.go b/clause/fetch.go index 12d0cd5e..1e9bbfdc 100644 --- a/clause/fetch.go +++ b/clause/fetch.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "strconv" @@ -16,7 +17,7 @@ func (f *Fetch) SetFetch(fetch Fetch) { *f = fetch } -func (f Fetch) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f Fetch) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.Count == nil { return nil, nil } diff --git a/clause/for.go b/clause/for.go index 651f8aac..d6a936df 100644 --- a/clause/for.go +++ b/clause/for.go @@ -1,6 +1,7 @@ package clause import ( + "context" "errors" "fmt" "io" @@ -32,7 +33,7 @@ func (f *For) SetFor(lock For) { *f = lock } -func (f For) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f For) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.Strength == "" { return nil, nil } @@ -42,7 +43,7 @@ func (f For) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { fmt.Fprintf(w, "%s ", f.Strength) } - args, err := bob.ExpressSlice(w, d, start, f.Tables, "OF ", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, f.Tables, "OF ", ", ", "") if err != nil { return nil, err } diff --git a/clause/frame.go b/clause/frame.go index d07c87f3..b71989e8 100644 --- a/clause/frame.go +++ b/clause/frame.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -34,7 +35,7 @@ func (f *Frame) SetExclusion(excl string) { f.Exclusion = excl } -func (f Frame) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f Frame) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.Mode == "" { f.Mode = "RANGE" } @@ -52,19 +53,19 @@ func (f Frame) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("BETWEEN ")) } - startArgs, err := bob.Express(w, d, start, f.Start) + startArgs, err := bob.Express(ctx, w, d, start, f.Start) if err != nil { return nil, err } args = append(args, startArgs...) - endArgs, err := bob.ExpressIf(w, d, start, f.End, f.End != nil, " AND ", "") + endArgs, err := bob.ExpressIf(ctx, w, d, start, f.End, f.End != nil, " AND ", "") if err != nil { return nil, err } args = append(args, endArgs...) - _, err = bob.ExpressIf(w, d, start, f.Exclusion, f.Exclusion != "", " EXCLUDE ", "") + _, err = bob.ExpressIf(ctx, w, d, start, f.Exclusion, f.Exclusion != "", " EXCLUDE ", "") if err != nil { return nil, err } diff --git a/clause/from.go b/clause/from.go index e4a75a02..eab2bce2 100644 --- a/clause/from.go +++ b/clause/from.go @@ -1,6 +1,7 @@ package clause import ( + "context" "fmt" "io" @@ -86,7 +87,7 @@ func (f *From) AppendIndexHint(i IndexHint) { f.IndexHints = append(f.IndexHints, i) } -func (f From) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f From) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.Table == nil { return nil, nil } @@ -99,7 +100,7 @@ func (f From) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("LATERAL ")) } - args, err := bob.Express(w, d, start, f.Table) + args, err := bob.Express(ctx, w, d, start, f.Table) if err != nil { return nil, err } @@ -108,7 +109,7 @@ func (f From) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(" WITH ORDINALITY")) } - _, err = bob.ExpressSlice(w, d, start, f.Partitions, " PARTITION (", ", ", ")") + _, err = bob.ExpressSlice(ctx, w, d, start, f.Partitions, " PARTITION (", ", ", ")") if err != nil { return nil, err } @@ -131,7 +132,7 @@ func (f From) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { } // No args for index hints - _, err = bob.ExpressSlice(w, d, start+len(args), f.IndexHints, "\n", " ", "") + _, err = bob.ExpressSlice(ctx, w, d, start+len(args), f.IndexHints, "\n", " ", "") if err != nil { return nil, err } @@ -146,7 +147,7 @@ func (f From) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(*f.IndexedBy)) } - joinArgs, err := bob.ExpressSlice(w, d, start+len(args), f.Joins, "\n", "\n", "") + joinArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), f.Joins, "\n", "\n", "") if err != nil { return nil, err } @@ -161,20 +162,20 @@ type IndexHint struct { For string // JOIN, ORDER BY or GROUP BY } -func (f IndexHint) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f IndexHint) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.Type == "" { return nil, nil } fmt.Fprintf(w, "%s INDEX ", f.Type) - _, err := bob.ExpressIf(w, d, start, f.For, f.For != "", " FOR ", "") + _, err := bob.ExpressIf(ctx, w, d, start, f.For, f.For != "", " FOR ", "") if err != nil { return nil, err } // Always include the brackets fmt.Fprint(w, " (") - _, err = bob.ExpressSlice(w, d, start, f.Indexes, "", ", ", "") + _, err = bob.ExpressSlice(ctx, w, d, start, f.Indexes, "", ", ", "") if err != nil { return nil, err } diff --git a/clause/group_by.go b/clause/group_by.go index e32e05e0..d9ceaddb 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -28,7 +29,7 @@ func (g *GroupBy) SetGroupByDistinct(distinct bool) { g.Distinct = distinct } -func (g GroupBy) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (g GroupBy) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any // don't write anything if there are no groups @@ -41,7 +42,7 @@ func (g GroupBy) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte("DISTINCT ")) } - args, err := bob.ExpressSlice(w, d, start, g.Groups, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, g.Groups, "", ", ", "") if err != nil { return nil, err } @@ -59,9 +60,9 @@ type GroupingSet struct { Type string // GROUPING SET | CUBE | ROLLUP } -func (g GroupingSet) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (g GroupingSet) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(g.Type)) - args, err := bob.ExpressSlice(w, d, start, g.Groups, " (", ", ", ")") + args, err := bob.ExpressSlice(ctx, w, d, start, g.Groups, " (", ", ", ")") if err != nil { return nil, err } diff --git a/clause/having.go b/clause/having.go index 9056fc19..b432d13e 100644 --- a/clause/having.go +++ b/clause/having.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -14,8 +15,8 @@ func (h *Having) AppendHaving(e ...any) { h.Conditions = append(h.Conditions, e...) } -func (h Having) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - args, err := bob.ExpressSlice(w, d, start, h.Conditions, "HAVING ", " AND ", "") +func (h Having) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + args, err := bob.ExpressSlice(ctx, w, d, start, h.Conditions, "HAVING ", " AND ", "") if err != nil { return nil, err } diff --git a/clause/join.go b/clause/join.go index 92717b48..a9c03ef3 100644 --- a/clause/join.go +++ b/clause/join.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -26,7 +27,7 @@ type Join struct { Using []string } -func (j Join) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (j Join) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if j.Natural { w.Write([]byte("NATURAL ")) } @@ -34,12 +35,12 @@ func (j Join) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(j.Type)) w.Write([]byte(" ")) - args, err := bob.Express(w, d, start, j.To) + args, err := bob.Express(ctx, w, d, start, j.To) if err != nil { return nil, err } - onArgs, err := bob.ExpressSlice(w, d, start+len(args), j.On, " ON ", " AND ", "") + onArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), j.On, " ON ", " AND ", "") if err != nil { return nil, err } @@ -52,7 +53,7 @@ func (j Join) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(", ")) } - _, err = expr.Quote(col).WriteSQL(w, d, 1) // start does not matter + _, err = expr.Quote(col).WriteSQL(ctx, w, d, 1) // start does not matter if err != nil { return nil, err } diff --git a/clause/limit.go b/clause/limit.go index b1d2b2e1..ca6da1cf 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -16,6 +17,6 @@ func (l *Limit) SetLimit(limit any) { l.Count = limit } -func (l Limit) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, l.Count, l.Count != nil, "LIMIT ", "") +func (l Limit) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, l.Count, l.Count != nil, "LIMIT ", "") } diff --git a/clause/offset.go b/clause/offset.go index f2de67cb..37af4c8c 100644 --- a/clause/offset.go +++ b/clause/offset.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -16,6 +17,6 @@ func (o *Offset) SetOffset(offset any) { o.Count = offset } -func (o Offset) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, o.Count, o.Count != nil, "OFFSET ", "") +func (o Offset) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, o.Count, o.Count != nil, "OFFSET ", "") } diff --git a/clause/order_by.go b/clause/order_by.go index dee34ae4..6659189e 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -1,6 +1,7 @@ package clause import ( + "context" "fmt" "io" @@ -19,8 +20,8 @@ func (o *OrderBy) AppendOrder(order OrderDef) { o.Expressions = append(o.Expressions, order) } -func (o OrderBy) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, o.Expressions, "ORDER BY ", ", ", "") +func (o OrderBy) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, o.Expressions, "ORDER BY ", ", ", "") } type OrderDef struct { @@ -30,14 +31,14 @@ type OrderDef struct { Collation bob.Expression } -func (o OrderDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - args, err := bob.Express(w, d, start, o.Expression) +func (o OrderDef) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + args, err := bob.Express(ctx, w, d, start, o.Expression) if err != nil { return nil, err } if o.Collation != nil { - _, err = o.Collation.WriteSQL(w, d, start) + _, err = o.Collation.WriteSQL(ctx, w, d, start) if err != nil { return nil, err } diff --git a/clause/returning.go b/clause/returning.go index 73fc730c..37ccf595 100644 --- a/clause/returning.go +++ b/clause/returning.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -18,6 +19,6 @@ func (r *Returning) AppendReturning(columns ...any) { r.Expressions = append(r.Expressions, columns...) } -func (r Returning) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, r.Expressions, "RETURNING ", ", ", "") +func (r Returning) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, r.Expressions, "RETURNING ", ", ", "") } diff --git a/clause/select.go b/clause/select.go index 54eaa319..73de3495 100644 --- a/clause/select.go +++ b/clause/select.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -33,12 +34,12 @@ func (s *SelectList) AppendPreloadSelect(columns ...any) { s.PreloadColumns = append(s.PreloadColumns, columns...) } -func (s SelectList) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s SelectList) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any all := append(s.Columns, s.PreloadColumns...) if len(all) > 0 { - colArgs, err := bob.ExpressSlice(w, d, start+len(args), all, "", ", ", "") + colArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), all, "", ", ", "") if err != nil { return nil, err } diff --git a/clause/set.go b/clause/set.go index 66f61871..2dbaba29 100644 --- a/clause/set.go +++ b/clause/set.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -14,6 +15,6 @@ func (s *Set) AppendSet(exprs ...any) { s.Set = append(s.Set, exprs...) } -func (s Set) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, s.Set, "", ",\n", "") +func (s Set) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, s.Set, "", ",\n", "") } diff --git a/clause/table.go b/clause/table.go index 611d432f..794b77ff 100644 --- a/clause/table.go +++ b/clause/table.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -21,8 +22,8 @@ func (t Table) As(alias string, columns ...string) Table { return t } -func (t Table) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - args, err := bob.Express(w, d, start, t.Expression) +func (t Table) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + args, err := bob.Express(ctx, w, d, start, t.Expression) if err != nil { return nil, err } @@ -44,7 +45,7 @@ func (t Table) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(")")) } - _, err = bob.ExpressSlice(w, d, start, t.Partitions, " PARTITION (", ", ", ")") + _, err = bob.ExpressSlice(ctx, w, d, start, t.Partitions, " PARTITION (", ", ", ")") if err != nil { return nil, err } diff --git a/clause/values.go b/clause/values.go index 3c85ef36..76a544d1 100644 --- a/clause/values.go +++ b/clause/values.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -18,8 +19,8 @@ type Values struct { type value []bob.Expression -func (v value) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, v, "(", ", ", ")") +func (v value) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, v, "(", ", ", ")") } func (v *Values) AppendValues(vals ...bob.Expression) { @@ -30,15 +31,15 @@ func (v *Values) AppendValues(vals ...bob.Expression) { v.Vals = append(v.Vals, vals) } -func (v Values) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (v Values) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { // If a query is present, use it if v.Query != nil { - return v.Query.WriteQuery(w, start) + return v.Query.WriteQuery(ctx, w, start) } // If values are present, use them if len(v.Vals) > 0 { - return bob.ExpressSlice(w, d, start, v.Vals, "VALUES ", ", ", "") + return bob.ExpressSlice(ctx, w, d, start, v.Vals, "VALUES ", ", ", "") } // If no value was present, use default value diff --git a/clause/where.go b/clause/where.go index ebe47e31..db1f3a3b 100644 --- a/clause/where.go +++ b/clause/where.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -14,8 +15,8 @@ func (wh *Where) AppendWhere(e ...any) { wh.Conditions = append(wh.Conditions, e...) } -func (wh Where) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - args, err := bob.ExpressSlice(w, d, start, wh.Conditions, "WHERE ", " AND ", "") +func (wh Where) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + args, err := bob.ExpressSlice(ctx, w, d, start, wh.Conditions, "WHERE ", " AND ", "") if err != nil { return nil, err } diff --git a/clause/window.go b/clause/window.go index a5fa30b7..3745765a 100644 --- a/clause/window.go +++ b/clause/window.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -35,24 +36,24 @@ func (wi *Window) AddOrderBy(order ...any) { wi.orderBy = append(wi.orderBy, order...) } -func (wi Window) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (wi Window) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if wi.From != "" { w.Write([]byte(wi.From)) w.Write([]byte(" ")) } - args, err := bob.ExpressSlice(w, d, start, wi.partitionBy, "PARTITION BY ", ", ", " ") + args, err := bob.ExpressSlice(ctx, w, d, start, wi.partitionBy, "PARTITION BY ", ", ", " ") if err != nil { return nil, err } - orderArgs, err := bob.ExpressSlice(w, d, start, wi.orderBy, "ORDER BY ", ", ", "") + orderArgs, err := bob.ExpressSlice(ctx, w, d, start, wi.orderBy, "ORDER BY ", ", ", "") if err != nil { return nil, err } args = append(args, orderArgs...) - frameArgs, err := bob.ExpressIf(w, d, start, wi.Frame, wi.Frame.Defined, " ", "") + frameArgs, err := bob.ExpressIf(ctx, w, d, start, wi.Frame, wi.Frame.Defined, " ", "") if err != nil { return nil, err } @@ -66,10 +67,10 @@ type NamedWindow struct { Definition any } -func (n NamedWindow) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (n NamedWindow) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(n.Name)) w.Write([]byte(" AS (")) - args, err := bob.Express(w, d, start, n.Definition) + args, err := bob.Express(ctx, w, d, start, n.Definition) w.Write([]byte(")")) return args, err @@ -83,6 +84,6 @@ func (wi *Windows) AppendWindow(w NamedWindow) { wi.Windows = append(wi.Windows, w) } -func (wi Windows) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, wi.Windows, "WINDOW ", ", ", "") +func (wi Windows) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, wi.Windows, "WINDOW ", ", ", "") } diff --git a/clause/with.go b/clause/with.go index ad83ad52..5fe4739e 100644 --- a/clause/with.go +++ b/clause/with.go @@ -1,6 +1,7 @@ package clause import ( + "context" "io" "github.com/stephenafamo/bob" @@ -19,10 +20,10 @@ func (w *With) SetRecursive(r bool) { w.Recursive = r } -func (w With) WriteSQL(wr io.Writer, d bob.Dialect, start int) ([]any, error) { +func (w With) WriteSQL(ctx context.Context, wr io.Writer, d bob.Dialect, start int) ([]any, error) { prefix := "WITH\n" if w.Recursive { prefix = "WITH RECURSIVE\n" } - return bob.ExpressSlice(wr, d, start, w.CTEs, prefix, ",\n", "") + return bob.ExpressSlice(ctx, wr, d, start, w.CTEs, prefix, ",\n", "") } diff --git a/dialect/mysql/dialect/builder.go b/dialect/mysql/dialect/builder.go index 99b8c365..0c6adfe2 100644 --- a/dialect/mysql/dialect/builder.go +++ b/dialect/mysql/dialect/builder.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "strings" "github.com/stephenafamo/bob" @@ -20,6 +21,6 @@ func (Expression) New(exp bob.Expression) Expression { // Implements fmt.Stringer() func (x Expression) String() string { w := strings.Builder{} - x.WriteSQL(&w, Dialect, 1) //nolint:errcheck + x.WriteSQL(context.Background(), &w, Dialect, 1) //nolint:errcheck return w.String() } diff --git a/dialect/mysql/dialect/clauses.go b/dialect/mysql/dialect/clauses.go index 13df90e5..3c670b94 100644 --- a/dialect/mysql/dialect/clauses.go +++ b/dialect/mysql/dialect/clauses.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -15,8 +16,8 @@ func (h *modifiers[T]) AppendModifier(modifier T) { h.modifiers = append(h.modifiers, modifier) } -func (h modifiers[T]) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, h.modifiers, "", " ", "") +func (h modifiers[T]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, h.modifiers, "", " ", "") } type Set struct { @@ -24,8 +25,8 @@ type Set struct { Val any } -func (s Set) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.Express(w, d, start, expr.OP("=", expr.Quote(s.Col), s.Val)) +func (s Set) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.Express(ctx, w, d, start, expr.OP("=", expr.Quote(s.Col), s.Val)) } type partitions struct { @@ -36,6 +37,6 @@ func (h *partitions) AppendPartition(partitions ...string) { h.partitions = append(h.partitions, partitions...) } -func (h partitions) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, h.partitions, "PARTITION (", ", ", ")") +func (h partitions) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, h.partitions, "PARTITION (", ", ", ")") } diff --git a/dialect/mysql/dialect/delete.go b/dialect/mysql/dialect/delete.go index d3eaba09..f54dc8fe 100644 --- a/dialect/mysql/dialect/delete.go +++ b/dialect/mysql/dialect/delete.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -21,10 +22,10 @@ type DeleteQuery struct { clause.Limit } -func (d DeleteQuery) WriteSQL(w io.Writer, dl bob.Dialect, start int) ([]any, error) { +func (d DeleteQuery) WriteSQL(ctx context.Context, w io.Writer, dl bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, dl, start+len(args), d.With, + withArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.With, len(d.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -34,47 +35,47 @@ func (d DeleteQuery) WriteSQL(w io.Writer, dl bob.Dialect, start int) ([]any, er w.Write([]byte("DELETE ")) // no optimizer hint args - _, err = bob.ExpressIf(w, dl, start+len(args), d.hints, + _, err = bob.ExpressIf(ctx, w, dl, start+len(args), d.hints, len(d.hints.hints) > 0, "\n", "\n") if err != nil { return nil, err } // no modifiers args - _, err = bob.ExpressIf(w, dl, start+len(args), d.modifiers, + _, err = bob.ExpressIf(ctx, w, dl, start+len(args), d.modifiers, len(d.modifiers.modifiers) > 0, "", " ") if err != nil { return nil, err } - tableArgs, err := bob.ExpressSlice(w, dl, start+len(args), d.Tables, "FROM ", ", ", "") + tableArgs, err := bob.ExpressSlice(ctx, w, dl, start+len(args), d.Tables, "FROM ", ", ", "") if err != nil { return nil, err } args = append(args, tableArgs...) - usingArgs, err := bob.ExpressIf(w, dl, start+len(args), d.From, + usingArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.From, d.From.Table != nil, "\nUSING ", "") if err != nil { return nil, err } args = append(args, usingArgs...) - whereArgs, err := bob.ExpressIf(w, dl, start+len(args), d.Where, + whereArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.Where, len(d.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - orderArgs, err := bob.ExpressIf(w, dl, start+len(args), d.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.OrderBy, len(d.OrderBy.Expressions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, orderArgs...) - _, err = bob.ExpressIf(w, dl, start+len(args), d.Limit, + _, err = bob.ExpressIf(ctx, w, dl, start+len(args), d.Limit, d.Limit.Count != nil, "\n", "") if err != nil { return nil, err diff --git a/dialect/mysql/dialect/function.go b/dialect/mysql/dialect/function.go index cb5b3eeb..d9386213 100644 --- a/dialect/mysql/dialect/function.go +++ b/dialect/mysql/dialect/function.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -33,7 +34,7 @@ func (f *Function) SetWindow(w clause.Window) { f.w = &w } -func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f Function) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.name == "" { return nil, nil } @@ -45,12 +46,12 @@ func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte("DISTINCT ")) } - args, err := bob.ExpressSlice(w, d, start, f.args, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, f.args, "", ", ", "") if err != nil { return nil, err } - orderArgs, err := bob.ExpressIf(w, d, start+len(args), f.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.OrderBy, len(f.OrderBy.Expressions) > 0, " ", "") if err != nil { return nil, err @@ -59,13 +60,13 @@ func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte(")")) - filterArgs, err := bob.ExpressSlice(w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")") + filterArgs, err := bob.ExpressSlice(ctx, w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")") if err != nil { return nil, err } args = append(args, filterArgs...) - winargs, err := bob.ExpressIf(w, d, start+len(args), f.w, f.w != nil, "OVER (", ")") + winargs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.w, f.w != nil, "OVER (", ")") if err != nil { return nil, err } diff --git a/dialect/mysql/dialect/hints.go b/dialect/mysql/dialect/hints.go index 1ac81c10..592220a5 100644 --- a/dialect/mysql/dialect/hints.go +++ b/dialect/mysql/dialect/hints.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "fmt" "io" "strings" @@ -17,8 +18,8 @@ func (h *hints) AppendHint(hint string) { h.hints = append(h.hints, hint) } -func (h hints) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, h.hints, "/*+ ", "\n ", " */") +func (h hints) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressSlice(ctx, w, d, start, h.hints, "/*+ ", "\n ", " */") } type hintable interface{ AppendHint(string) } diff --git a/dialect/mysql/dialect/insert.go b/dialect/mysql/dialect/insert.go index 4521d3ef..c7b53e6b 100644 --- a/dialect/mysql/dialect/insert.go +++ b/dialect/mysql/dialect/insert.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "fmt" "io" @@ -24,34 +25,34 @@ type InsertQuery struct { DuplicateKeyUpdate clause.Set } -func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (i InsertQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any var err error w.Write([]byte("INSERT ")) // no optimizer hint args - _, err = bob.ExpressIf(w, d, start+len(args), i.hints, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), i.hints, len(i.hints.hints) > 0, "\n", "\n") if err != nil { return nil, err } // no modifiers args - _, err = bob.ExpressIf(w, d, start+len(args), i.modifiers, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), i.modifiers, len(i.modifiers.modifiers) > 0, "", " ") if err != nil { return nil, err } // no expected table args - _, err = bob.ExpressIf(w, d, start+len(args), i.Table, true, "INTO ", " ") + _, err = bob.ExpressIf(ctx, w, d, start+len(args), i.Table, true, "INTO ", " ") if err != nil { return nil, err } // no partition args - _, err = bob.ExpressIf(w, d, start+len(args), i.partitions, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), i.partitions, len(i.partitions.partitions) > 0, "", " ") if err != nil { return nil, err @@ -71,14 +72,14 @@ func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err } // Either this or the values will get expressed - setArgs, err := bob.ExpressSlice(w, d, start+len(args), i.Sets, "\nSET ", "\n", " ") + setArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), i.Sets, "\nSET ", "\n", " ") if err != nil { return nil, err } args = append(args, setArgs...) // Either this or SET will get expressed - valArgs, err := bob.ExpressIf(w, d, start+len(args), i.Values, len(i.Sets) == 0, "\n", " ") + valArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Values, len(i.Sets) == 0, "\n", " ") if err != nil { return nil, err } @@ -104,7 +105,7 @@ func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err } } - updateArgs, err := bob.ExpressSlice(w, d, start+len(args), i.DuplicateKeyUpdate.Set, + updateArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), i.DuplicateKeyUpdate.Set, "\nON DUPLICATE KEY UPDATE\n", ",\n", "") if err != nil { return nil, err diff --git a/dialect/mysql/dialect/mods.go b/dialect/mysql/dialect/mods.go index 5a63b4d2..a4be9ab4 100644 --- a/dialect/mysql/dialect/mods.go +++ b/dialect/mysql/dialect/mods.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "fmt" "io" @@ -264,7 +265,7 @@ type collation struct { name string } -func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) { +func (c collation) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, _ int) ([]any, error) { if _, err := fmt.Fprintf(w, " COLLATE %s", c.name); err != nil { return nil, err } @@ -395,10 +396,10 @@ func (w *WindowChain[T]) FromUnboundedPreceding() T { return w.Wrap } -func (w *WindowChain[T]) FromPreceding(exp any) T { +func (w *WindowChain[T]) FromPreceding(ctx context.Context, exp any) T { w.def.SetStart(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " PRECEDING") }), ) return w.Wrap @@ -411,8 +412,8 @@ func (w *WindowChain[T]) FromCurrentRow() T { func (w *WindowChain[T]) FromFollowing(exp any) T { w.def.SetStart(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " FOLLOWING") }), ) return w.Wrap @@ -420,8 +421,8 @@ func (w *WindowChain[T]) FromFollowing(exp any) T { func (w *WindowChain[T]) ToPreceding(exp any) T { w.def.SetEnd(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " PRECEDING") }), ) return w.Wrap @@ -434,8 +435,8 @@ func (w *WindowChain[T]) ToCurrentRow(count int) T { func (w *WindowChain[T]) ToFollowing(exp any) T { w.def.SetEnd(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " FOLLOWING") }), ) return w.Wrap diff --git a/dialect/mysql/dialect/select.go b/dialect/mysql/dialect/select.go index c9848274..51a249a9 100644 --- a/dialect/mysql/dialect/select.go +++ b/dialect/mysql/dialect/select.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -34,11 +35,11 @@ func (s *SelectQuery) SetInto(i any) { s.into = i } -func (s SelectQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s SelectQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any var err error - withArgs, err := bob.ExpressIf(w, d, start+len(args), s.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.With, len(s.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -48,93 +49,93 @@ func (s SelectQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("SELECT ")) // no optimizer hint args - _, err = bob.ExpressIf(w, d, start+len(args), s.hints, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), s.hints, len(s.hints.hints) > 0, "\n", "\n") if err != nil { return nil, err } // no modifiers args - _, err = bob.ExpressIf(w, d, start+len(args), s.modifiers, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), s.modifiers, len(s.modifiers.modifiers) > 0, "", " ") if err != nil { return nil, err } - selArgs, err := bob.ExpressIf(w, d, start+len(args), s.SelectList, true, "\n", "") + selArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.SelectList, true, "\n", "") if err != nil { return nil, err } args = append(args, selArgs...) - fromArgs, err := bob.ExpressIf(w, d, start+len(args), s.From, s.From.Table != nil, "\nFROM ", "") + fromArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.From, s.From.Table != nil, "\nFROM ", "") if err != nil { return nil, err } args = append(args, fromArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), s.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Where, len(s.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - groupByArgs, err := bob.ExpressIf(w, d, start+len(args), s.GroupBy, + groupByArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.GroupBy, len(s.GroupBy.Groups) > 0, "\n", "") if err != nil { return nil, err } args = append(args, groupByArgs...) - havingArgs, err := bob.ExpressIf(w, d, start+len(args), s.Having, + havingArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Having, len(s.Having.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, havingArgs...) - windowArgs, err := bob.ExpressIf(w, d, start+len(args), s.Windows, + windowArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Windows, len(s.Windows.Windows) > 0, "\n", "") if err != nil { return nil, err } args = append(args, windowArgs...) - combineArgs, err := bob.ExpressIf(w, d, start+len(args), s.Combine, + combineArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Combine, s.Combine.Query != nil, "\n", "") if err != nil { return nil, err } args = append(args, combineArgs...) - orderArgs, err := bob.ExpressIf(w, d, start+len(args), s.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.OrderBy, len(s.OrderBy.Expressions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, orderArgs...) - _, err = bob.ExpressIf(w, d, start+len(args), s.Limit, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), s.Limit, s.Limit.Count != nil, "\n", "") if err != nil { return nil, err } - _, err = bob.ExpressIf(w, d, start+len(args), s.Offset, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), s.Offset, s.Offset.Count != nil, "\n", "") if err != nil { return nil, err } - forArgs, err := bob.ExpressIf(w, d, start+len(args), s.For, + forArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.For, s.For.Strength != "", "\n", "") if err != nil { return nil, err } args = append(args, forArgs...) - intoArgs, err := bob.ExpressIf(w, d, start+len(args), s.into, + intoArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.into, s.into != nil, "\n", "") if err != nil { return nil, err diff --git a/dialect/mysql/dialect/update.go b/dialect/mysql/dialect/update.go index da982965..9c8c938c 100644 --- a/dialect/mysql/dialect/update.go +++ b/dialect/mysql/dialect/update.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -21,10 +22,10 @@ type UpdateQuery struct { clause.Limit } -func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (u UpdateQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), u.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.With, len(u.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -34,47 +35,47 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("UPDATE")) // no optimizer hint args - _, err = bob.ExpressIf(w, d, start+len(args), u.hints, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), u.hints, len(u.hints.hints) > 0, "\n", "\n") if err != nil { return nil, err } // no modifiers args - _, err = bob.ExpressIf(w, d, start+len(args), u.modifiers, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), u.modifiers, len(u.modifiers.modifiers) > 0, " ", "") if err != nil { return nil, err } - fromArgs, err := bob.ExpressIf(w, d, start+len(args), u.From, + fromArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.From, u.From.Table != nil, " ", "") if err != nil { return nil, err } args = append(args, fromArgs...) - setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "") + setArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Set, true, " SET\n", "") if err != nil { return nil, err } args = append(args, setArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), u.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Where, len(u.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - orderArgs, err := bob.ExpressIf(w, d, start+len(args), u.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.OrderBy, len(u.OrderBy.Expressions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, orderArgs...) - _, err = bob.ExpressIf(w, d, start+len(args), u.Limit, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), u.Limit, u.Limit.Count != nil, "\n", "") if err != nil { return nil, err diff --git a/dialect/mysql/load.go b/dialect/mysql/load.go index 70d8a609..f690f8da 100644 --- a/dialect/mysql/load.go +++ b/dialect/mysql/load.go @@ -57,8 +57,7 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa o.ModifyPreloadSettings(&settings) } - return buildPreloader[T](func(ctx context.Context) (string, mods.QueryMods[*dialect.SelectQuery]) { - parent, _ := ctx.Value(orm.CtxLoadParentAlias).(string) + return buildPreloader[T](func(parent string) (string, mods.QueryMods[*dialect.SelectQuery]) { if parent == "" { parent = rel.Sides[0].From } @@ -89,7 +88,7 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa } queryMods = append(queryMods, sm. - LeftJoin(side.ToExpr(ctx)). + LeftJoin(orm.SchemaTable(side.To)). As(alias). On(on...)) @@ -104,17 +103,16 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa }, rel.Name, settings) } -func buildPreloader[T any](f func(context.Context) (string, mods.QueryMods[*dialect.SelectQuery]), name string, opt PreloadSettings) Preloader { - return func(ctx context.Context) (bob.Mod[*dialect.SelectQuery], scan.MapperMod, []bob.Loader) { - alias, queryMods := f(ctx) +func buildPreloader[T any](f func(string) (string, mods.QueryMods[*dialect.SelectQuery]), name string, opt PreloadSettings) Preloader { + return func(parent string) (bob.Mod[*dialect.SelectQuery], scan.MapperMod, []bob.Loader) { + alias, queryMods := f(parent) prefix := alias + "." var mapperMods []scan.MapperMod extraLoaders := []bob.Loader{opt.ExtraLoader} - ctx = context.WithValue(ctx, orm.CtxLoadParentAlias, alias) for _, l := range opt.SubLoaders { - queryMod, mapperMod, extraLoader := l(ctx) + queryMod, mapperMod, extraLoader := l(alias) if queryMod != nil { queryMods = append(queryMods, queryMod) } diff --git a/dialect/mysql/sm/into.go b/dialect/mysql/sm/into.go index 5f88043b..b1d7b7ba 100644 --- a/dialect/mysql/sm/into.go +++ b/dialect/mysql/sm/into.go @@ -1,6 +1,7 @@ package sm import ( + "context" "fmt" "io" @@ -72,10 +73,10 @@ type into struct { lineOptions lineOptions } -func (i into) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (i into) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { // If it has vars, use INTO var_name, var_name ... if len(i.vars) > 0 { - return bob.ExpressSlice(w, d, start, i.vars, "INTO @", ", @", "") + return bob.ExpressSlice(ctx, w, d, start, i.vars, "INTO @", ", @", "") } // If dumpfile is present, use INTO DUMPFILE 'file_name' @@ -94,13 +95,13 @@ func (i into) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { return nil, err } - _, err = bob.ExpressIf(w, d, start, i.characterSet, + _, err = bob.ExpressIf(ctx, w, d, start, i.characterSet, i.characterSet != "", "\nCHARACTER SET ", "") if err != nil { return nil, err } - _, err = bob.ExpressIf(w, d, start, i.fieldOptions, i.hasFieldOpt, "\n", "") + _, err = bob.ExpressIf(ctx, w, d, start, i.fieldOptions, i.hasFieldOpt, "\n", "") if err != nil { return nil, err } @@ -115,7 +116,7 @@ type fieldOptions struct { enclosedByOptional bool } -func (f fieldOptions) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f fieldOptions) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("FIELDS")) if f.terminatedBy != "" { @@ -141,7 +142,7 @@ type lineOptions struct { terminatedBy string } -func (l lineOptions) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (l lineOptions) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("LINES")) if l.startingBy != "" { diff --git a/dialect/mysql/table.go b/dialect/mysql/table.go index 9f5a0212..9084ca07 100644 --- a/dialect/mysql/table.go +++ b/dialect/mysql/table.go @@ -159,7 +159,7 @@ func (t *Table[T, Tslice, Tset]) InsertMany(ctx context.Context, exec bob.Execut } q := Insert( - im.Into(t.Name(ctx), internal.FilterNonZero(t.setterMapping.NonGenerated)...), + im.Into(t.Name(), internal.FilterNonZero(t.setterMapping.NonGenerated)...), ) // To prevent unnecessary work, we will do this before we add the rows @@ -220,7 +220,7 @@ func (t *Table[T, Tslice, Tset]) Update(ctx context.Context, exec bob.Executor, pkPairs[i] = row.PrimaryKeyVals() } - q := Update(um.Table(t.NameAs(ctx)), vals, um.Where(t.pkExpr.In(pkPairs...))) + q := Update(um.Table(t.NameAs()), vals, um.Where(t.pkExpr.In(pkPairs...))) ctx, err = t.UpdateQueryHooks.Do(ctx, exec, q.Expression) if err != nil { @@ -288,7 +288,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut } q := Insert( - im.Into(t.Name(ctx), columns...), + im.Into(t.Name(), columns...), conflictQM, ) @@ -350,7 +350,7 @@ func (t *Table[T, Tslice, Tset]) Delete(ctx context.Context, exec bob.Executor, pkPairs[i] = row.PrimaryKeyVals() } - q := Delete(dm.From(t.NameAs(ctx)), dm.Where(t.pkExpr.In(pkPairs...))) + q := Delete(dm.From(t.NameAs()), dm.Where(t.pkExpr.In(pkPairs...))) ctx, err = t.DeleteQueryHooks.Do(ctx, exec, q.Expression) if err != nil { @@ -432,7 +432,7 @@ func (t *Table[T, Tslice, Tset]) uniqueSet(row Tset) ([]string, []any) { // Starts an insert query for this table func (t *Table[T, Tslice, Tset]) InsertQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.InsertQuery]) *TQuery[*dialect.InsertQuery, T, Tslice] { q := &TQuery[*dialect.InsertQuery, T, Tslice]{ - BaseQuery: Insert(im.Into(t.NameAs(ctx), internal.FilterNonZero(t.setterMapping.NonGenerated)...)), + BaseQuery: Insert(im.Into(t.NameAs(), internal.FilterNonZero(t.setterMapping.NonGenerated)...)), ctx: ctx, exec: exec, view: t.View, @@ -447,7 +447,7 @@ func (t *Table[T, Tslice, Tset]) InsertQ(ctx context.Context, exec bob.Executor, // Starts an update query for this table func (t *Table[T, Tslice, Tset]) UpdateQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.UpdateQuery]) *TQuery[*dialect.UpdateQuery, T, Tslice] { q := &TQuery[*dialect.UpdateQuery, T, Tslice]{ - BaseQuery: Update(um.Table(t.NameAs(ctx))), + BaseQuery: Update(um.Table(t.NameAs())), ctx: ctx, exec: exec, view: t.View, @@ -462,7 +462,7 @@ func (t *Table[T, Tslice, Tset]) UpdateQ(ctx context.Context, exec bob.Executor, // Starts a delete query for this table func (t *Table[T, Tslice, Tset]) DeleteQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.DeleteQuery]) *TQuery[*dialect.DeleteQuery, T, Tslice] { q := &TQuery[*dialect.DeleteQuery, T, Tslice]{ - BaseQuery: Delete(dm.From(t.NameAs(ctx))), + BaseQuery: Delete(dm.From(t.NameAs())), ctx: ctx, exec: exec, view: t.View, diff --git a/dialect/mysql/view.go b/dialect/mysql/view.go index 0118bcaa..9b78ee77 100644 --- a/dialect/mysql/view.go +++ b/dialect/mysql/view.go @@ -49,12 +49,12 @@ type View[T any, Tslice ~[]T] struct { SelectQueryHooks orm.Hooks[*dialect.SelectQuery, orm.SkipQueryHooksKey] } -func (v *View[T, Tslice]) Name(ctx context.Context) Expression { +func (v *View[T, Tslice]) Name() Expression { return Quote(v.name) } -func (v *View[T, Tslice]) NameAs(ctx context.Context) bob.Expression { - return v.Name(ctx).As(v.alias) +func (v *View[T, Tslice]) NameAs() bob.Expression { + return v.Name().As(v.alias) } // Returns a column list @@ -66,13 +66,12 @@ func (v *View[T, Tslice]) Columns() orm.Columns { // Adds table name et al func (v *View[T, Tslice]) Query(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.SelectQuery]) *ViewQuery[T, Tslice] { q := &ViewQuery[T, Tslice]{ - BaseQuery: Select(sm.From(v.NameAs(ctx))), + BaseQuery: Select(sm.From(v.NameAs())), ctx: ctx, exec: exec, view: v, } - q.Expression.SetLoadContext(ctx) q.Apply(queryMods...) return q @@ -109,23 +108,23 @@ type ViewQuery[T any, Ts ~[]T] struct { } // it is necessary to override this method to be able to add columns if not set -func (v ViewQuery[T, Ts]) WriteSQL(w io.Writer, _ bob.Dialect, start int) ([]any, error) { +func (v ViewQuery[T, Ts]) WriteSQL(ctx context.Context, w io.Writer, _ bob.Dialect, start int) ([]any, error) { // Append the table columns if len(v.BaseQuery.Expression.SelectList.Columns) == 0 { v.BaseQuery.Expression.AppendSelect(v.view.Columns()) } - return v.Expression.WriteSQL(w, v.Dialect, start) + return v.Expression.WriteSQL(ctx, w, v.Dialect, start) } // it is necessary to override this method to be able to add columns if not set -func (v ViewQuery[T, Ts]) WriteQuery(w io.Writer, start int) ([]any, error) { +func (v ViewQuery[T, Ts]) WriteQuery(ctx context.Context, w io.Writer, start int) ([]any, error) { // Append the table columns if len(v.BaseQuery.Expression.SelectList.Columns) == 0 { v.BaseQuery.Expression.AppendSelect(v.view.Columns()) } - return v.BaseQuery.WriteQuery(w, start) + return v.BaseQuery.WriteQuery(ctx, w, start) } func (v *ViewQuery[T, Ts]) hook() error { diff --git a/dialect/psql/dialect/builder.go b/dialect/psql/dialect/builder.go index 4d4ad06e..319a5fdf 100644 --- a/dialect/psql/dialect/builder.go +++ b/dialect/psql/dialect/builder.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "strings" "github.com/stephenafamo/bob" @@ -15,6 +16,10 @@ var ( iLike = expr.Raw("ILIKE") ) +func NewExpression(exp bob.Expression) Expression { + return Expression{}.New(exp) +} + type Expression struct { expr.Chain[Expression, Expression] } @@ -28,7 +33,7 @@ func (Expression) New(exp bob.Expression) Expression { // Implements fmt.Stringer() func (x Expression) String() string { w := strings.Builder{} - x.WriteSQL(&w, Dialect, 1) //nolint:errcheck + x.WriteSQL(context.Background(), &w, Dialect, 1) //nolint:errcheck return w.String() } diff --git a/dialect/psql/dialect/delete.go b/dialect/psql/dialect/delete.go index 96498d5d..c2873c0b 100644 --- a/dialect/psql/dialect/delete.go +++ b/dialect/psql/dialect/delete.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -18,10 +19,10 @@ type DeleteQuery struct { clause.Returning } -func (d DeleteQuery) WriteSQL(w io.Writer, dl bob.Dialect, start int) ([]any, error) { +func (d DeleteQuery) WriteSQL(ctx context.Context, w io.Writer, dl bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, dl, start+len(args), d.With, + withArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.With, len(d.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -34,27 +35,27 @@ func (d DeleteQuery) WriteSQL(w io.Writer, dl bob.Dialect, start int) ([]any, er w.Write([]byte("ONLY ")) } - tableArgs, err := bob.ExpressIf(w, dl, start+len(args), d.Table, true, "", "") + tableArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.Table, true, "", "") if err != nil { return nil, err } args = append(args, tableArgs...) - usingArgs, err := bob.ExpressIf(w, dl, start+len(args), d.From, + usingArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.From, d.From.Table != nil, "\nUSING ", "") if err != nil { return nil, err } args = append(args, usingArgs...) - whereArgs, err := bob.ExpressIf(w, dl, start+len(args), d.Where, + whereArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.Where, len(d.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - retArgs, err := bob.ExpressIf(w, dl, start+len(args), d.Returning, + retArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.Returning, len(d.Returning.Expressions) > 0, "\n", "") if err != nil { return nil, err diff --git a/dialect/psql/dialect/function.go b/dialect/psql/dialect/function.go index 80422355..ed30db7d 100644 --- a/dialect/psql/dialect/function.go +++ b/dialect/psql/dialect/function.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -42,7 +43,7 @@ func (f *Function) AppendColumn(name, datatype string) { }) } -func (f *Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f *Function) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.name == "" { return nil, nil } @@ -54,13 +55,13 @@ func (f *Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error w.Write([]byte("DISTINCT ")) } - args, err := bob.ExpressSlice(w, d, start, f.args, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, f.args, "", ", ", "") if err != nil { return nil, err } if !f.WithinGroup { - orderArgs, err := bob.ExpressIf(w, d, start+len(args), f.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.OrderBy, len(f.OrderBy.Expressions) > 0, " ", "") if err != nil { return nil, err @@ -70,7 +71,7 @@ func (f *Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error w.Write([]byte(")")) if f.WithinGroup { - orderArgs, err := bob.ExpressIf(w, d, start+len(args), f.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.OrderBy, len(f.OrderBy.Expressions) > 0, " WITHIN GROUP (", ")") if err != nil { return nil, err @@ -78,7 +79,7 @@ func (f *Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error args = append(args, orderArgs...) } - filterArgs, err := bob.ExpressSlice(w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")") + filterArgs, err := bob.ExpressSlice(ctx, w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")") if err != nil { return nil, err } @@ -93,13 +94,13 @@ func (f *Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error w.Write([]byte(" ")) } - colArgs, err := bob.ExpressSlice(w, d, start+len(args), f.Columns, "(", ", ", ")") + colArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), f.Columns, "(", ", ", ")") if err != nil { return nil, err } args = append(args, colArgs...) - winargs, err := bob.ExpressIf(w, d, start+len(args), f.w, f.w != nil, "OVER (", ")") + winargs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.w, f.w != nil, "OVER (", ")") if err != nil { return nil, err } @@ -113,7 +114,7 @@ type columnDef struct { dataType string } -func (c columnDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c columnDef) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte(c.name + " " + c.dataType)) return nil, nil @@ -121,12 +122,12 @@ func (c columnDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error type Functions []*Function -func (f Functions) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f Functions) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if len(f) > 1 { w.Write([]byte("ROWS FROM (")) } - args, err := bob.ExpressSlice(w, d, start, f, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, f, "", ", ", "") if err != nil { return nil, err } diff --git a/dialect/psql/dialect/insert.go b/dialect/psql/dialect/insert.go index f768011f..772f41f9 100644 --- a/dialect/psql/dialect/insert.go +++ b/dialect/psql/dialect/insert.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -18,43 +19,43 @@ type InsertQuery struct { clause.Returning } -func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (i InsertQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), i.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.With, len(i.With.CTEs) > 0, "", "\n") if err != nil { return nil, err } args = append(args, withArgs...) - tableArgs, err := bob.ExpressIf(w, d, start+len(args), i.Table, + tableArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Table, true, "INSERT INTO ", "") if err != nil { return nil, err } args = append(args, tableArgs...) - _, err = bob.ExpressIf(w, d, start+len(args), i.Overriding, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), i.Overriding, i.Overriding != "", "\nOVERRIDING ", " VALUE") if err != nil { return nil, err } - valArgs, err := bob.ExpressIf(w, d, start+len(args), i.Values, true, "\n", "") + valArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Values, true, "\n", "") if err != nil { return nil, err } args = append(args, valArgs...) - conflictArgs, err := bob.ExpressIf(w, d, start+len(args), i.Conflict, + conflictArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Conflict, i.Conflict.Do != "", "\n", "") if err != nil { return nil, err } args = append(args, conflictArgs...) - retArgs, err := bob.ExpressIf(w, d, start+len(args), i.Returning, + retArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Returning, len(i.Returning.Expressions) > 0, "\n", "") if err != nil { return nil, err diff --git a/dialect/psql/dialect/mods.go b/dialect/psql/dialect/mods.go index 99cde8c7..6c56f4d4 100644 --- a/dialect/psql/dialect/mods.go +++ b/dialect/psql/dialect/mods.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -13,9 +14,9 @@ type Distinct struct { On []any } -func (di Distinct) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (di Distinct) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("DISTINCT")) - return bob.ExpressSlice(w, d, start, di.On, " ON (", ", ", ")") + return bob.ExpressSlice(ctx, w, d, start, di.On, " ON (", ", ", ")") } func With[Q interface{ AppendWith(clause.CTE) }](name string, columns ...string) CTEChain[Q] { @@ -201,7 +202,7 @@ type collation struct { name string } -func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) { +func (c collation) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, _ int) ([]any, error) { if _, err := w.Write([]byte(" COLLATE ")); err != nil { return nil, err } @@ -428,8 +429,8 @@ func (w *WindowChain[T]) FromUnboundedPreceding() T { func (w *WindowChain[T]) FromPreceding(exp any) T { w.def.SetStart(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " PRECEDING") }), ) return w.Wrap @@ -442,8 +443,8 @@ func (w *WindowChain[T]) FromCurrentRow() T { func (w *WindowChain[T]) FromFollowing(exp any) T { w.def.SetStart(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " FOLLOWING") }), ) return w.Wrap @@ -451,8 +452,8 @@ func (w *WindowChain[T]) FromFollowing(exp any) T { func (w *WindowChain[T]) ToPreceding(exp any) T { w.def.SetEnd(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " PRECEDING") }), ) return w.Wrap @@ -465,8 +466,8 @@ func (w *WindowChain[T]) ToCurrentRow(count int) T { func (w *WindowChain[T]) ToFollowing(exp any) T { w.def.SetEnd(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " FOLLOWING") }), ) return w.Wrap diff --git a/dialect/psql/dialect/select.go b/dialect/psql/dialect/select.go index 82dfb9de..ab4dd101 100644 --- a/dialect/psql/dialect/select.go +++ b/dialect/psql/dialect/select.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -27,10 +28,10 @@ type SelectQuery struct { bob.Load } -func (s SelectQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s SelectQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), s.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.With, len(s.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -39,88 +40,88 @@ func (s SelectQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("SELECT ")) - distinctArgs, err := bob.ExpressIf(w, d, start+len(args), s.Distinct, + distinctArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Distinct, s.Distinct.On != nil, "", " ") if err != nil { return nil, err } args = append(args, distinctArgs...) - selArgs, err := bob.ExpressIf(w, d, start+len(args), s.SelectList, true, "\n", "") + selArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.SelectList, true, "\n", "") if err != nil { return nil, err } args = append(args, selArgs...) - fromArgs, err := bob.ExpressIf(w, d, start+len(args), s.From, s.From.Table != nil, "\nFROM ", "") + fromArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.From, s.From.Table != nil, "\nFROM ", "") if err != nil { return nil, err } args = append(args, fromArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), s.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Where, len(s.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - groupByArgs, err := bob.ExpressIf(w, d, start+len(args), s.GroupBy, + groupByArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.GroupBy, len(s.GroupBy.Groups) > 0, "\n", "") if err != nil { return nil, err } args = append(args, groupByArgs...) - havingArgs, err := bob.ExpressIf(w, d, start+len(args), s.Having, + havingArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Having, len(s.Having.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, havingArgs...) - windowArgs, err := bob.ExpressIf(w, d, start+len(args), s.Windows, + windowArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Windows, len(s.Windows.Windows) > 0, "\n", "") if err != nil { return nil, err } args = append(args, windowArgs...) - combineArgs, err := bob.ExpressIf(w, d, start+len(args), s.Combine, + combineArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Combine, s.Combine.Query != nil, "\n", "") if err != nil { return nil, err } args = append(args, combineArgs...) - orderArgs, err := bob.ExpressIf(w, d, start+len(args), s.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.OrderBy, len(s.OrderBy.Expressions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, orderArgs...) - limitArgs, err := bob.ExpressIf(w, d, start+len(args), s.Limit, + limitArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Limit, s.Limit.Count != nil, "\n", "") if err != nil { return nil, err } args = append(args, limitArgs...) - offsetArgs, err := bob.ExpressIf(w, d, start+len(args), s.Offset, + offsetArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Offset, s.Offset.Count != nil, "\n", "") if err != nil { return nil, err } args = append(args, offsetArgs...) - _, err = bob.ExpressIf(w, d, start+len(args), s.Fetch, + _, err = bob.ExpressIf(ctx, w, d, start+len(args), s.Fetch, s.Fetch.Count != nil, "\n", "") if err != nil { return nil, err } - forArgs, err := bob.ExpressIf(w, d, start+len(args), s.For, + forArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.For, s.For.Strength != "", "\n", "") if err != nil { return nil, err diff --git a/dialect/psql/dialect/update.go b/dialect/psql/dialect/update.go index 5b7f0bb1..fdb6aa30 100644 --- a/dialect/psql/dialect/update.go +++ b/dialect/psql/dialect/update.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -19,10 +20,10 @@ type UpdateQuery struct { clause.Returning } -func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (u UpdateQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), u.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.With, len(u.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -35,33 +36,33 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("ONLY ")) } - tableArgs, err := bob.ExpressIf(w, d, start+len(args), u.Table, true, "", "") + tableArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Table, true, "", "") if err != nil { return nil, err } args = append(args, tableArgs...) - setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "") + setArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Set, true, " SET\n", "") if err != nil { return nil, err } args = append(args, setArgs...) - fromArgs, err := bob.ExpressIf(w, d, start+len(args), u.From, + fromArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.From, u.From.Table != nil, "\nFROM ", "") if err != nil { return nil, err } args = append(args, fromArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), u.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Where, len(u.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - retArgs, err := bob.ExpressIf(w, d, start+len(args), u.Returning, + retArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Returning, len(u.Returning.Expressions) > 0, "\n", "") if err != nil { return nil, err diff --git a/dialect/psql/load.go b/dialect/psql/load.go index 8a6ba9df..4ef4f74f 100644 --- a/dialect/psql/load.go +++ b/dialect/psql/load.go @@ -57,8 +57,7 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa o.ModifyPreloadSettings(&settings) } - return buildPreloader[T](func(ctx context.Context) (string, mods.QueryMods[*dialect.SelectQuery]) { - parent, _ := ctx.Value(orm.CtxLoadParentAlias).(string) + return buildPreloader[T](func(parent string) (string, mods.QueryMods[*dialect.SelectQuery]) { if parent == "" { parent = rel.Sides[0].From } @@ -90,7 +89,7 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa } queryMods = append(queryMods, sm. - LeftJoin(side.ToExpr(ctx)). + LeftJoin(orm.SchemaTable(side.To)). As(alias). On(on...)) @@ -105,17 +104,16 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa }, rel.Name, settings) } -func buildPreloader[T any](f func(context.Context) (string, mods.QueryMods[*dialect.SelectQuery]), name string, opt PreloadSettings) Preloader { - return func(ctx context.Context) (bob.Mod[*dialect.SelectQuery], scan.MapperMod, []bob.Loader) { - alias, queryMods := f(ctx) +func buildPreloader[T any](f func(string) (string, mods.QueryMods[*dialect.SelectQuery]), name string, opt PreloadSettings) Preloader { + return func(parent string) (bob.Mod[*dialect.SelectQuery], scan.MapperMod, []bob.Loader) { + alias, queryMods := f(parent) prefix := alias + "." var mapperMods []scan.MapperMod extraLoaders := []bob.Loader{opt.ExtraLoader} - ctx = context.WithValue(ctx, orm.CtxLoadParentAlias, alias) for _, l := range opt.SubLoaders { - queryMod, mapperMod, extraLoader := l(ctx) + queryMod, mapperMod, extraLoader := l(alias) if queryMod != nil { queryMods = append(queryMods, queryMod) } diff --git a/dialect/psql/table.go b/dialect/psql/table.go index 61a282da..6f71e7d2 100644 --- a/dialect/psql/table.go +++ b/dialect/psql/table.go @@ -101,7 +101,7 @@ func (t *Table[T, Tslice, Tset]) InsertMany(ctx context.Context, exec bob.Execut } q := Insert( - im.Into(t.NameAs(ctx), internal.FilterNonZero(t.setterMapping.NonGenerated)...), + im.Into(t.NameAs(), internal.FilterNonZero(t.setterMapping.NonGenerated)...), im.Returning(t.Columns()), ) @@ -145,7 +145,7 @@ func (t *Table[T, Tslice, Tset]) Update(ctx context.Context, exec bob.Executor, pkPairs[i] = row.PrimaryKeyVals() } - q := Update(um.Table(t.NameAs(ctx)), vals, um.Where(t.pkExpr.In(pkPairs...))) + q := Update(um.Table(t.NameAs()), vals, um.Where(t.pkExpr.In(pkPairs...))) ctx, err = t.UpdateQueryHooks.Do(ctx, exec, q.Expression) if err != nil { @@ -225,7 +225,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut } q := Insert( - im.Into(t.NameAs(ctx), internal.FilterNonZero(t.setterMapping.NonGenerated)...), + im.Into(t.NameAs(), internal.FilterNonZero(t.setterMapping.NonGenerated)...), im.Returning(t.Columns()), conflictQM, ) @@ -268,7 +268,7 @@ func (t *Table[T, Tslice, Tset]) Delete(ctx context.Context, exec bob.Executor, pkPairs[i] = row.PrimaryKeyVals() } - q := Delete(dm.From(t.NameAs(ctx)), dm.Where(t.pkExpr.In(pkPairs...))) + q := Delete(dm.From(t.NameAs()), dm.Where(t.pkExpr.In(pkPairs...))) ctx, err = t.DeleteQueryHooks.Do(ctx, exec, q.Expression) if err != nil { @@ -289,7 +289,7 @@ func (t *Table[T, Tslice, Tset]) Delete(ctx context.Context, exec bob.Executor, // Starts an insert query for this table func (t *Table[T, Tslice, Tset]) InsertQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.InsertQuery]) *TableQuery[*dialect.InsertQuery, T, Tslice] { q := &TableQuery[*dialect.InsertQuery, T, Tslice]{ - BaseQuery: Insert(im.Into(t.NameAs(ctx), internal.FilterNonZero(t.setterMapping.NonGenerated)...)), + BaseQuery: Insert(im.Into(t.NameAs(), internal.FilterNonZero(t.setterMapping.NonGenerated)...)), ctx: ctx, exec: exec, view: t.View, @@ -304,7 +304,7 @@ func (t *Table[T, Tslice, Tset]) InsertQ(ctx context.Context, exec bob.Executor, // Starts an update query for this table func (t *Table[T, Tslice, Tset]) UpdateQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.UpdateQuery]) *TableQuery[*dialect.UpdateQuery, T, Tslice] { q := &TableQuery[*dialect.UpdateQuery, T, Tslice]{ - BaseQuery: Update(um.Table(t.NameAs(ctx))), + BaseQuery: Update(um.Table(t.NameAs())), ctx: ctx, exec: exec, view: t.View, @@ -319,7 +319,7 @@ func (t *Table[T, Tslice, Tset]) UpdateQ(ctx context.Context, exec bob.Executor, // Starts a delete query for this table func (t *Table[T, Tslice, Tset]) DeleteQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.DeleteQuery]) *TableQuery[*dialect.DeleteQuery, T, Tslice] { q := &TableQuery[*dialect.DeleteQuery, T, Tslice]{ - BaseQuery: Delete(dm.From(t.NameAs(ctx))), + BaseQuery: Delete(dm.From(t.NameAs())), ctx: ctx, exec: exec, view: t.View, diff --git a/dialect/psql/view.go b/dialect/psql/view.go index 48c604bc..02182dd8 100644 --- a/dialect/psql/view.go +++ b/dialect/psql/view.go @@ -62,18 +62,17 @@ type View[T any, Tslice ~[]T] struct { SelectQueryHooks orm.Hooks[*dialect.SelectQuery, orm.SkipQueryHooksKey] } -func (v *View[T, Tslice]) Name(ctx context.Context) Expression { +func (v *View[T, Tslice]) Name() Expression { // schema is not empty, never override if v.schema != "" { return Quote(v.schema, v.name) } - schema, _ := ctx.Value(orm.CtxUseSchema).(string) - return Quote(schema, v.name) + return Expression{}.New(orm.SchemaTable(v.name)) } -func (v *View[T, Tslice]) NameAs(ctx context.Context) bob.Expression { - return v.Name(ctx).As(v.alias) +func (v *View[T, Tslice]) NameAs() bob.Expression { + return v.Name().As(v.alias) } // Returns a column list @@ -85,13 +84,12 @@ func (v *View[T, Tslice]) Columns() orm.Columns { // Starts a select query func (v *View[T, Tslice]) Query(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.SelectQuery]) *ViewQuery[T, Tslice] { q := &ViewQuery[T, Tslice]{ - BaseQuery: Select(sm.From(v.NameAs(ctx))), + BaseQuery: Select(sm.From(v.NameAs())), ctx: ctx, exec: exec, view: v, } - q.Expression.SetLoadContext(ctx) q.Apply(queryMods...) return q @@ -128,23 +126,23 @@ type ViewQuery[T any, Ts ~[]T] struct { } // it is necessary to override this method to be able to add columns if not set -func (v ViewQuery[T, Ts]) WriteSQL(w io.Writer, _ bob.Dialect, start int) ([]any, error) { +func (v ViewQuery[T, Ts]) WriteSQL(ctx context.Context, w io.Writer, _ bob.Dialect, start int) ([]any, error) { // Append the table columns if len(v.BaseQuery.Expression.SelectList.Columns) == 0 { v.BaseQuery.Expression.AppendSelect(v.view.Columns()) } - return v.BaseQuery.WriteSQL(w, v.Dialect, start) + return v.BaseQuery.WriteSQL(ctx, w, v.Dialect, start) } // it is necessary to override this method to be able to add columns if not set -func (v ViewQuery[T, Ts]) WriteQuery(w io.Writer, start int) ([]any, error) { +func (v ViewQuery[T, Ts]) WriteQuery(ctx context.Context, w io.Writer, start int) ([]any, error) { // Append the table columns if len(v.BaseQuery.Expression.SelectList.Columns) == 0 { v.BaseQuery.Expression.AppendSelect(v.view.Columns()) } - return v.BaseQuery.WriteQuery(w, start) + return v.BaseQuery.WriteQuery(ctx, w, start) } func (v *ViewQuery[T, Ts]) hook() error { diff --git a/dialect/sqlite/dialect/builder.go b/dialect/sqlite/dialect/builder.go index 99b8c365..0c6adfe2 100644 --- a/dialect/sqlite/dialect/builder.go +++ b/dialect/sqlite/dialect/builder.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "strings" "github.com/stephenafamo/bob" @@ -20,6 +21,6 @@ func (Expression) New(exp bob.Expression) Expression { // Implements fmt.Stringer() func (x Expression) String() string { w := strings.Builder{} - x.WriteSQL(&w, Dialect, 1) //nolint:errcheck + x.WriteSQL(context.Background(), &w, Dialect, 1) //nolint:errcheck return w.String() } diff --git a/dialect/sqlite/dialect/delete.go b/dialect/sqlite/dialect/delete.go index 43369206..7cdc3a53 100644 --- a/dialect/sqlite/dialect/delete.go +++ b/dialect/sqlite/dialect/delete.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -16,10 +17,10 @@ type DeleteQuery struct { clause.Returning } -func (d DeleteQuery) WriteSQL(w io.Writer, dl bob.Dialect, start int) ([]any, error) { +func (d DeleteQuery) WriteSQL(ctx context.Context, w io.Writer, dl bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, dl, start+len(args), d.With, + withArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.With, len(d.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -28,20 +29,20 @@ func (d DeleteQuery) WriteSQL(w io.Writer, dl bob.Dialect, start int) ([]any, er w.Write([]byte("DELETE FROM")) - tableArgs, err := bob.ExpressIf(w, dl, start+len(args), d.From, true, " ", "") + tableArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.From, true, " ", "") if err != nil { return nil, err } args = append(args, tableArgs...) - whereArgs, err := bob.ExpressIf(w, dl, start+len(args), d.Where, + whereArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.Where, len(d.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - retArgs, err := bob.ExpressIf(w, dl, start+len(args), d.Returning, + retArgs, err := bob.ExpressIf(ctx, w, dl, start+len(args), d.Returning, len(d.Returning.Expressions) > 0, "\n", "") if err != nil { return nil, err diff --git a/dialect/sqlite/dialect/function.go b/dialect/sqlite/dialect/function.go index cb5b3eeb..d9386213 100644 --- a/dialect/sqlite/dialect/function.go +++ b/dialect/sqlite/dialect/function.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -33,7 +34,7 @@ func (f *Function) SetWindow(w clause.Window) { f.w = &w } -func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (f Function) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if f.name == "" { return nil, nil } @@ -45,12 +46,12 @@ func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte("DISTINCT ")) } - args, err := bob.ExpressSlice(w, d, start, f.args, "", ", ", "") + args, err := bob.ExpressSlice(ctx, w, d, start, f.args, "", ", ", "") if err != nil { return nil, err } - orderArgs, err := bob.ExpressIf(w, d, start+len(args), f.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.OrderBy, len(f.OrderBy.Expressions) > 0, " ", "") if err != nil { return nil, err @@ -59,13 +60,13 @@ func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte(")")) - filterArgs, err := bob.ExpressSlice(w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")") + filterArgs, err := bob.ExpressSlice(ctx, w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")") if err != nil { return nil, err } args = append(args, filterArgs...) - winargs, err := bob.ExpressIf(w, d, start+len(args), f.w, f.w != nil, "OVER (", ")") + winargs, err := bob.ExpressIf(ctx, w, d, start+len(args), f.w, f.w != nil, "OVER (", ")") if err != nil { return nil, err } diff --git a/dialect/sqlite/dialect/insert.go b/dialect/sqlite/dialect/insert.go index e68ed407..78764d36 100644 --- a/dialect/sqlite/dialect/insert.go +++ b/dialect/sqlite/dialect/insert.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -18,10 +19,10 @@ type InsertQuery struct { clause.Returning } -func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (i InsertQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), i.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.With, len(i.With.CTEs) > 0, "", "\n") if err != nil { return nil, err @@ -30,31 +31,31 @@ func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("INSERT")) - _, err = bob.ExpressIf(w, d, start+len(args), i.or, true, " ", "") + _, err = bob.ExpressIf(ctx, w, d, start+len(args), i.or, true, " ", "") if err != nil { return nil, err } - tableArgs, err := bob.ExpressIf(w, d, start+len(args), i.Table, true, " INTO ", "") + tableArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Table, true, " INTO ", "") if err != nil { return nil, err } args = append(args, tableArgs...) - valArgs, err := bob.ExpressIf(w, d, start+len(args), i.Values, true, "\n", "") + valArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Values, true, "\n", "") if err != nil { return nil, err } args = append(args, valArgs...) - conflictArgs, err := bob.ExpressIf(w, d, start+len(args), i.Conflict, + conflictArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Conflict, i.Conflict.Do != "", "\n", "") if err != nil { return nil, err } args = append(args, conflictArgs...) - retArgs, err := bob.ExpressIf(w, d, start+len(args), i.Returning, + retArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), i.Returning, len(i.Returning.Expressions) > 0, "\n", "") if err != nil { return nil, err diff --git a/dialect/sqlite/dialect/mods.go b/dialect/sqlite/dialect/mods.go index 09ca12b2..87d6625e 100644 --- a/dialect/sqlite/dialect/mods.go +++ b/dialect/sqlite/dialect/mods.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "fmt" "io" @@ -233,7 +234,7 @@ type collation struct { name string } -func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) { +func (c collation) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, _ int) ([]any, error) { if _, err := fmt.Fprintf(w, " COLLATE %s", c.name); err != nil { return nil, err } @@ -353,8 +354,8 @@ func (w *WindowChain[T]) FromUnboundedPreceding() T { func (w *WindowChain[T]) FromPreceding(exp any) T { w.def.SetStart(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " PRECEDING") }), ) return w.Wrap @@ -367,8 +368,8 @@ func (w *WindowChain[T]) FromCurrentRow() T { func (w *WindowChain[T]) FromFollowing(exp any) T { w.def.SetStart(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " FOLLOWING") }), ) return w.Wrap @@ -376,8 +377,8 @@ func (w *WindowChain[T]) FromFollowing(exp any) T { func (w *WindowChain[T]) ToPreceding(exp any) T { w.def.SetEnd(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " PRECEDING") }), ) return w.Wrap @@ -390,8 +391,8 @@ func (w *WindowChain[T]) ToCurrentRow(count int) T { func (w *WindowChain[T]) ToFollowing(exp any) T { w.def.SetEnd(bob.ExpressionFunc( - func(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") + func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, exp, true, "", " FOLLOWING") }), ) return w.Wrap diff --git a/dialect/sqlite/dialect/or.go b/dialect/sqlite/dialect/or.go index e8143ba5..cbdf5071 100644 --- a/dialect/sqlite/dialect/or.go +++ b/dialect/sqlite/dialect/or.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -14,6 +15,6 @@ func (o *or) SetOr(to string) { o.action = to } -func (o or) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, o.action, o.action != "", " OR ", "") +func (o or) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, o.action, o.action != "", " OR ", "") } diff --git a/dialect/sqlite/dialect/select.go b/dialect/sqlite/dialect/select.go index d9361e1e..7b50c5a9 100644 --- a/dialect/sqlite/dialect/select.go +++ b/dialect/sqlite/dialect/select.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -25,10 +26,10 @@ type SelectQuery struct { bob.Load } -func (s SelectQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s SelectQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), s.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.With, len(s.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -41,68 +42,68 @@ func (s SelectQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("DISTINCT ")) } - selArgs, err := bob.ExpressIf(w, d, start+len(args), s.SelectList, true, "\n", "") + selArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.SelectList, true, "\n", "") if err != nil { return nil, err } args = append(args, selArgs...) - fromArgs, err := bob.ExpressIf(w, d, start+len(args), s.From, s.From.Table != nil, "\nFROM ", "") + fromArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.From, s.From.Table != nil, "\nFROM ", "") if err != nil { return nil, err } args = append(args, fromArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), s.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Where, len(s.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - groupByArgs, err := bob.ExpressIf(w, d, start+len(args), s.GroupBy, + groupByArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.GroupBy, len(s.GroupBy.Groups) > 0, "\n", "") if err != nil { return nil, err } args = append(args, groupByArgs...) - havingArgs, err := bob.ExpressIf(w, d, start+len(args), s.Having, + havingArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Having, len(s.Having.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, havingArgs...) - windowArgs, err := bob.ExpressIf(w, d, start+len(args), s.Windows, + windowArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Windows, len(s.Windows.Windows) > 0, "\n", "") if err != nil { return nil, err } args = append(args, windowArgs...) - combineArgs, err := bob.ExpressIf(w, d, start+len(args), s.Combine, + combineArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Combine, s.Combine.Query != nil, "\n", "") if err != nil { return nil, err } args = append(args, combineArgs...) - orderArgs, err := bob.ExpressIf(w, d, start+len(args), s.OrderBy, + orderArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.OrderBy, len(s.OrderBy.Expressions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, orderArgs...) - limitArgs, err := bob.ExpressIf(w, d, start+len(args), s.Limit, + limitArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Limit, s.Limit.Count != nil, "\n", "") if err != nil { return nil, err } args = append(args, limitArgs...) - offsetArgs, err := bob.ExpressIf(w, d, start+len(args), s.Offset, + offsetArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), s.Offset, s.Offset.Count != nil, "\n", "") if err != nil { return nil, err diff --git a/dialect/sqlite/dialect/update.go b/dialect/sqlite/dialect/update.go index ea51289f..204fdb8d 100644 --- a/dialect/sqlite/dialect/update.go +++ b/dialect/sqlite/dialect/update.go @@ -1,6 +1,7 @@ package dialect import ( + "context" "io" "github.com/stephenafamo/bob" @@ -19,10 +20,10 @@ type UpdateQuery struct { clause.Returning } -func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (u UpdateQuery) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { var args []any - withArgs, err := bob.ExpressIf(w, d, start+len(args), u.With, + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.With, len(u.With.CTEs) > 0, "\n", "") if err != nil { return nil, err @@ -31,38 +32,38 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err w.Write([]byte("UPDATE")) - _, err = bob.ExpressIf(w, d, start+len(args), u.or, true, " ", "") + _, err = bob.ExpressIf(ctx, w, d, start+len(args), u.or, true, " ", "") if err != nil { return nil, err } - tableArgs, err := bob.ExpressIf(w, d, start+len(args), u.Table, true, " ", "") + tableArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Table, true, " ", "") if err != nil { return nil, err } args = append(args, tableArgs...) - setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "") + setArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Set, true, " SET\n", "") if err != nil { return nil, err } args = append(args, setArgs...) - fromArgs, err := bob.ExpressIf(w, d, start+len(args), u.From, + fromArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.From, u.From.Table != nil, "\nFROM ", "") if err != nil { return nil, err } args = append(args, fromArgs...) - whereArgs, err := bob.ExpressIf(w, d, start+len(args), u.Where, + whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Where, len(u.Where.Conditions) > 0, "\n", "") if err != nil { return nil, err } args = append(args, whereArgs...) - retArgs, err := bob.ExpressIf(w, d, start+len(args), u.Returning, + retArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), u.Returning, len(u.Returning.Expressions) > 0, "\n", "") if err != nil { return nil, err diff --git a/dialect/sqlite/load.go b/dialect/sqlite/load.go index 2ec60b74..6787a5d6 100644 --- a/dialect/sqlite/load.go +++ b/dialect/sqlite/load.go @@ -57,8 +57,7 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa o.ModifyPreloadSettings(&settings) } - return buildPreloader[T](func(ctx context.Context) (string, mods.QueryMods[*dialect.SelectQuery]) { - parent, _ := ctx.Value(orm.CtxLoadParentAlias).(string) + return buildPreloader[T](func(parent string) (string, mods.QueryMods[*dialect.SelectQuery]) { if parent == "" { parent = rel.Sides[0].From } @@ -89,7 +88,7 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa } queryMods = append(queryMods, sm. - LeftJoin(side.ToExpr(ctx)). + LeftJoin(orm.SchemaTable(side.To)). As(alias). On(on...)) @@ -104,17 +103,16 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa }, rel.Name, settings) } -func buildPreloader[T any](f func(context.Context) (string, mods.QueryMods[*dialect.SelectQuery]), name string, opt PreloadSettings) Preloader { - return func(ctx context.Context) (bob.Mod[*dialect.SelectQuery], scan.MapperMod, []bob.Loader) { - alias, queryMods := f(ctx) +func buildPreloader[T any](f func(string) (string, mods.QueryMods[*dialect.SelectQuery]), name string, opt PreloadSettings) Preloader { + return func(parent string) (bob.Mod[*dialect.SelectQuery], scan.MapperMod, []bob.Loader) { + alias, queryMods := f(parent) prefix := alias + "." var mapperMods []scan.MapperMod extraLoaders := []bob.Loader{opt.ExtraLoader} - ctx = context.WithValue(ctx, orm.CtxLoadParentAlias, alias) for _, l := range opt.SubLoaders { - queryMod, mapperMod, extraLoader := l(ctx) + queryMod, mapperMod, extraLoader := l(alias) if queryMod != nil { queryMods = append(queryMods, queryMod) } diff --git a/dialect/sqlite/table.go b/dialect/sqlite/table.go index dd303855..560bcc3d 100644 --- a/dialect/sqlite/table.go +++ b/dialect/sqlite/table.go @@ -104,7 +104,7 @@ func (t *Table[T, Tslice, Tset]) InsertMany(ctx context.Context, exec bob.Execut columns := rows[0].SetColumns() q := Insert( - im.Into(t.NameAs(ctx), columns...), + im.Into(t.NameAs(), columns...), im.Returning(t.Columns()), ) @@ -148,7 +148,7 @@ func (t *Table[T, Tslice, Tset]) Update(ctx context.Context, exec bob.Executor, pkPairs[i] = row.PrimaryKeyVals() } - q := Update(um.Table(t.NameAs(ctx)), vals, um.Where(t.pkExpr.In(pkPairs...))) + q := Update(um.Table(t.NameAs()), vals, um.Where(t.pkExpr.In(pkPairs...))) ctx, err = t.UpdateQueryHooks.Do(ctx, exec, q.Expression) if err != nil { @@ -224,7 +224,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut } q := Insert( - im.Into(t.NameAs(ctx), columns...), + im.Into(t.NameAs(), columns...), im.Returning(t.Columns()), conflictQM, ) @@ -268,7 +268,7 @@ func (t *Table[T, Tslice, Tset]) Delete(ctx context.Context, exec bob.Executor, pkPairs[i] = row.PrimaryKeyVals() } - q := Delete(dm.From(t.NameAs(ctx)), dm.Where(t.pkExpr.In(pkPairs...))) + q := Delete(dm.From(t.NameAs()), dm.Where(t.pkExpr.In(pkPairs...))) ctx, err = t.DeleteQueryHooks.Do(ctx, exec, q.Expression) if err != nil { @@ -289,7 +289,7 @@ func (t *Table[T, Tslice, Tset]) Delete(ctx context.Context, exec bob.Executor, // Starts an insert query for this table func (t *Table[T, Tslice, Tset]) InsertQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.InsertQuery]) *TQuery[*dialect.InsertQuery, T, Tslice] { q := &TQuery[*dialect.InsertQuery, T, Tslice]{ - BaseQuery: Insert(im.Into(t.NameAs(ctx), internal.FilterNonZero(t.setterMapping.NonGenerated)...)), + BaseQuery: Insert(im.Into(t.NameAs(), internal.FilterNonZero(t.setterMapping.NonGenerated)...)), ctx: ctx, exec: exec, view: t.View, @@ -305,7 +305,7 @@ func (t *Table[T, Tslice, Tset]) InsertQ(ctx context.Context, exec bob.Executor, // Starts an update query for this table func (t *Table[T, Tslice, Tset]) UpdateQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.UpdateQuery]) *TQuery[*dialect.UpdateQuery, T, Tslice] { q := &TQuery[*dialect.UpdateQuery, T, Tslice]{ - BaseQuery: Update(um.Table(t.NameAs(ctx))), + BaseQuery: Update(um.Table(t.NameAs())), ctx: ctx, exec: exec, view: t.View, @@ -321,7 +321,7 @@ func (t *Table[T, Tslice, Tset]) UpdateQ(ctx context.Context, exec bob.Executor, // Starts a delete query for this table func (t *Table[T, Tslice, Tset]) DeleteQ(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.DeleteQuery]) *TQuery[*dialect.DeleteQuery, T, Tslice] { q := &TQuery[*dialect.DeleteQuery, T, Tslice]{ - BaseQuery: Delete(dm.From(t.NameAs(ctx))), + BaseQuery: Delete(dm.From(t.NameAs())), ctx: ctx, exec: exec, view: t.View, diff --git a/dialect/sqlite/view.go b/dialect/sqlite/view.go index 3eb7e7cb..0a99d22a 100644 --- a/dialect/sqlite/view.go +++ b/dialect/sqlite/view.go @@ -62,18 +62,17 @@ type View[T any, Tslice ~[]T] struct { SelectQueryHooks orm.Hooks[*dialect.SelectQuery, orm.SkipQueryHooksKey] } -func (v *View[T, Tslice]) Name(ctx context.Context) Expression { +func (v *View[T, Tslice]) Name() Expression { // schema is not empty, never override if v.schema != "" { return Quote(v.schema, v.name) } - schema, _ := ctx.Value(orm.CtxUseSchema).(string) - return Quote(schema, v.name) + return Expression{}.New(orm.SchemaTable(v.name)) } -func (v *View[T, Tslice]) NameAs(ctx context.Context) bob.Expression { - return v.Name(ctx).As(v.alias) +func (v *View[T, Tslice]) NameAs() bob.Expression { + return v.Name().As(v.alias) } // Returns a column list @@ -85,13 +84,12 @@ func (v *View[T, Tslice]) Columns() orm.Columns { // Adds table name et al func (v *View[T, Tslice]) Query(ctx context.Context, exec bob.Executor, queryMods ...bob.Mod[*dialect.SelectQuery]) *ViewQuery[T, Tslice] { q := &ViewQuery[T, Tslice]{ - BaseQuery: Select(sm.From(v.NameAs(ctx))), + BaseQuery: Select(sm.From(v.NameAs())), ctx: ctx, exec: exec, view: v, } - q.Expression.SetLoadContext(ctx) q.Apply(queryMods...) return q @@ -128,23 +126,23 @@ type ViewQuery[T any, Ts ~[]T] struct { } // it is necessary to override this method to be able to add columns if not set -func (v ViewQuery[T, Ts]) WriteSQL(w io.Writer, _ bob.Dialect, start int) ([]any, error) { +func (v ViewQuery[T, Ts]) WriteSQL(ctx context.Context, w io.Writer, _ bob.Dialect, start int) ([]any, error) { // Append the table columns if len(v.BaseQuery.Expression.SelectList.Columns) == 0 { v.BaseQuery.Expression.AppendSelect(v.view.Columns()) } - return v.Expression.WriteSQL(w, v.Dialect, start) + return v.Expression.WriteSQL(ctx, w, v.Dialect, start) } // it is necessary to override this method to be able to add columns if not set -func (v ViewQuery[T, Ts]) WriteQuery(w io.Writer, start int) ([]any, error) { +func (v ViewQuery[T, Ts]) WriteQuery(ctx context.Context, w io.Writer, start int) ([]any, error) { // Append the table columns if len(v.BaseQuery.Expression.SelectList.Columns) == 0 { v.BaseQuery.Expression.AppendSelect(v.view.Columns()) } - return v.BaseQuery.WriteQuery(w, start) + return v.BaseQuery.WriteQuery(ctx, w, start) } func (v *ViewQuery[T, Ts]) hook() error { diff --git a/exec.go b/exec.go index 5824d483..dff8dece 100644 --- a/exec.go +++ b/exec.go @@ -25,7 +25,7 @@ type Executor interface { } func Exec(ctx context.Context, exec Executor, q Query) (sql.Result, error) { - sql, args, err := Build(q) + sql, args, err := Build(ctx, q) if err != nil { return nil, err } @@ -54,7 +54,7 @@ func One[T any](ctx context.Context, exec Executor, q Query, m scan.Mapper[T], o var t T - sql, args, err := Build(q) + sql, args, err := Build(ctx, q) if err != nil { return t, err } @@ -100,7 +100,7 @@ func Allx[T any, Ts ~[]T](ctx context.Context, exec Executor, q Query, m scan.Ma opt(&settings) } - sql, args, err := Build(q) + sql, args, err := Build(ctx, q) if err != nil { return nil, err } @@ -142,7 +142,7 @@ func Cursor[T any](ctx context.Context, exec Executor, q Query, m scan.Mapper[T] opt(&settings) } - sql, args, err := Build(q) + sql, args, err := Build(ctx, q) if err != nil { return nil, err } diff --git a/expr/arg.go b/expr/arg.go index eb07592b..d8dc0091 100644 --- a/expr/arg.go +++ b/expr/arg.go @@ -1,6 +1,7 @@ package expr import ( + "context" "io" "github.com/stephenafamo/bob" @@ -20,7 +21,7 @@ type args struct { grouped bool } -func (a args) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (a args) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if len(a.vals) == 0 { return nil, nil } diff --git a/expr/cast.go b/expr/cast.go index a666475f..9957a7ff 100644 --- a/expr/cast.go +++ b/expr/cast.go @@ -1,6 +1,7 @@ package expr import ( + "context" "io" "github.com/stephenafamo/bob" @@ -15,6 +16,6 @@ type cast struct { typname string } -func (c cast) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressIf(w, d, start, c.e, c.e != nil, "CAST(", " AS "+c.typname+")") +func (c cast) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.ExpressIf(ctx, w, d, start, c.e, c.e != nil, "CAST(", " AS "+c.typname+")") } diff --git a/expr/chain.go b/expr/chain.go index df1109fb..dbf09467 100644 --- a/expr/chain.go +++ b/expr/chain.go @@ -1,6 +1,7 @@ package expr import ( + "context" "io" "github.com/stephenafamo/bob" @@ -11,8 +12,8 @@ type Chain[T bob.Expression, B builder[T]] struct { } // WriteSQL satisfies the bob.Expression interface -func (x Chain[T, B]) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.Express(w, d, start, x.Base) +func (x Chain[T, B]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return bob.Express(ctx, w, d, start, x.Base) } // IS DISTINCT FROM diff --git a/expr/group.go b/expr/group.go index 33386c11..80d0fec8 100644 --- a/expr/group.go +++ b/expr/group.go @@ -1,6 +1,7 @@ package expr import ( + "context" "io" "github.com/stephenafamo/bob" @@ -9,10 +10,10 @@ import ( // Multiple expressions that will be group together as a single expression type group []bob.Expression -func (g group) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (g group) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if len(g) == 0 { - return bob.ExpressIf(w, d, start, null, true, openPar, closePar) + return bob.ExpressIf(ctx, w, d, start, null, true, openPar, closePar) } - return bob.ExpressSlice(w, d, start, g, openPar, commaSpace, closePar) + return bob.ExpressSlice(ctx, w, d, start, g, openPar, commaSpace, closePar) } diff --git a/expr/operators.go b/expr/operators.go index 4e854dcd..527d8ba4 100644 --- a/expr/operators.go +++ b/expr/operators.go @@ -1,6 +1,7 @@ package expr import ( + "context" "fmt" "io" @@ -14,15 +15,15 @@ type leftRight struct { left any } -func (lr leftRight) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - largs, err := bob.Express(w, d, start, lr.left) +func (lr leftRight) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + largs, err := bob.Express(ctx, w, d, start, lr.left) if err != nil { return nil, err } fmt.Fprintf(w, " %s ", lr.operator) - rargs, err := bob.Express(w, d, start+len(largs), lr.right) + rargs, err := bob.Express(ctx, w, d, start+len(largs), lr.right) if err != nil { return nil, err } @@ -45,11 +46,11 @@ type Join struct { Sep string } -func (s Join) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s Join) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { sep := s.Sep if sep == "" { sep = " " } - return bob.ExpressSlice(w, d, start, s.Exprs, "", sep, "") + return bob.ExpressSlice(ctx, w, d, start, s.Exprs, "", sep, "") } diff --git a/expr/quote.go b/expr/quote.go index ae5a2819..fd7e0156 100644 --- a/expr/quote.go +++ b/expr/quote.go @@ -1,6 +1,7 @@ package expr import ( + "context" "io" "github.com/stephenafamo/bob" @@ -21,7 +22,7 @@ func Quote(aa ...string) bob.Expression { // quoted and joined... something like "users"."id" type quoted []string -func (q quoted) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (q quoted) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if len(q) == 0 { return nil, nil } diff --git a/expr/raw.go b/expr/raw.go index d73882f8..2ebfba08 100644 --- a/expr/raw.go +++ b/expr/raw.go @@ -1,6 +1,7 @@ package expr import ( + "context" "errors" "fmt" "io" @@ -11,7 +12,7 @@ import ( type Raw []byte -func (r Raw) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (r Raw) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write(r) return nil, nil } @@ -29,9 +30,9 @@ type Clause struct { args []any // The replacements for the placeholders in order } -func (r Clause) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (r Clause) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { // replace the args with positional args appropriately - total, args, err := r.convertQuestionMarks(w, d, start) + total, args, err := r.convertQuestionMarks(ctx, w, d, start) if err != nil { return nil, err } @@ -46,7 +47,7 @@ func (r Clause) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { // convertQuestionMarks converts each occurrence of ? with $ // where is an incrementing digit starting at startAt. // If question-mark (?) is escaped using back-slash (\), it will be ignored. -func (r Clause) convertQuestionMarks(w io.Writer, d bob.Dialect, startAt int) (int, []any, error) { +func (r Clause) convertQuestionMarks(ctx context.Context, w io.Writer, d bob.Dialect, startAt int) (int, []any, error) { if startAt == 0 { panic("Not a valid start number.") } @@ -83,7 +84,7 @@ func (r Clause) convertQuestionMarks(w io.Writer, d bob.Dialect, startAt int) (i arg = r.args[total] } if ex, ok := arg.(bob.Expression); ok { - eargs, err := ex.WriteSQL(w, d, startAt) + eargs, err := ex.WriteSQL(ctx, w, d, startAt) if err != nil { return total, nil, err } diff --git a/expr/string.go b/expr/string.go index 1e42df42..d17ccb80 100644 --- a/expr/string.go +++ b/expr/string.go @@ -1,6 +1,7 @@ package expr import ( + "context" "io" "github.com/stephenafamo/bob" @@ -8,7 +9,7 @@ import ( type rawString string -func (s rawString) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (s rawString) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { w.Write([]byte("'")) w.Write([]byte(s)) w.Write([]byte("'")) diff --git a/expression.go b/expression.go index 4c53428c..7b3be1a1 100644 --- a/expression.go +++ b/expression.go @@ -1,6 +1,7 @@ package bob import ( + "context" "database/sql" "errors" "fmt" @@ -31,16 +32,16 @@ type Expression interface { // Writes the textual representation of the expression to the writer // using the given dialect. // start is the beginning index of the args if it needs to write any - WriteSQL(w io.Writer, d Dialect, start int) (args []any, err error) + WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) (args []any, err error) } -type ExpressionFunc func(w io.Writer, d Dialect, start int) ([]any, error) +type ExpressionFunc func(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) -func (e ExpressionFunc) WriteSQL(w io.Writer, d Dialect, start int) ([]any, error) { - return e(w, d, start) +func (e ExpressionFunc) WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) { + return e(ctx, w, d, start) } -func Express(w io.Writer, d Dialect, start int, e any) ([]any, error) { +func Express(ctx context.Context, w io.Writer, d Dialect, start int, e any) ([]any, error) { switch v := e.(type) { case string: w.Write([]byte(v)) @@ -56,7 +57,7 @@ func Express(w io.Writer, d Dialect, start int, e any) ([]any, error) { dn.WriteNamedArg(w, v.Name) return []any{v}, nil case Expression: - return v.WriteSQL(w, d, start) + return v.WriteSQL(ctx, w, d, start) default: fmt.Fprint(w, e) return nil, nil @@ -65,13 +66,13 @@ func Express(w io.Writer, d Dialect, start int, e any) ([]any, error) { // ExpressIf expands an express if the condition evaluates to true // it can also add a prefix and suffix -func ExpressIf(w io.Writer, d Dialect, start int, e any, cond bool, prefix, suffix string) ([]any, error) { +func ExpressIf(ctx context.Context, w io.Writer, d Dialect, start int, e any, cond bool, prefix, suffix string) ([]any, error) { if !cond { return nil, nil } w.Write([]byte(prefix)) - args, err := Express(w, d, start, e) + args, err := Express(ctx, w, d, start, e) if err != nil { return nil, err } @@ -81,7 +82,7 @@ func ExpressIf(w io.Writer, d Dialect, start int, e any, cond bool, prefix, suff } // ExpressSlice is used to express a slice of expressions along with a prefix and suffix -func ExpressSlice[T any](w io.Writer, d Dialect, start int, expressions []T, prefix, sep, suffix string) ([]any, error) { +func ExpressSlice[T any](ctx context.Context, w io.Writer, d Dialect, start int, expressions []T, prefix, sep, suffix string) ([]any, error) { if len(expressions) == 0 { return nil, nil } @@ -94,7 +95,7 @@ func ExpressSlice[T any](w io.Writer, d Dialect, start int, expressions []T, pre w.Write([]byte(sep)) } - newArgs, err := Express(w, d, start+len(args), e) + newArgs, err := Express(ctx, w, d, start+len(args), e) if err != nil { return args, err } diff --git a/expression_test.go b/expression_test.go index 21881db2..4c8acbda 100644 --- a/expression_test.go +++ b/expression_test.go @@ -2,6 +2,7 @@ package bob import ( "bytes" + "context" "database/sql" "errors" "io" @@ -24,7 +25,7 @@ func (d dialect) WriteQuoted(w io.Writer, s string) { w.Write([]byte(`"`)) } -var expression = ExpressionFunc(func(w io.Writer, d Dialect, start int) ([]any, error) { +var expression = ExpressionFunc(func(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) { w.Write([]byte("Hello ")) d.WriteArg(w, start) w.Write([]byte(" ")) @@ -54,7 +55,7 @@ func compare(t *testing.T, sqlExpected, sqlGotten string, argsExpected, argsGott func TestExpress(t *testing.T) { w := bytes.NewBuffer(nil) - args, err := Express(w, d, 2, expression) + args, err := Express(context.Background(), w, d, 2, expression) if err != nil { t.Fatalf("err while expressing") } @@ -88,7 +89,7 @@ func TestExpress2(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { w := bytes.NewBuffer(nil) - args, err := Express(w, d, 1, test.value) + args, err := Express(context.Background(), w, d, 1, test.value) if err != nil { t.Fatalf("err while expressing") } @@ -125,7 +126,7 @@ func TestExpressIf(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { w := bytes.NewBuffer(nil) - args, err := ExpressIf(w, d, 1, expression, test.cond, test.prefix, test.suffix) + args, err := ExpressIf(context.Background(), w, d, 1, expression, test.cond, test.prefix, test.suffix) if err != nil { t.Fatalf("err while expressing") } @@ -142,7 +143,7 @@ func TestExpressIf(t *testing.T) { func TestExpressEmptySlice(t *testing.T) { w := bytes.NewBuffer(nil) - args, err := ExpressSlice(w, d, 2, []string{}, "prefix ", ", ", " suffix") + args, err := ExpressSlice(context.Background(), w, d, 2, []string{}, "prefix ", ", ", " suffix") if err != nil { t.Fatalf("err while expressing") } @@ -152,7 +153,7 @@ func TestExpressEmptySlice(t *testing.T) { func TestExpressSlice(t *testing.T) { w := bytes.NewBuffer(nil) - args, err := ExpressSlice(w, d, 2, []string{"one", "two", "three"}, "prefix ", ", ", " suffix") + args, err := ExpressSlice(context.Background(), w, d, 2, []string{"one", "two", "three"}, "prefix ", ", ", " suffix") if err != nil { t.Fatalf("err while expressing") } @@ -170,7 +171,7 @@ func (d dialectWithNamed) WriteNamedArg(w io.Writer, name string) { func TestNamedArgs(t *testing.T) { arg := sql.Named("name", "value") w := bytes.NewBuffer(nil) - args, err := Express(w, dialectWithNamed{}, 1, arg) + args, err := Express(context.Background(), w, dialectWithNamed{}, 1, arg) if err != nil { t.Fatalf("err while expressing") } @@ -181,7 +182,7 @@ func TestNamedArgs(t *testing.T) { func TestErrNoNamedArgs(t *testing.T) { arg := sql.Named("name", "value") w := bytes.NewBuffer(nil) - _, err := Express(w, d, 1, arg) + _, err := Express(context.Background(), w, d, 1, arg) if !errors.Is(err, ErrNoNamedArgs) { t.Fatalf("Expected to get ErrNoNamedArgs but got %v", err) } diff --git a/gen/bobgen-atlas/driver/atlas_test.go b/gen/bobgen-atlas/driver/atlas_test.go index ce77817d..af6b8942 100644 --- a/gen/bobgen-atlas/driver/atlas_test.go +++ b/gen/bobgen-atlas/driver/atlas_test.go @@ -13,7 +13,7 @@ import ( "github.com/stephenafamo/bob/gen" helpers "github.com/stephenafamo/bob/gen/bobgen-helpers" "github.com/stephenafamo/bob/gen/drivers" - testutils "github.com/stephenafamo/bob/test/utils" + testgen "github.com/stephenafamo/bob/test/gen" ) //go:embed test_schema @@ -86,7 +86,7 @@ func testDialect(t *testing.T, tt testCase) { os.RemoveAll(out) }() - testutils.TestDriver(t, testutils.DriverTestConfig[any]{ + testgen.TestDriver(t, testgen.DriverTestConfig[any]{ Root: out, GetDriver: func() drivers.Interface[any] { return New(tt.config, tt.schema) diff --git a/gen/bobgen-mysql/driver/mysql_test.go b/gen/bobgen-mysql/driver/mysql_test.go index 203205b3..efd185b4 100644 --- a/gen/bobgen-mysql/driver/mysql_test.go +++ b/gen/bobgen-mysql/driver/mysql_test.go @@ -16,7 +16,7 @@ import ( helpers "github.com/stephenafamo/bob/gen/bobgen-helpers" "github.com/stephenafamo/bob/gen/drivers" testfiles "github.com/stephenafamo/bob/test/files" - testutils "github.com/stephenafamo/bob/test/utils" + testgen "github.com/stephenafamo/bob/test/gen" ) var ( @@ -175,7 +175,7 @@ func TestDriver(t *testing.T) { os.RemoveAll(out) }() - testutils.TestDriver(t, testutils.DriverTestConfig[any]{ + testgen.TestDriver(t, testgen.DriverTestConfig[any]{ Root: out, GetDriver: func() drivers.Interface[any] { return New(tt.config) diff --git a/gen/bobgen-prisma/driver/prisma_test.go b/gen/bobgen-prisma/driver/prisma_test.go index 1df80275..b6ace533 100644 --- a/gen/bobgen-prisma/driver/prisma_test.go +++ b/gen/bobgen-prisma/driver/prisma_test.go @@ -13,7 +13,7 @@ import ( "github.com/stephenafamo/bob/gen" helpers "github.com/stephenafamo/bob/gen/bobgen-helpers" "github.com/stephenafamo/bob/gen/drivers" - testutils "github.com/stephenafamo/bob/test/utils" + testgen "github.com/stephenafamo/bob/test/gen" ) //go:embed test_data_model.json @@ -89,7 +89,7 @@ func testDialect(t *testing.T, tt testCase) { os.RemoveAll(out) }() - testutils.TestDriver(t, testutils.DriverTestConfig[Extra]{ + testgen.TestDriver(t, testgen.DriverTestConfig[Extra]{ Root: out, GetDriver: func() drivers.Interface[Extra] { return New(Config{}, tt.name, tt.provider, dataModel) diff --git a/gen/bobgen-psql/driver/psql_test.go b/gen/bobgen-psql/driver/psql_test.go index 173e4a23..be163d65 100644 --- a/gen/bobgen-psql/driver/psql_test.go +++ b/gen/bobgen-psql/driver/psql_test.go @@ -16,7 +16,7 @@ import ( helpers "github.com/stephenafamo/bob/gen/bobgen-helpers" "github.com/stephenafamo/bob/gen/drivers" testfiles "github.com/stephenafamo/bob/test/files" - testutils "github.com/stephenafamo/bob/test/utils" + testgen "github.com/stephenafamo/bob/test/gen" ) var flagOverwriteGolden = flag.Bool("overwrite-golden", false, "Overwrite the golden file with the current execution results") @@ -157,7 +157,7 @@ func TestDriver(t *testing.T) { os.RemoveAll(out) }() - testutils.TestDriver(t, testutils.DriverTestConfig[any]{ + testgen.TestDriver(t, testgen.DriverTestConfig[any]{ Root: out, GetDriver: func() drivers.Interface[any] { return New(tt.config) diff --git a/gen/bobgen-sql/driver/sql_test.go b/gen/bobgen-sql/driver/sql_test.go index 7cffd1b6..923a7244 100644 --- a/gen/bobgen-sql/driver/sql_test.go +++ b/gen/bobgen-sql/driver/sql_test.go @@ -10,7 +10,7 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/stephenafamo/bob/gen/drivers" testfiles "github.com/stephenafamo/bob/test/files" - testutils "github.com/stephenafamo/bob/test/utils" + testgen "github.com/stephenafamo/bob/test/gen" ) func TestPostgres(t *testing.T) { @@ -22,7 +22,7 @@ func TestPostgres(t *testing.T) { fs: testfiles.PostgresSchema, } - testutils.TestDriver(t, testutils.DriverTestConfig[any]{ + testgen.TestDriver(t, testgen.DriverTestConfig[any]{ Root: out, GetDriver: func() drivers.Interface[any] { d, err := getPsqlDriver(context.Background(), config) @@ -46,7 +46,7 @@ func TestSQLite(t *testing.T) { Schemas: []string{"one"}, } - testutils.TestDriver(t, testutils.DriverTestConfig[any]{ + testgen.TestDriver(t, testgen.DriverTestConfig[any]{ Root: out, GetDriver: func() drivers.Interface[any] { d, err := getSQLiteDriver(context.Background(), config) diff --git a/gen/bobgen-sqlite/driver/sqlite_test.go b/gen/bobgen-sqlite/driver/sqlite_test.go index 41e85c5e..3300b538 100644 --- a/gen/bobgen-sqlite/driver/sqlite_test.go +++ b/gen/bobgen-sqlite/driver/sqlite_test.go @@ -17,7 +17,7 @@ import ( helpers "github.com/stephenafamo/bob/gen/bobgen-helpers" "github.com/stephenafamo/bob/gen/drivers" testfiles "github.com/stephenafamo/bob/test/files" - testutils "github.com/stephenafamo/bob/test/utils" + testgen "github.com/stephenafamo/bob/test/gen" "modernc.org/sqlite" ) @@ -191,7 +191,7 @@ func TestAssemble(t *testing.T) { os.RemoveAll(out) }() - testutils.TestDriver(t, testutils.DriverTestConfig[any]{ + testgen.TestDriver(t, testgen.DriverTestConfig[any]{ Root: out, GetDriver: func() drivers.Interface[any] { return New(tt.config) diff --git a/gen/templates/models/09_rel_query.go.tpl b/gen/templates/models/09_rel_query.go.tpl index d2ec9cf4..82ffcc2e 100644 --- a/gen/templates/models/09_rel_query.go.tpl +++ b/gen/templates/models/09_rel_query.go.tpl @@ -29,7 +29,7 @@ func {{$tAlias.DownPlural}}Join{{$relAlias}}[Q dialect.Joinable](from {{$tAlias. {{if ne $index (sub (len $rel.Sides) 1) -}} to := {{$toCols}}.AliasedAs({{$toCols}}.Alias() + random) {{end -}} - mods = append(mods, dialect.Join[Q](typ, {{$to.UpPlural}}.Name(ctx).As(to.Alias())).On( + mods = append(mods, dialect.Join[Q](typ, {{$to.UpPlural}}.Name().As(to.Alias())).On( {{range $i, $local := $side.FromColumns -}} {{- $fromCol := index $from.Columns $local -}} {{- $toCol := index $to.Columns (index $side.ToColumns $i) -}} @@ -68,7 +68,7 @@ func (o *{{$tAlias.UpSingular}}) {{relQueryMethodName $tAlias $relAlias}}(ctx co {{- $to := $.Aliases.Table $side.To -}} {{- $fromTable := getTable $.Tables $side.From -}} {{- if gt $index 0 -}} - sm.InnerJoin({{$from.UpPlural}}.NameAs(ctx)).On( + sm.InnerJoin({{$from.UpPlural}}.NameAs()).On( {{end -}} {{range $i, $local := $side.FromColumns -}} {{- $fromCol := index $from.Columns $local -}} @@ -128,7 +128,7 @@ func (os {{$tAlias.UpSingular}}Slice) {{relQueryMethodName $tAlias $relAlias}}(c {{- $to := $.Aliases.Table $side.To -}} {{- $fromTable := getTable $.Tables $side.From -}} {{- if gt $index 0 -}} - sm.InnerJoin({{$from.UpPlural}}.NameAs(ctx)).On( + sm.InnerJoin({{$from.UpPlural}}.NameAs()).On( {{range $i, $local := $side.FromColumns -}} {{- $foreign := index $side.ToColumns $i -}} {{- $fromCol := index $from.Columns $local -}} diff --git a/gen/templates/models/10_rel_load.go.tpl b/gen/templates/models/10_rel_load.go.tpl index 1fc43ba7..f64791c4 100644 --- a/gen/templates/models/10_rel_load.go.tpl +++ b/gen/templates/models/10_rel_load.go.tpl @@ -86,11 +86,8 @@ func Preload{{$tAlias.UpSingular}}{{$relAlias}}(opts ...{{$.Dialect}}.PreloadOpt {{- $fromTable := getTable $.Tables $side.From -}} {{- $toTable = getTable $.Tables $side.To -}} { - From: {{quote $fromTable.Key}}, + From: TableNames.{{$from.UpPlural}}, To: TableNames.{{$to.UpPlural}}, - ToExpr: func(ctx context.Context) bob.Expression { - return {{$to.UpPlural}}.Name(ctx) - }, {{if $side.FromColumns -}} FromColumns: []string{ {{range $name := $side.FromColumns -}} diff --git a/go.mod b/go.mod index 22b7daa7..c6bbede4 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/volatiletech/strmangle v0.0.6 github.com/wasilibs/go-pgquery v0.0.0-20240319230125-b9b2e95c69a7 golang.org/x/mod v0.16.0 + golang.org/x/text v0.14.0 golang.org/x/tools v0.19.0 modernc.org/sqlite v1.20.3 mvdan.cc/gofumpt v0.5.0 @@ -78,7 +79,6 @@ require ( golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/protobuf v1.31.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 1deac288..6a638025 100644 --- a/go.sum +++ b/go.sum @@ -170,8 +170,6 @@ github.com/vmihailenco/msgpack/v4 v4.3.12/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+ github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/volatiletech/inflect v0.0.1 h1:2a6FcMQyhmPZcLa+uet3VJ8gLn/9svWhJxJYwvE8KsU= github.com/volatiletech/inflect v0.0.1/go.mod h1:IBti31tG6phkHitLlr5j7shC5SOo//x0AjDzaJU1PLA= -github.com/volatiletech/strmangle v0.0.4 h1:CxrEPhobZL/PCZOTDSH1aq7s4Kv76hQpRoTVVlUOim4= -github.com/volatiletech/strmangle v0.0.4/go.mod h1:ycDvbDkjDvhC0NUU8w3fWwl5JEMTV56vTKXzR3GeR+0= github.com/volatiletech/strmangle v0.0.6 h1:AdOYE3B2ygRDq4rXDij/MMwq6KVK/pWAYxpC7CLrkKQ= github.com/volatiletech/strmangle v0.0.6/go.mod h1:ycDvbDkjDvhC0NUU8w3fWwl5JEMTV56vTKXzR3GeR+0= github.com/wasilibs/go-pgquery v0.0.0-20240319230125-b9b2e95c69a7 h1:sqqLVb63En4uTKFKBWSJ7c1aIFonhM1yn35/+KchOf4= diff --git a/internal/load.go b/internal/load.go index d6f94bbe..b8de54ad 100644 --- a/internal/load.go +++ b/internal/load.go @@ -15,7 +15,6 @@ type Preloadable interface { } type loadable interface { - GetLoadContext() context.Context AppendLoader(f ...bob.Loader) AppendMapperMod(f scan.MapperMod) } @@ -100,14 +99,14 @@ func (filters PreloadWhere[Q]) ModifyPreloadSettings(el *PreloadSettings[Q]) { // while it can be used as a queryMod, it does not have any direct effect. // if using manually, the ApplyPreload method should be called // with the query's context AFTER other mods have been applied -type Preloader[Q loadable] func(ctx context.Context) (bob.Mod[Q], scan.MapperMod, []bob.Loader) +type Preloader[Q loadable] func(parent string) (bob.Mod[Q], scan.MapperMod, []bob.Loader) // Apply satisfies bob.Mod[*dialect.SelectQuery]. // 1. It modifies the query to join the preloading table and the extra columns to retrieve // 2. It modifies the mapper to scan the new columns. // 3. It calls the original object's Preload method with the loaded object func (l Preloader[Q]) Apply(q Q) { - mod, mapperMod, afterLoaders := l(q.GetLoadContext()) + mod, mapperMod, afterLoaders := l("") mod.Apply(q) // add preload columns q.AppendMapperMod(mapperMod) // add mapper diff --git a/internal/reflect_test.go b/internal/reflect_test.go index 9b37cc99..68027ef3 100644 --- a/internal/reflect_test.go +++ b/internal/reflect_test.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "database/sql/driver" "fmt" "io" @@ -248,7 +249,7 @@ type expression struct { func expTransformer(e bob.Expression) expression { buf := &bytes.Buffer{} - args, err := e.WriteSQL(buf, dialect{}, 1) + args, err := e.WriteSQL(context.Background(), buf, dialect{}, 1) return expression{ Query: buf.String(), diff --git a/load.go b/load.go index 9e68d4cd..8646a5e7 100644 --- a/load.go +++ b/load.go @@ -26,21 +26,10 @@ type ( // Load is an embeddable struct that enables Preloading and AfterLoading type Load struct { - loadContext context.Context loadFuncs []Loader preloadMapperMods []scan.MapperMod } -// GetLoadContext -func (l *Load) GetLoadContext() context.Context { - return l.loadContext -} - -// SetLoadContext -func (l *Load) SetLoadContext(ctx context.Context) { - l.loadContext = ctx -} - func (l *Load) SetMapperMods(mods ...scan.MapperMod) { l.preloadMapperMods = mods } diff --git a/mods/conflict_test.go b/mods/conflict_test.go new file mode 100644 index 00000000..592cb70f --- /dev/null +++ b/mods/conflict_test.go @@ -0,0 +1,8 @@ +package mods + +import ( + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/clause" +) + +var _ bob.Mod[interface{ SetConflict(clause.Conflict) }] = Conflict[interface{ SetConflict(clause.Conflict) }](nil) diff --git a/mods/mods.go b/mods/mods.go index 05b3f17c..bde6d583 100644 --- a/mods/mods.go +++ b/mods/mods.go @@ -1,6 +1,7 @@ package mods import ( + "context" "io" "github.com/stephenafamo/bob" @@ -26,8 +27,8 @@ func (q QueryModFunc[T]) Apply(query T) { // allows for some fluent API, for example with functions type Moddable[T bob.Expression] func(...bob.Mod[T]) T -func (m Moddable[T]) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return m().WriteSQL(w, d, start) +func (m Moddable[T]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + return m().WriteSQL(ctx, w, d, start) } type With[Q interface{ AppendWith(clause.CTE) }] clause.CTE diff --git a/mods/mods_test.go b/mods/mods_test.go new file mode 100644 index 00000000..36a94348 --- /dev/null +++ b/mods/mods_test.go @@ -0,0 +1,12 @@ +package mods + +import ( + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/clause" +) + +var ( + _ bob.Mod[any] = QueryMods[any](nil) + _ bob.Mod[any] = QueryModFunc[any](nil) + _ bob.Mod[interface{ AppendWith(clause.CTE) }] = With[interface{ AppendWith(clause.CTE) }]{} +) diff --git a/orm/columns.go b/orm/columns.go index 9aaff055..2e41a096 100644 --- a/orm/columns.go +++ b/orm/columns.go @@ -1,6 +1,7 @@ package orm import ( + "context" "io" "github.com/stephenafamo/bob" @@ -56,7 +57,7 @@ func (c Columns) Except(cols ...string) Columns { return c } -func (c Columns) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { +func (c Columns) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { if len(c.names) == 0 { return nil, nil } diff --git a/orm/context.go b/orm/context.go index 0b5cbba9..163412a8 100644 --- a/orm/context.go +++ b/orm/context.go @@ -3,10 +3,8 @@ package orm type ctxKey int const ( - // The alias of an eager loader's parent - CtxLoadParentAlias ctxKey = iota // A schema to use when non was specified during generation - CtxUseSchema + CtxUseSchema ctxKey = iota ) type ( diff --git a/orm/relationship.go b/orm/relationship.go index 03f67d62..299f5ea9 100644 --- a/orm/relationship.go +++ b/orm/relationship.go @@ -1,11 +1,8 @@ package orm import ( - "context" "fmt" "sort" - - "github.com/stephenafamo/bob" ) type RelWhere struct { @@ -57,9 +54,6 @@ type RelSide struct { // relationship without deleting it // this is set in Relationships.init() KeyNullable bool `yaml:"-"` - - // Kinda hacky, used for preloading - ToExpr func(context.Context) bob.Expression `json:"-" yaml:"-"` } type Relationship struct { diff --git a/orm/types.go b/orm/types.go index 4fcf06f3..3feb908e 100644 --- a/orm/types.go +++ b/orm/types.go @@ -1,6 +1,12 @@ package orm -import "github.com/stephenafamo/bob" +import ( + "context" + "io" + + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/expr" +) type Table interface { // PrimaryKeyVals returns the values of the primary key columns @@ -19,3 +25,10 @@ type Setter[T any, InsertQ any, UpdateQ any] interface { // Return a mod for the insert query InsertMod() bob.Mod[InsertQ] } + +type SchemaTable string + +func (s SchemaTable) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { + schema, _ := ctx.Value(CtxUseSchema).(string) + return expr.Quote(schema, string(s)).WriteSQL(ctx, w, d, start) +} diff --git a/query.go b/query.go index bb00dcc8..dafbdff4 100644 --- a/query.go +++ b/query.go @@ -24,7 +24,7 @@ type Query interface { // start is the index of the args, usually 1. // it is present to allow re-indexing in cases of a subquery // The method returns the value of any args placed - WriteQuery(w io.Writer, start int) (args []any, err error) + WriteQuery(ctx context.Context, w io.Writer, start int) (args []any, err error) } type Mod[T any] interface { @@ -32,6 +32,7 @@ type Mod[T any] interface { } var ( + _ Query = BaseQuery[Expression]{} _ Loadable = BaseQuery[Expression]{} _ MapperModder = BaseQuery[Expression]{} ) @@ -83,15 +84,15 @@ func (b BaseQuery[E]) Apply(mods ...Mod[E]) { } } -func (b BaseQuery[E]) WriteQuery(w io.Writer, start int) ([]any, error) { - return b.Expression.WriteSQL(w, b.Dialect, start) +func (b BaseQuery[E]) WriteQuery(ctx context.Context, w io.Writer, start int) ([]any, error) { + return b.Expression.WriteSQL(ctx, w, b.Dialect, start) } // Satisfies the Expression interface, but uses its own dialect instead // of the dialect passed to it -func (b BaseQuery[E]) WriteSQL(w io.Writer, _ Dialect, start int) ([]any, error) { +func (b BaseQuery[E]) WriteSQL(ctx context.Context, w io.Writer, _ Dialect, start int) ([]any, error) { w.Write([]byte(openPar)) - args, err := b.Expression.WriteSQL(w, b.Dialect, start) + args, err := b.Expression.WriteSQL(ctx, w, b.Dialect, start) w.Write([]byte(closePar)) return args, err @@ -99,32 +100,32 @@ func (b BaseQuery[E]) WriteSQL(w io.Writer, _ Dialect, start int) ([]any, error) // MustBuild builds the query and panics on error // useful for initializing queries that need to be reused -func (q BaseQuery[E]) MustBuild() (string, []any) { - return MustBuildN(q, 1) +func (q BaseQuery[E]) MustBuild(ctx context.Context) (string, []any) { + return MustBuildN(ctx, q, 1) } // MustBuildN builds the query and panics on error // start numbers the arguments from a different point -func (q BaseQuery[E]) MustBuildN(start int) (string, []any) { - return MustBuildN(q, start) +func (q BaseQuery[E]) MustBuildN(ctx context.Context, start int) (string, []any) { + return MustBuildN(ctx, q, start) } // Convinient function to build query from start -func (q BaseQuery[E]) Build() (string, []any, error) { - return BuildN(q, 1) +func (q BaseQuery[E]) Build(ctx context.Context) (string, []any, error) { + return BuildN(ctx, q, 1) } // Convinient function to build query from a point -func (q BaseQuery[E]) BuildN(start int) (string, []any, error) { - return BuildN(q, start) +func (q BaseQuery[E]) BuildN(ctx context.Context, start int) (string, []any, error) { + return BuildN(ctx, q, start) } // Convinient function to cache a query -func (q BaseQuery[E]) Cache() (BaseQuery[*cached], error) { - return CacheN(q, 1) +func (q BaseQuery[E]) Cache(ctx context.Context) (BaseQuery[*cached], error) { + return CacheN(ctx, q, 1) } // Convinient function to cache a query from a point -func (q BaseQuery[E]) CacheN(start int) (BaseQuery[*cached], error) { - return CacheN(q, start) +func (q BaseQuery[E]) CacheN(ctx context.Context, start int) (BaseQuery[*cached], error) { + return CacheN(ctx, q, start) } diff --git a/stmt.go b/stmt.go index 9d8123cf..519e890b 100644 --- a/stmt.go +++ b/stmt.go @@ -21,7 +21,7 @@ type Statement interface { // retains the expected methods used by *sql.Stmt // This is useful when an existing *sql.Stmt is used in other places in the codebase func Prepare(ctx context.Context, exec Preparer, q Query) (Stmt, error) { - query, args, err := Build(q) + query, args, err := Build(ctx, q) if err != nil { return Stmt{}, err } diff --git a/test/utils/gen.go b/test/gen/gen.go similarity index 99% rename from test/utils/gen.go rename to test/gen/gen.go index 61e033a5..bbdacb77 100644 --- a/test/utils/gen.go +++ b/test/gen/gen.go @@ -1,4 +1,4 @@ -package testutils +package testgen import ( "bufio" diff --git a/test/utils/utils.go b/test/utils/utils.go index 228ca703..a0866461 100644 --- a/test/utils/utils.go +++ b/test/utils/utils.go @@ -1,6 +1,7 @@ package testutils import ( + "context" "fmt" "regexp" "strings" @@ -64,7 +65,7 @@ func RunTests(t *testing.T, cases Testcases, format FormatFunc) { t.Helper() for name, tc := range cases { t.Run(name, func(t *testing.T) { - sql, args, err := bob.Build(tc.Query) + sql, args, err := bob.Build(context.Background(), tc.Query) if err != nil { t.Fatalf("error: %v", err) } @@ -100,7 +101,7 @@ func RunExpressionTests(t *testing.T, d bob.Dialect, cases ExpressionTestcases) for name, tc := range cases { t.Run(name, func(t *testing.T) { b := &strings.Builder{} - args, err := bob.Express(b, d, 1, tc.Expression) + args, err := bob.Express(context.Background(), b, d, 1, tc.Expression) sql := b.String() if diff := ErrDiff(tc.ExpectedError, err); diff != "" { diff --git a/website/docs/code-generation/usage.md b/website/docs/code-generation/usage.md index 33497eb3..19159b8b 100644 --- a/website/docs/code-generation/usage.md +++ b/website/docs/code-generation/usage.md @@ -296,7 +296,7 @@ For even more control, expressions are generated for every column to be used in // ORDER BY "jets"."pilot_id" psql.Select( sm.Columns(models.JetColumns.Name, "count(1)"), - sm.From(models.JetsTable.Name(ctx)), + sm.From(models.JetsTable.Name()), sm.Where(models.JetColumns.ID.Between(50, 5000)), sm.OrderBy(models.JetColumns.PilotID), ) diff --git a/website/docs/query-builder/using-queries.md b/website/docs/query-builder/using-queries.md index d5e35b37..d95ca798 100644 --- a/website/docs/query-builder/using-queries.md +++ b/website/docs/query-builder/using-queries.md @@ -1,8 +1,6 @@ --- - sidebar_position: 1.1 description: How to use queries built with Bob - --- # Using the Query @@ -17,19 +15,19 @@ type Query interface { // it is present to allow re-indexing in cases of a subquery // The method returns the value of any args placed // An `io.Writer` is used for efficiency when building the query. - WriteQuery(w io.Writer, start int) (args []any, err error) + WriteQuery(ctx context.Context, w io.Writer, start int) (args []any, err error) } ``` The `WriteQuery` method is useful when we want to write to an existing `io.Writer`. However, we often just want the query string and arguments. So the Query objects have the following methods: -* `Build() (query string, args []any, err error)` -* `BuildN(start int) (query string, args []any, err error)` -* `MustBuild() (query string, args []any) // panics on error` -* `MustBuildN(start int) (query string, args []any) // panics on error` +- `Build(ctx context.Context) (query string, args []any, err error)` +- `BuildN(ctx context.Context, start int) (query string, args []any, err error)` +- `MustBuild(ctx context.Context) (query string, args []any) // panics on error` +- `MustBuildN(ctx context.Context, start int) (query string, args []any) // panics on error` ```go -queryString, args, err := psql.Select(...).Build() +queryString, args, err := psql.Select(...).Build(ctx) ``` Since the query is built from scratch every time the `WriteQuery()` method is called, it can be useful to initialize the query one time and reuse where necessary. @@ -37,7 +35,7 @@ Since the query is built from scratch every time the `WriteQuery()` method is ca For that, the `MustBuild()` function can be used. This panics on error. ```go -myquery, myargs := psql.Insert(...).MustBuild() +myquery, myargs := psql.Insert(...).MustBuild(ctx) ``` ## Executing queries @@ -48,7 +46,7 @@ The returned `query` and `args` can then be passed to your querier (e.g. `*sql.D ctx := context.Background() // Build the query -myquery, myargs := psql.Insert(...).MustBuild() +myquery, myargs := psql.Insert(...).MustBuild(ctx) // Execute the query err := db.ExecContext(ctx, myquery, myargs...)