diff --git a/CHANGELOG.md b/CHANGELOG.md index a8600d0..2eb6ccb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Type() QueryType` method to `bob.Query` to get the type of query it is. Available constants are `Unknown, Select, Insert, Update, Delete`. - Postgres and SQLite Update/Delete queries now refresh the models after the query is executed. This is enabled by the `RETURNING` clause, so it is not available in MySQL. - Added the `Case()` starter to all dialects to build `CASE` expressions. (thanks @k4n4ry) +- Added `bob.Named()` which is used to add named arguments to the query and bind them later. +- Added `bob.BindNamed` which takes an argument (struct, map, or a single value type) to be used to bind named arguments in a query. See changes to `bob.Prepare()` for details of which type can be used. ### Changed @@ -78,6 +80,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BeforeInsertHooks` now only takes a single `ModelSetter` at a time. This is because it is not possible to know before executing the queries exactly how many setters are being used since additional rows can be inserted by applying another setter as a mod. - `bob.Cache()` now requires an `Executor`. This is used to run any query hooks. +- `bob.Prepare()` now requires a type parameter to be used to bind named arguments. The type can either be: + - A struct with fields that match the named arguments in the query + - A map with string keys. When supplied, the values in the map will be used to bind the named arguments in the query. + - When there is only a single named argument, one of the following can be used: + - A primitive type (int, bool, string, etc) + - `time.Time` + - Any type that implements `driver.Valuer`. ### Removed @@ -89,6 +98,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove `Update` and `Delete` methods from `orm.Table` since they are not needed. It is possible to do the same thing, with similar effor using the the `UpdateQ` and `DeleteQ` methods (which are now renamed to `Update` and `Delete`). - `context.Context` and `bob.Executor` are no longer passed when creating a Table/ViewQuery. It is now passed at the point of execution with `Exec/One/All/Cursor`. +- Remove `Prepare` methods from table and view qureries. Since `bob.Prepare()` now takes a type parameter, it is not possible to prepare from a method since Go does not allow additional type parameters in methods. ### Fixed diff --git a/binder.go b/binder.go new file mode 100644 index 0000000..a3a2a13 --- /dev/null +++ b/binder.go @@ -0,0 +1,323 @@ +package bob + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "time" + + "github.com/stephenafamo/bob/internal/mappings" +) + +//nolint:gochecknoglobals +var ( + ErrBadArgType = errors.New("bind type of multiple named args must be a struct, pointer to struct or map with ~string keys") + ErrTooManyNamedArgs = errors.New("too many named args for single arg binder") + driverValuerIntf = reflect.TypeFor[driver.Valuer]() + timeType = reflect.TypeFor[time.Time]() +) + +type MissingArgError struct{ Name string } + +func (e MissingArgError) Error() string { + return fmt.Sprintf("missing arg %s", e.Name) +} + +type binder[T any] interface { + // list returns the names of the args that the binder expects + list() []string + // Return the args to be run in the query + // this should also include any non-named args in the original query + toArgs(T) []any +} + +func bindArgs[Arg any](args []any, named Arg) ([]any, error) { + binder, err := makeBinder[Arg](args) + if err != nil { + return nil, err + } + + return binder.toArgs(named), nil +} + +func makeBinder[Arg any](args []any) (binder[Arg], error) { + namedArgs := countNamedArgs(args) + + switch namedArgs { + case 0: // no named args + return emptyBinder[Arg](args), nil + case 1: // only one named arg + return makeSingleArgBinder[Arg](args) + default: + return makeMultiArgBinder[Arg](args) + } +} + +func canUseAsSingleValue(typ reflect.Type) bool { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + switch typ.Kind() { + case reflect.Bool, reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return true + case reflect.Slice: + return typ.Elem().Kind() == reflect.Uint8 + } + + if typ == timeType { + return true + } + + if typ.Implements(driverValuerIntf) { + return true + } + + return false +} + +func makeSingleArgBinder[Arg any](args []any) (binder[Arg], error) { + typ := reflect.TypeFor[Arg]() + if !canUseAsSingleValue(typ) { + return makeMultiArgBinder[Arg](args) + } + + givenArg := make([]any, len(args)) + copy(givenArg, args) + + b := singleBinder[Arg]{givenArg: givenArg} + + for pos, arg := range args { + if name, ok := arg.(namedArg); ok { + b.argIndexs = append(b.argIndexs, pos) + b.name = string(name) + } + } + + return b, nil +} + +func makeMultiArgBinder[Arg any](args []any) (binder[Arg], error) { + typ := reflect.TypeFor[Arg]() + + switch typ.Kind() { + case reflect.Map: + if typ.Key().Kind() != reflect.String { + return nil, ErrBadArgType + } + + return makeMapBinder[Arg](args), nil + + case reflect.Struct: + return makeStructBinder[Arg](args) + + case reflect.Ptr: + if typ.Elem().Kind() == reflect.Struct { + return makeStructBinder[Arg](args) + } + } + + return nil, ErrBadArgType +} + +type emptyBinder[Arg any] []any + +func (b emptyBinder[Arg]) list() []string { + return nil +} + +func (b emptyBinder[Arg]) toArgs(arg Arg) []any { + return b +} + +func makeStructBinder[Arg any](args []any) (binder[Arg], error) { + typ := reflect.TypeFor[Arg]() + + isStruct := typ.Kind() == reflect.Struct + if typ.Kind() == reflect.Ptr { + isStruct = typ.Elem().Kind() == reflect.Struct + } + + if !isStruct { + return structBinder[Arg]{}, errors.New("bind type must be a struct") + } + + givenArg := make([]any, len(args)) + argPositions := make([]string, len(args)) + for pos, arg := range args { + if name, ok := arg.(namedArg); ok { + argPositions[pos] = string(name) + continue + } + + givenArg[pos] = arg + } + + fieldNames := mappings.GetMappings(typ).All + fieldPositions := make([]int, len(argPositions)) + + // check if all positions have matching fields +ArgLoop: + for argIndex, name := range argPositions { + if name == "" { + continue + } + + for fieldIndex, field := range fieldNames { + if field == name { + fieldPositions[argIndex] = fieldIndex + continue ArgLoop + } + } + return structBinder[Arg]{}, MissingArgError{Name: name} + } + + return structBinder[Arg]{ + args: argPositions, + fields: fieldPositions, + givenArg: givenArg, + }, nil +} + +type structBinder[Arg any] struct { + args []string + fields []int + givenArg []any +} + +func (b structBinder[Arg]) list() []string { + names := make([]string, len(b.args)) + for _, name := range b.args { + if name == "" { + continue + } + + names = append(names, name) + } + + return names +} + +func (b structBinder[Arg]) toArgs(arg Arg) []any { + isNil := false + val := reflect.ValueOf(arg) + if val.Kind() == reflect.Pointer { + isNil = val.IsNil() + val = val.Elem() + } + + values := make([]any, len(b.args)) + + for index, argName := range b.args { + if argName == "" { + values[index] = b.givenArg[index] + continue + } + + if isNil { + continue + } + + values[index] = val.Field(b.fields[index]).Interface() + } + + return values +} + +func makeMapBinder[Arg any](args []any) binder[Arg] { + givenArg := make([]any, len(args)) + argPositions := make([]string, len(args)) + for pos, arg := range args { + if name, ok := arg.(namedArg); ok { + argPositions[pos] = string(name) + continue + } + + givenArg[pos] = arg + } + + return mapBinder[Arg]{ + args: argPositions, + givenArg: givenArg, + } +} + +type mapBinder[Arg any] struct { + args []string + givenArg []any +} + +func (b mapBinder[Arg]) list() []string { + names := make([]string, len(b.args)) + for _, name := range b.args { + if name == "" { + continue + } + + names = append(names, name) + } + + return names +} + +func (b mapBinder[Arg]) toArgs(args Arg) []any { + values := make([]any, len(b.args)) + + for index, argName := range b.args { + if argName == "" { + values[index] = b.givenArg[index] + continue + } + + val := reflect.ValueOf(args).MapIndex(reflect.ValueOf(argName)) + if !val.IsValid() { + continue + } + + values[index] = val.Interface() + } + + return values +} + +type singleBinder[Arg any] struct { + givenArg []any + argIndexs []int + name string +} + +func (b singleBinder[Arg]) list() []string { + return []string{b.name} +} + +func (b singleBinder[Arg]) toArgs(arg Arg) []any { + values := make([]any, len(b.givenArg)) + copy(values, b.givenArg) + + for _, i := range b.argIndexs { + values[i] = arg + } + + return values +} + +func countNamedArgs(args []any) int { + names := map[string]struct{}{} + for _, arg := range args { + if name, ok := arg.(namedArg); ok { + names[string(name)] = struct{}{} + continue + } + + if name, ok := arg.(named); ok && len(name.names) == 1 { + names[name.names[0]] = struct{}{} + continue + } + } + + return len(names) +} diff --git a/binder_test.go b/binder_test.go new file mode 100644 index 0000000..7692547 --- /dev/null +++ b/binder_test.go @@ -0,0 +1,192 @@ +package bob + +import ( + "database/sql/driver" + "errors" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +type customString string + +type binderTester interface { + Run(t *testing.T, origin []any) +} + +type binderTests[Arg any] struct { + args Arg + final []any + err error +} + +func (s binderTests[Arg]) Run(t *testing.T, origin []any) { + t.Helper() + + t.Run("", func(t *testing.T) { + binder, err := makeBinder[Arg](origin) + if !errors.Is(err, s.err) { + t.Fatal(err) + } + + if s.err != nil { + return + } + + if diff := cmp.Diff( + s.final, binder.toArgs(s.args), cmpopts.EquateEmpty(), + ); diff != "" { + t.Fatal(diff) + } + }) +} + +func testBinder(t *testing.T, origin []any, tests []binderTester) { + t.Helper() + + for _, test := range tests { + test.Run(t, origin) + } +} + +func TestBinding(t *testing.T) { + t.Run("no args", func(t *testing.T) { + testBinder(t, []any{}, []binderTester{ + binderTests[struct{}]{ + args: struct{}{}, + final: []any{}, + }, + binderTests[map[customString]any]{ + args: nil, + final: []any{}, + }, + binderTests[int]{}, + }) + }) + + t.Run("no named", func(t *testing.T) { + testBinder(t, []any{1, 2, 3, 4}, []binderTester{ + binderTests[struct{}]{ + args: struct{}{}, + final: []any{1, 2, 3, 4}, + }, + binderTests[map[string]any]{ + args: nil, + final: []any{1, 2, 3, 4}, + }, + binderTests[int]{ + args: 0, + final: []any{1, 2, 3, 4}, + }, + }) + }) + + t.Run("all named", func(t *testing.T) { + testBinder(t, []any{namedArg("one"), namedArg("two"), namedArg("three"), namedArg("four")}, []binderTester{ + binderTests[struct{ One, Two, Three, Four int }]{ + args: struct{ One, Two, Three, Four int }{ + One: 1, Two: 2, Three: 3, Four: 4, + }, + final: []any{1, 2, 3, 4}, + }, + binderTests[map[string]int]{ + args: map[string]int{"one": 1, "two": 2, "three": 3, "four": 4}, + final: []any{1, 2, 3, 4}, + }, + binderTests[int]{ + err: ErrBadArgType, + }, + }) + }) + + t.Run("mixed named", func(t *testing.T) { + testBinder(t, []any{1, 2, namedArg("three"), 4}, []binderTester{ + binderTests[struct{ Three int }]{ + args: struct{ Three int }{Three: 3}, + final: []any{1, 2, 3, 4}, + }, + binderTests[map[string]int]{ + args: map[string]int{"three": 3}, + final: []any{1, 2, 3, 4}, + }, + binderTests[int]{ + args: 3, + final: []any{1, 2, 3, 4}, + }, + }) + }) + + t.Run("mixed named with nil arg", func(t *testing.T) { + testBinder(t, []any{1, 2, namedArg("three"), 4}, []binderTester{ + binderTests[*struct{ Three int }]{ + args: nil, + final: []any{1, 2, nil, 4}, + }, + binderTests[map[string]int]{ + args: nil, + final: []any{1, 2, nil, 4}, + }, + binderTests[*int]{ + args: nil, + final: []any{1, 2, (*int)(nil), 4}, + }, + }) + }) + + t.Run("varaitions of single binder", func(t *testing.T) { + timeVal, err := time.Parse(time.RFC3339, "2021-01-01T00:00:00Z") + if err != nil { + t.Fatal(err) + } + + testBinder(t, []any{1, 2, namedArg("three"), 4}, []binderTester{ + binderTests[int]{ + args: 3, + final: []any{1, 2, 3, 4}, + }, + binderTests[*int]{ + args: nil, + final: []any{1, 2, (*int)(nil), 4}, + }, + binderTests[time.Time]{ + args: timeVal, + final: []any{1, 2, timeVal, 4}, + }, + binderTests[valuable]{ + args: valuable{3}, + final: []any{1, 2, valuable{3}, 4}, + }, + }) + }) + + t.Run("reuse names", func(t *testing.T) { + testBinder(t, []any{1, 2, namedArg("three"), 4, namedArg("three")}, []binderTester{ + binderTests[struct{ Three int }]{ + args: struct{ Three int }{Three: 3}, + final: []any{1, 2, 3, 4, 3}, + }, + binderTests[map[string]int]{ + args: map[string]int{"three": 3}, + final: []any{1, 2, 3, 4, 3}, + }, + binderTests[int]{ + args: 3, + final: []any{1, 2, 3, 4, 3}, + }, + }) + }) +} + +type valuable struct { + val int +} + +func (v valuable) Value() (driver.Value, error) { + return v.val, nil +} + +func (v valuable) Equal(other valuable) bool { + return v.val == other.val +} diff --git a/dialect/mysql/view.go b/dialect/mysql/view.go index 098926f..d0a16c4 100644 --- a/dialect/mysql/view.go +++ b/dialect/mysql/view.go @@ -88,16 +88,6 @@ func (v *View[T, Tslice]) Query(queryMods ...bob.Mod[*dialect.SelectQuery]) *Vie return q } -// Prepare a statement that will be mapped to the view's type -func (v *View[T, Tslice]) Prepare(ctx context.Context, exec bob.Preparer, queryMods ...bob.Mod[*dialect.SelectQuery]) (bob.QueryStmt[T, Tslice], error) { - return v.PrepareQuery(ctx, exec, v.Query(queryMods...)) -} - -// Prepare a statement from an existing query that will be mapped to the view's type -func (v *View[T, Tslice]) PrepareQuery(ctx context.Context, exec bob.Preparer, q bob.Query) (bob.QueryStmt[T, Tslice], error) { - return bob.PrepareQueryx[T, Tslice](ctx, exec, q, v.scanner) -} - type ViewQuery[T any, Ts ~[]T] struct { orm.Query[*dialect.SelectQuery, T, Ts] } diff --git a/dialect/psql/view.go b/dialect/psql/view.go index 1ba0b62..74aee7f 100644 --- a/dialect/psql/view.go +++ b/dialect/psql/view.go @@ -106,16 +106,6 @@ func (v *View[T, Tslice]) Query(queryMods ...bob.Mod[*dialect.SelectQuery]) *Vie return q } -// Prepare a statement that will be mapped to the view's type -func (v *View[T, Tslice]) Prepare(ctx context.Context, exec bob.Preparer, queryMods ...bob.Mod[*dialect.SelectQuery]) (bob.QueryStmt[T, Tslice], error) { - return v.PrepareQuery(ctx, exec, v.Query(queryMods...)) -} - -// Prepare a statement from an existing query that will be mapped to the view's type -func (v *View[T, Tslice]) PrepareQuery(ctx context.Context, exec bob.Preparer, q bob.Query) (bob.QueryStmt[T, Tslice], error) { - return bob.PrepareQueryx[T, Tslice](ctx, exec, q, v.scanner) -} - type ViewQuery[T any, Ts ~[]T] struct { orm.Query[*dialect.SelectQuery, T, Ts] } diff --git a/dialect/sqlite/view.go b/dialect/sqlite/view.go index ebcfd12..6bd1348 100644 --- a/dialect/sqlite/view.go +++ b/dialect/sqlite/view.go @@ -106,16 +106,6 @@ func (v *View[T, Tslice]) Query(queryMods ...bob.Mod[*dialect.SelectQuery]) *Vie return q } -// Prepare a statement that will be mapped to the view's type -func (v *View[T, Tslice]) Prepare(ctx context.Context, exec bob.Preparer, queryMods ...bob.Mod[*dialect.SelectQuery]) (bob.QueryStmt[T, Tslice], error) { - return v.PrepareQuery(ctx, exec, v.Query(queryMods...)) -} - -// Prepare a statement from an existing query that will be mapped to the view's type -func (v *View[T, Tslice]) PrepareQuery(ctx context.Context, exec bob.Preparer, q bob.Query) (bob.QueryStmt[T, Tslice], error) { - return bob.PrepareQueryx[T, Tslice](ctx, exec, q, v.scanner) -} - type ViewQuery[T any, Ts ~[]T] struct { orm.Query[*dialect.SelectQuery, T, Ts] } diff --git a/gen/templates/models/09_rel_query.go.tpl b/gen/templates/models/09_rel_query.go.tpl index a98d798..68465be 100644 --- a/gen/templates/models/09_rel_query.go.tpl +++ b/gen/templates/models/09_rel_query.go.tpl @@ -3,7 +3,6 @@ {{if $.Relationships.Get $table.Key -}} {{$.Importer.Import "github.com/stephenafamo/bob"}} {{$.Importer.Import "github.com/stephenafamo/bob/mods"}} - {{$.Importer.Import "context"}} type {{$tAlias.DownSingular}}Joins[Q dialect.Joinable] struct { typ string diff --git a/named.go b/named.go new file mode 100644 index 0000000..917b2a9 --- /dev/null +++ b/named.go @@ -0,0 +1,67 @@ +package bob + +import ( + "context" + "database/sql/driver" + "fmt" + "io" +) + +type RawNamedArgError struct { + Name string +} + +func (e RawNamedArgError) Error() string { + return fmt.Sprintf("raw named arg %q used without rebinding", e.Name) +} + +// named args should ONLY be used to prepare statements +type namedArg string + +// Value implements the driver.Valuer interface. +// it always returns an error because named args should only be used to prepare statements +func (n namedArg) Value() (driver.Value, error) { + return nil, RawNamedArgError{string(n)} +} + +// Named args should ONLY be used to prepare statements +func Named(names ...string) Expression { + return named{names: names} +} + +// NamedGroup is like Named, but wraps in parentheses +func NamedGroup(names ...string) Expression { + return named{names: names} +} + +type named struct { + names []string + grouped bool +} + +func (a named) WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) { + if len(a.names) == 0 { + return nil, nil + } + + args := make([]any, len(a.names)) + + if a.grouped { + w.Write([]byte(openPar)) + } + + for k, name := range a.names { + if k > 0 { + w.Write([]byte(commaSpace)) + } + + d.WriteArg(w, start+k) + args[k] = namedArg(name) + } + + if a.grouped { + w.Write([]byte(closePar)) + } + + return args, nil +} diff --git a/orm/query.go b/orm/query.go index 0679ce0..5221a29 100644 --- a/orm/query.go +++ b/orm/query.go @@ -39,10 +39,6 @@ func (q ExecQuery[Q, T, Ts]) RunHooks(ctx context.Context, exec bob.Executor) (c return q.Hooks.RunHooks(ctx, exec, q.BaseQuery.Expression) } -func (q ExecQuery[Q, T, Ts]) Prepare(ctx context.Context, exec bob.Preparer) (bob.QueryStmt[T, Ts], error) { - return bob.PrepareQueryx[T, Ts](ctx, exec, q, q.Scanner) -} - // Execute the query func (q ExecQuery[Q, T, Ts]) Exec(ctx context.Context, exec bob.Executor) (int64, error) { result, err := bob.Exec(ctx, exec, q) diff --git a/query.go b/query.go index eb53114..afde58d 100644 --- a/query.go +++ b/query.go @@ -11,8 +11,9 @@ import ( // To pervent unnecessary allocations const ( - openPar = "(" - closePar = ")" + openPar = "(" + closePar = ")" + commaSpace = ", " ) type QueryType int @@ -100,12 +101,25 @@ func (b BaseQuery[E]) Apply(mods ...Mod[E]) { } func (b BaseQuery[E]) WriteQuery(ctx context.Context, w io.Writer, start int) ([]any, error) { + // If it a query, just call its WriteQuery method + if e, ok := any(b.Expression).(interface { + WriteQuery(context.Context, io.Writer, int) ([]any, error) + }); ok { + return e.WriteQuery(ctx, w, start) + } + return b.Expression.WriteSQL(ctx, w, b.Dialect, start) } // Satisfies the Expression interface, but uses its own dialect instead // of the dialect passed to it -func (b BaseQuery[E]) WriteSQL(ctx context.Context, w io.Writer, _ Dialect, start int) ([]any, error) { +func (b BaseQuery[E]) WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) { + // If it a query, don't wrap it in parentheses + // it may already do this on its own and we don't want to double wrap + if e, ok := any(b.Expression).(Query); ok { + return e.WriteSQL(ctx, w, d, start) + } + w.Write([]byte(openPar)) args, err := b.Expression.WriteSQL(ctx, w, b.Dialect, start) w.Write([]byte(closePar)) @@ -144,3 +158,60 @@ func (q BaseQuery[E]) Cache(ctx context.Context, exec Executor) (BaseQuery[*cach func (q BaseQuery[E]) CacheN(ctx context.Context, exec Executor, start int) (BaseQuery[*cached], error) { return CacheN(ctx, exec, q, start) } + +func BindNamed[Arg any](ctx context.Context, q Query, args Arg) BoundQuery[Arg] { + return BoundQuery[Arg]{Query: q, namedArgs: args} +} + +type BoundQuery[Arg any] struct { + Query + namedArgs Arg +} + +func (b BoundQuery[Arg]) WriteQuery(ctx context.Context, w io.Writer, start int) ([]any, error) { + args, err := b.Query.WriteQuery(ctx, w, start) + if err != nil { + return nil, err + } + + return bindArgs(args, b.namedArgs) +} + +// Satisfies the Expression interface, but uses its own dialect instead +// of the dialect passed to it +func (b BoundQuery[E]) WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) { + args, err := b.Query.WriteSQL(ctx, w, d, start) + if err != nil { + return nil, err + } + + return bindArgs(args, b.namedArgs) +} + +func (b BoundQuery[E]) Exec(ctx context.Context, exec Executor) (sql.Result, error) { + return Exec(ctx, exec, b) +} + +func (b BoundQuery[E]) RunHooks(ctx context.Context, exec Executor) (context.Context, error) { + if l, ok := any(b.Query).(HookableQuery); ok { + return l.RunHooks(ctx, exec) + } + + return ctx, nil +} + +func (b BoundQuery[E]) GetLoaders() []Loader { + if l, ok := any(b.Query).(Loadable); ok { + return l.GetLoaders() + } + + return nil +} + +func (b BoundQuery[E]) GetMapperMods() []scan.MapperMod { + if l, ok := any(b.Query).(MapperModder); ok { + return l.GetMapperMods() + } + + return nil +} diff --git a/query_test.go b/query_test.go index a04588a..b9531ab 100644 --- a/query_test.go +++ b/query_test.go @@ -1,12 +1,19 @@ package bob var ( - _ Query = BaseQuery[Expression]{} - _ Loadable = BaseQuery[Expression]{} - _ MapperModder = BaseQuery[Expression]{} - _ HookableQuery = BaseQuery[Expression]{} + _ Expression = &cached{} + + _ Query = BaseQuery[Expression]{} + _ Query = BoundQuery[Expression]{} + + _ Loadable = BaseQuery[Expression]{} + _ Loadable = BoundQuery[Expression]{} + _ Loadable = &cached{} - _ Expression = &cached{} - _ Loadable = &cached{} + _ MapperModder = BaseQuery[Expression]{} + _ MapperModder = BoundQuery[Expression]{} _ MapperModder = &cached{} + + _ HookableQuery = BaseQuery[Expression]{} + _ HookableQuery = BoundQuery[Expression]{} ) diff --git a/stdlib.go b/stdlib.go index 46ca52e..0ad4e1b 100644 --- a/stdlib.go +++ b/stdlib.go @@ -42,9 +42,9 @@ type common[T StdInterface] struct { } // PrepareContext creates a prepared statement for later queries or executions -func (c common[T]) PrepareContext(ctx context.Context, query string) (Statement, error) { +func (c common[T]) PrepareContext(ctx context.Context, query string) (StdPrepared, error) { s, err := c.wrapped.PrepareContext(ctx, query) - return stdStmt{s}, err + return StdPrepared{s}, err } // ExecContext executes a query without returning any rows. The args are for any placeholder parameters in the query. @@ -103,6 +103,11 @@ func NewTx(tx *sql.Tx) Tx { return Tx{New(tx)} } +var ( + _ txForStmt[StdPrepared] = &Tx{} + _ Preparer[StdPrepared] = &Tx{} +) + // Tx is similar to *sql.Tx but implements [Queryer] type Tx struct { common[*sql.Tx] @@ -118,6 +123,10 @@ func (t Tx) Rollback() error { return t.wrapped.Rollback() } +func (tx *Tx) StmtContext(ctx context.Context, stmt StdPrepared) StdPrepared { + return StdPrepared{tx.wrapped.StmtContext(ctx, stmt.Stmt)} +} + // NewConn wraps an [*sql.Conn] and returns a type that implements [Queryer] // This is useful when an existing *sql.Conn is used in other places in the codebase func NewConn(conn *sql.Conn) Conn { @@ -150,10 +159,10 @@ func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { return NewTx(tx), nil } -type stdStmt struct { +type StdPrepared struct { *sql.Stmt } -func (s stdStmt) QueryContext(ctx context.Context, args ...any) (scan.Rows, error) { +func (s StdPrepared) QueryContext(ctx context.Context, args ...any) (scan.Rows, error) { return s.Stmt.QueryContext(ctx, args...) } diff --git a/stmt.go b/stmt.go index 01e27a5..b5e1e99 100644 --- a/stmt.go +++ b/stmt.go @@ -7,43 +7,47 @@ import ( "github.com/stephenafamo/scan" ) -type Preparer interface { +type Preparer[P PreparedExecutor] interface { Executor - PrepareContext(ctx context.Context, query string) (Statement, error) + PrepareContext(ctx context.Context, query string) (P, error) } -type Statement interface { +type PreparedExecutor interface { ExecContext(ctx context.Context, args ...any) (sql.Result, error) QueryContext(ctx context.Context, args ...any) (scan.Rows, error) + Close() error } -// NewStmt wraps an [*sql.Stmt] and returns a type that implements [Queryer] but still -// retains the expected methods used by *sql.Stmt -// This is useful when an existing *sql.Stmt is used in other places in the codebase -func Prepare(ctx context.Context, exec Preparer, q Query) (Stmt, error) { +// Prepare prepares a query using the [Preparer] and returns a [NamedStmt] +func Prepare[Arg any, P PreparedExecutor](ctx context.Context, exec Preparer[P], q Query) (Stmt[Arg], error) { var err error if h, ok := q.(HookableQuery); ok { ctx, err = h.RunHooks(ctx, exec) if err != nil { - return Stmt{}, err + return Stmt[Arg]{}, err } } query, args, err := Build(ctx, q) if err != nil { - return Stmt{}, err + return Stmt[Arg]{}, err } - stmt, err := exec.PrepareContext(ctx, query) + binder, err := makeBinder[Arg](args) if err != nil { - return Stmt{}, err + return Stmt[Arg]{}, err } - s := Stmt{ - exec: exec, - stmt: stmt, - lenArgs: len(args), + stmt, err := exec.PrepareContext(ctx, string(query)) + if err != nil { + return Stmt[Arg]{}, err + } + + s := Stmt[Arg]{ + stmt: stmt, + exec: exec, + binder: binder, } if l, ok := q.(Loadable); ok { @@ -56,15 +60,39 @@ func Prepare(ctx context.Context, exec Preparer, q Query) (Stmt, error) { } // Stmt is similar to *sql.Stmt but implements [Queryer] -type Stmt struct { - stmt Statement +// instead of taking a list of args, it takes a struct to bind to the query +type Stmt[Arg any] struct { + stmt PreparedExecutor exec Executor - lenArgs int loaders []Loader + binder binder[Arg] +} + +type txForStmt[Stmt PreparedExecutor] interface { + Executor + StmtContext(context.Context, Stmt) Stmt +} + +// InTx returns a new MappedStmt that will be executed in the given transaction +func InTx[Arg any, S PreparedExecutor](ctx context.Context, s Stmt[Arg], tx txForStmt[S]) Stmt[Arg] { + stmt, ok := s.stmt.(S) + if !ok { + panic("stmt is not an the right type") + } + + s.stmt = tx.StmtContext(ctx, stmt) + s.exec = tx + return s +} + +// Close closes the statement. +func (s Stmt[Arg]) Close() error { + return s.stmt.Close() } // Exec executes a query without returning any rows. The args are for any placeholder parameters in the query. -func (s Stmt) Exec(ctx context.Context, args ...any) (sql.Result, error) { +func (s Stmt[Arg]) Exec(ctx context.Context, arg Arg) (sql.Result, error) { + args := s.binder.toArgs(arg) result, err := s.stmt.ExecContext(ctx, args...) if err != nil { return nil, err @@ -79,14 +107,18 @@ func (s Stmt) Exec(ctx context.Context, args ...any) (sql.Result, error) { return result, nil } -func PrepareQuery[T any](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T]) (QueryStmt[T, []T], error) { - return PrepareQueryx[T, []T](ctx, exec, q, m) +func (s Stmt[Arg]) NamedArgs() []string { + return s.binder.list() +} + +func PrepareQuery[Arg any, P PreparedExecutor, T any](ctx context.Context, exec Preparer[P], q Query, m scan.Mapper[T]) (QueryStmt[Arg, T, []T], error) { + return PrepareQueryx[Arg, P, T, []T](ctx, exec, q, m) } -func PrepareQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T]) (QueryStmt[T, Ts], error) { - var qs QueryStmt[T, Ts] +func PrepareQueryx[Arg any, P PreparedExecutor, T any, Ts ~[]T](ctx context.Context, exec Preparer[P], q Query, m scan.Mapper[T]) (QueryStmt[Arg, T, Ts], error) { + var qs QueryStmt[Arg, T, Ts] - s, err := Prepare(ctx, exec, q) + s, err := Prepare[Arg](ctx, exec, q) if err != nil { return qs, err } @@ -97,7 +129,7 @@ func PrepareQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, } } - qs = QueryStmt[T, Ts]{ + qs = QueryStmt[Arg, T, Ts]{ Stmt: s, queryType: q.Type(), mapper: m, @@ -106,16 +138,17 @@ func PrepareQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, return qs, nil } -type QueryStmt[T any, Ts ~[]T] struct { - Stmt +type QueryStmt[Arg, T any, Ts ~[]T] struct { + Stmt[Arg] queryType QueryType mapper scan.Mapper[T] } -func (s QueryStmt[T, Ts]) One(ctx context.Context, args ...any) (T, error) { +func (s QueryStmt[Arg, T, Ts]) One(ctx context.Context, arg Arg) (T, error) { var t T + args := s.binder.toArgs(arg) rows, err := s.stmt.QueryContext(ctx, args...) if err != nil { return t, err @@ -141,7 +174,8 @@ func (s QueryStmt[T, Ts]) One(ctx context.Context, args ...any) (T, error) { return t, err } -func (s QueryStmt[T, Ts]) All(ctx context.Context, args ...any) (Ts, error) { +func (s QueryStmt[Arg, T, Ts]) All(ctx context.Context, arg Arg) (Ts, error) { + args := s.binder.toArgs(arg) rows, err := s.stmt.QueryContext(ctx, args...) if err != nil { return nil, err @@ -175,7 +209,8 @@ func (s QueryStmt[T, Ts]) All(ctx context.Context, args ...any) (Ts, error) { return typedSlice, err } -func (s QueryStmt[T, Ts]) Cursor(ctx context.Context, args ...any) (scan.ICursor[T], error) { +func (s QueryStmt[Arg, T, Ts]) Cursor(ctx context.Context, arg Arg) (scan.ICursor[T], error) { + args := s.binder.toArgs(arg) rows, err := s.stmt.QueryContext(ctx, args...) if err != nil { return nil, err