Skip to content

Commit

Permalink
Merge pull request #10 from ctrliq/context-aware-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mstg authored Mar 28, 2024
2 parents c97a81e + 293a94b commit b5cc282
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 129 deletions.
20 changes: 10 additions & 10 deletions pika.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,32 +87,32 @@ type QuerySet[T any] interface {
ClearAll() QuerySet[T]

// Create creates a new value
Create(value *T, options ...CreateOption) error
Create(ctx context.Context, value *T, options ...CreateOption) error

// Update updates a value
// All filters will be applied
Update(value *T) error
Update(ctx context.Context, value *T) error

// Delete deletes a row
// All filters will be applied
Delete() error
Delete(ctx context.Context) error

// GetOrNil returns a single value or nil
// Multiple values will return an error.
// Ignores Limit
GetOrNil() (*T, error)
GetOrNil(ctx context.Context) (*T, error)

// Get returns a single value
// Returns error if no value is found
// Returns error if multiple values are found
// Ignores Limit
Get() (*T, error)
Get(ctx context.Context) (*T, error)

// All returns all values
All() ([]*T, error)
All(ctx context.Context) ([]*T, error)

// Count returns the number of values
Count() (int, error)
Count(ctx context.Context) (int, error)

// Limit sets the limit for the query
Limit(limit int) QuerySet[T]
Expand Down Expand Up @@ -157,7 +157,7 @@ type QuerySet[T any] interface {
// Page token functionality for gRPC
// The count is optional and returns the total number of rows for the query.
// It is implemented as a variadic function to not break existing code.
GetPage(paginatable Paginatable, options AIPFilterOptions, count ...*int) ([]*T, string, error)
GetPage(ctx context.Context, paginatable Paginatable, options AIPFilterOptions, count ...*int) ([]*T, string, error)

// Join table
InnerJoin(modelFirst, modelSecond interface{}, keyFirst, keySecond string) QuerySet[T]
Expand All @@ -179,7 +179,7 @@ type QuerySet[T any] interface {
// Other filters applied to the query set are also inherited.
// Returns an error if the ID field is not set or does not exist.
// Thus preventing accidental updates to all rows.
U(value *T) error
U(ctx context.Context, value *T) error

// F is a shorthand for Filter. It is a variadic function that accepts a list of filters.
// The filters are applied in the order they are given.
Expand All @@ -190,7 +190,7 @@ type QuerySet[T any] interface {
// Other filters applied to the query set are also inherited.
// Returns an error if the ID field is not set or does not exist.
// Thus preventing accidental deletes to all rows.
D(value *T) error
D(ctx context.Context, value *T) error

// Transaction is a shorthand for wrapping a query set in a transaction.
// Currently Pika transactions affects the full connection, not just the query set.
Expand Down
48 changes: 20 additions & 28 deletions pika_psql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,12 @@ type Queryable interface {
sqlx.Preparer

GetContext(context.Context, interface{}, string, ...interface{}) error
SelectContext(context.Context, interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
MustExecContext(context.Context, string, ...interface{}) sql.Result
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error)
PreparexContext(context.Context, string) (*sqlx.Stmt, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
Select(interface{}, string, ...interface{}) error
QueryRow(string, ...interface{}) *sql.Row
PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error)
PrepareNamed(string) (*sqlx.NamedStmt, error)
Preparex(string) (*sqlx.Stmt, error)
NamedExec(string, interface{}) (sql.Result, error)
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
MustExec(string, ...interface{}) sql.Result
NamedQuery(string, interface{}) (*sqlx.Rows, error)
SelectContext(context.Context, interface{}, string, ...interface{}) error
}

var (
Expand Down Expand Up @@ -258,7 +250,7 @@ func (b *basePsql[T]) ClearAll() QuerySet[T] {
}

// Create creates a new record in the database.
func (b *basePsql[T]) Create(x *T, options ...CreateOption) error {
func (b *basePsql[T]) Create(ctx context.Context, x *T, options ...CreateOption) error {
if b.err != nil {
return b.err
}
Expand All @@ -270,7 +262,7 @@ func (b *basePsql[T]) Create(x *T, options ...CreateOption) error {
b.ignoreOrderBy = origIgnoreOrderBy

// Execute query
err := b.psql.Queryable().Get(x, q, args...)
err := b.psql.Queryable().GetContext(ctx, x, q, args...)
if err != nil {
// ignore no rows in resultset error when ignoreConflict is set to true, this is a normal case
if errors.Is(err, sql.ErrNoRows) && (InsertOnConflictionDoNothing&getOption(options...) != 0) {
Expand All @@ -283,7 +275,7 @@ func (b *basePsql[T]) Create(x *T, options ...CreateOption) error {
}

// Update updates a record in the database.
func (b *basePsql[T]) Update(x *T) error {
func (b *basePsql[T]) Update(ctx context.Context, x *T) error {
if b.err != nil {
return b.err
}
Expand All @@ -294,7 +286,7 @@ func (b *basePsql[T]) Update(x *T) error {
b.ignoreOrderBy = origIgnoreOrderBy

// Execute query
err := b.psql.Queryable().Get(x, q, args...)
err := b.psql.Queryable().GetContext(ctx, x, q, args...)
if err != nil {
return err
}
Expand All @@ -303,7 +295,7 @@ func (b *basePsql[T]) Update(x *T) error {
}

// Delete deletes a record from the database.
func (b *basePsql[T]) Delete() error {
func (b *basePsql[T]) Delete(ctx context.Context) error {
if b.err != nil {
return b.err
}
Expand All @@ -314,7 +306,7 @@ func (b *basePsql[T]) Delete() error {
b.ignoreOrderBy = origIgnoreOrderBy

// Execute query
_, err := b.psql.Queryable().Exec(q, args...)
_, err := b.psql.Queryable().ExecContext(ctx, q, args...)
if err != nil {
return err
}
Expand All @@ -324,7 +316,7 @@ func (b *basePsql[T]) Delete() error {

// GetOrNil returns a single value or nil
// Multiple values will return an error.
func (b *basePsql[T]) GetOrNil() (*T, error) {
func (b *basePsql[T]) GetOrNil(ctx context.Context) (*T, error) {
if b.err != nil {
return nil, b.err
}
Expand All @@ -338,7 +330,7 @@ func (b *basePsql[T]) GetOrNil() (*T, error) {
var x T

// Send arguments to prepared statement
err := b.psql.Queryable().Get(&x, q, args...)
err := b.psql.Queryable().GetContext(ctx, &x, q, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
Expand All @@ -352,7 +344,7 @@ func (b *basePsql[T]) GetOrNil() (*T, error) {
// Get returns a single value
// Returns error if no value is found
// Returns error if multiple values are found
func (b *basePsql[T]) Get() (*T, error) {
func (b *basePsql[T]) Get(ctx context.Context) (*T, error) {
if b.err != nil {
return nil, b.err
}
Expand All @@ -363,7 +355,7 @@ func (b *basePsql[T]) Get() (*T, error) {
var x T

// Send arguments to prepared statement
err := b.psql.Queryable().Get(&x, q, args...)
err := b.psql.Queryable().GetContext(ctx, &x, q, args...)
if err != nil {
return nil, err
}
Expand All @@ -372,7 +364,7 @@ func (b *basePsql[T]) Get() (*T, error) {
}

// All returns all values
func (b *basePsql[T]) All() ([]*T, error) {
func (b *basePsql[T]) All(ctx context.Context) ([]*T, error) {
if b.err != nil {
return nil, b.err
}
Expand All @@ -383,7 +375,7 @@ func (b *basePsql[T]) All() ([]*T, error) {
var x []*T

// Send arguments to prepared statement
err := b.psql.Queryable().Select(&x, q, args...)
err := b.psql.Queryable().SelectContext(ctx, &x, q, args...)
if err != nil {
return nil, err
}
Expand All @@ -392,7 +384,7 @@ func (b *basePsql[T]) All() ([]*T, error) {
}

// Count returns the number of values
func (b *basePsql[T]) Count() (int, error) {
func (b *basePsql[T]) Count(ctx context.Context) (int, error) {
if b.err != nil {
return 0, b.err
}
Expand Down Expand Up @@ -420,7 +412,7 @@ func (b *basePsql[T]) Count() (int, error) {
q := fmt.Sprintf("%s%s", selectQuery, filterStatement)
logger.Debugf("Pika query: %s", q)

err := b.psql.Queryable().Get(&x, q, args...)
err := b.psql.Queryable().GetContext(ctx, &x, q, args...)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -542,7 +534,7 @@ func (b *basePsql[T]) AIP160(filter string, options AIPFilterOptions) (QuerySet[
}

// Page tokens for gRPC
func (b *basePsql[T]) GetPage(paginatable Paginatable, options AIPFilterOptions, countPointer ...*int) ([]*T, string, error) {
func (b *basePsql[T]) GetPage(ctx context.Context, paginatable Paginatable, options AIPFilterOptions, countPointer ...*int) ([]*T, string, error) {
if len(countPointer) > 1 {
return nil, "", fmt.Errorf("too many arguments (count should be one pointer or none)")
}
Expand Down Expand Up @@ -570,15 +562,15 @@ func (b *basePsql[T]) GetPage(paginatable Paginatable, options AIPFilterOptions,
return nil, "", err
}

result, err := qs.All()
result, err := qs.All(ctx)
if err != nil {
return nil, "", err
}

b.PageToken.Offset += uint(len(result))

// Get count and check if there are more results
count, err := b.Count()
count, err := b.Count(ctx)
if err != nil {
return nil, "", fmt.Errorf("getting count: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions pika_psql_experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ func (b *basePsql[T]) F(keyval ...any) QuerySet[T] {
return b.Args(args).Filter(queries...)
}

func (b *basePsql[T]) D(x *T) error {
func (b *basePsql[T]) D(ctx context.Context, x *T) error {
id := b.findID(x)
if id == nil {
return fmt.Errorf("id not found")
}

qs := b.F("id", id)
return qs.Delete()
return qs.Delete(ctx)
}

func (b *basePsql[T]) Transaction(ctx context.Context) (QuerySet[T], error) {
Expand All @@ -62,12 +62,12 @@ func (b *basePsql[T]) Transaction(ctx context.Context) (QuerySet[T], error) {
return Q[T](ts), nil
}

func (b *basePsql[T]) U(x *T) error {
func (b *basePsql[T]) U(ctx context.Context, x *T) error {
id := b.findID(x)
if id == nil {
return fmt.Errorf("id not found")
}

qs := b.F("id", id)
return qs.Update(x)
return qs.Update(ctx, x)
}
19 changes: 10 additions & 9 deletions pika_psql_experimental_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pika

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -10,7 +11,7 @@ func TestF_1(t *testing.T) {
psql := newPsql(t)
createTestEntries(t, psql)

m, err := Q[simpleModel1](psql).F("title", "Test").All()
m, err := Q[simpleModel1](psql).F("title", "Test").All(context.Background())
require.Nil(t, err)
require.NotNil(t, m)
require.Equal(t, 1, len(m))
Expand All @@ -21,7 +22,7 @@ func TestF_2(t *testing.T) {
psql := newPsql(t)
createTestEntries(t, psql)

m, err := Q[simpleModel1](psql).F("title", "Test", "description", "Test").All()
m, err := Q[simpleModel1](psql).F("title", "Test", "description", "Test").All(context.Background())
require.Nil(t, err)
require.NotNil(t, m)
require.Equal(t, 1, len(m))
Expand All @@ -32,7 +33,7 @@ func TestF_3Or(t *testing.T) {
psql := newPsql(t)
createTestEntries(t, psql)

m, err := Q[simpleModel1](psql).F("title", "Test", "title__or", "Test2").All()
m, err := Q[simpleModel1](psql).F("title", "Test", "title__or", "Test2").All(context.Background())
require.Nil(t, err)
require.NotNil(t, m)
require.Equal(t, 2, len(m))
Expand All @@ -44,18 +45,18 @@ func TestU(t *testing.T) {
psql := newPsql(t)
createTestEntries(t, psql)

m, err := Q[simpleModel1](psql).F("title", "Test").All()
m, err := Q[simpleModel1](psql).F("title", "Test").All(context.Background())
require.Nil(t, err)
require.NotNil(t, m)
require.Equal(t, 1, len(m))
require.Equal(t, "Test", m[0].Title)

elem := m[0]
elem.Title = "TestUpdated"
err = Q[simpleModel1](psql).U(elem)
err = Q[simpleModel1](psql).U(context.Background(), elem)
require.Nil(t, err)

m, err = Q[simpleModel1](psql).F("title", "TestUpdated").All()
m, err = Q[simpleModel1](psql).F("title", "TestUpdated").All(context.Background())
require.Nil(t, err)
require.NotNil(t, m)
require.Equal(t, 1, len(m))
Expand All @@ -66,17 +67,17 @@ func TestD(t *testing.T) {
psql := newPsql(t)
createTestEntries(t, psql)

m, err := Q[simpleModel1](psql).F("title", "Test").All()
m, err := Q[simpleModel1](psql).F("title", "Test").All(context.Background())
require.Nil(t, err)
require.NotNil(t, m)
require.Equal(t, 1, len(m))
require.Equal(t, "Test", m[0].Title)

elem := m[0]
err = Q[simpleModel1](psql).D(elem)
err = Q[simpleModel1](psql).D(context.Background(), elem)
require.Nil(t, err)

m, err = Q[simpleModel1](psql).F("title", "Test").All()
m, err = Q[simpleModel1](psql).F("title", "Test").All(context.Background())
require.Nil(t, err)
require.Equal(t, 0, len(m))
}
Loading

0 comments on commit b5cc282

Please sign in to comment.