diff --git a/pika.go b/pika.go index de13491..32bfd6a 100644 --- a/pika.go +++ b/pika.go @@ -4,6 +4,8 @@ package pika import ( + "context" + orderedmap "github.com/wk8/go-ordered-map/v2" ) @@ -165,6 +167,34 @@ type QuerySet[T any] interface { Exclude(excludes ...string) QuerySet[T] // Include fields Include(includes ...string) QuerySet[T] + + // EXPERIMENTAL + // The following methods are EXPERIMENTAL. Think of it as a sneak peek on what's coming. + // It is mostly to experiment with a simpler API for filtering, updating and querying. + // Feel free to test it out and provide feedback. + + // U is a shorthand for Update. ID field is used as the filter. + // Other filters applied to the query set are also inherited. + // Returns an error if the ID field is not set or does not exist. + // Thus preventing accidental updates to all rows. + U(value *T) error + + // F is a shorthand for Filter. It is a variadic function that accepts a list of filters. + // The filters are applied in the order they are given. + // Format is as follows: , etc. + F(keyval ...any) QuerySet[T] + + // D is a shorthand for Delete. ID field is used as the filter. + // Other filters applied to the query set are also inherited. + // Returns an error if the ID field is not set or does not exist. + // Thus preventing accidental deletes to all rows. + D(value *T) error + + // Transaction is a shorthand for wrapping a query set in a transaction. + // Currently Pika transactions affects the full connection, not just the query set. + // That method works if you use factories to create query sets. + // This helper will re-use the internal DB instance to return a new query set with the transaction. + Transaction(ctx context.Context) (QuerySet[T], error) } func NewArgs() *orderedmap.OrderedMap[string, any] { diff --git a/pika_psql_experimental.go b/pika_psql_experimental.go new file mode 100644 index 0000000..64efb24 --- /dev/null +++ b/pika_psql_experimental.go @@ -0,0 +1,70 @@ +package pika + +import ( + "context" + "fmt" + "reflect" +) + +func (b *basePsql[T]) findID(x *T) any { + elem := reflect.ValueOf(x).Elem() + + // Check if ID is a field + idField := elem.FieldByName("ID") + if idField.IsValid() { + return idField.Interface() + } + + // Also check for Id field + idField = elem.FieldByName("Id") + if idField.IsValid() { + return idField.Interface() + } + + // Return nil if ID is not a field + return nil +} + +func (b *basePsql[T]) F(keyval ...any) QuerySet[T] { + args := NewArgs() + var queries []string + for i := 0; i < len(keyval); i += 2 { + args.Set(keyval[i].(string), keyval[i+1]) + filter := fmt.Sprintf("%s=:%s", keyval[i].(string), keyval[i].(string)) + queries = append(queries, filter) + } + + logger.Debugf("F: %s", queries) + + return b.Args(args).Filter(queries...) +} + +func (b *basePsql[T]) D(x *T) error { + id := b.findID(x) + if id == nil { + return fmt.Errorf("id not found") + } + + qs := b.F("id", id) + return qs.Delete() +} + +func (b *basePsql[T]) Transaction(ctx context.Context) (QuerySet[T], error) { + ts := NewPostgreSQLFromDB(b.psql.DB()) + err := ts.Begin(ctx) + if err != nil { + return nil, err + } + + return Q[T](ts), nil +} + +func (b *basePsql[T]) U(x *T) error { + id := b.findID(x) + if id == nil { + return fmt.Errorf("id not found") + } + + qs := b.F("id", id) + return qs.Update(x) +} diff --git a/pika_psql_experimental_test.go b/pika_psql_experimental_test.go new file mode 100644 index 0000000..84377ae --- /dev/null +++ b/pika_psql_experimental_test.go @@ -0,0 +1,82 @@ +package pika + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestF_1(t *testing.T) { + psql := newPsql(t) + createTestEntries(t, psql) + + m, err := Q[simpleModel1](psql).F("title", "Test").All() + require.Nil(t, err) + require.NotNil(t, m) + require.Equal(t, 1, len(m)) + require.Equal(t, "Test", m[0].Title) +} + +func TestF_2(t *testing.T) { + psql := newPsql(t) + createTestEntries(t, psql) + + m, err := Q[simpleModel1](psql).F("title", "Test", "description", "Test").All() + require.Nil(t, err) + require.NotNil(t, m) + require.Equal(t, 1, len(m)) + require.Equal(t, "Test", m[0].Title) +} + +func TestF_3Or(t *testing.T) { + psql := newPsql(t) + createTestEntries(t, psql) + + m, err := Q[simpleModel1](psql).F("title", "Test", "title__or", "Test2").All() + require.Nil(t, err) + require.NotNil(t, m) + require.Equal(t, 2, len(m)) + require.Equal(t, "Test", m[0].Title) + require.Equal(t, "Test2", m[1].Title) +} + +func TestU(t *testing.T) { + psql := newPsql(t) + createTestEntries(t, psql) + + m, err := Q[simpleModel1](psql).F("title", "Test").All() + require.Nil(t, err) + require.NotNil(t, m) + require.Equal(t, 1, len(m)) + require.Equal(t, "Test", m[0].Title) + + elem := m[0] + elem.Title = "TestUpdated" + err = Q[simpleModel1](psql).U(elem) + require.Nil(t, err) + + m, err = Q[simpleModel1](psql).F("title", "TestUpdated").All() + require.Nil(t, err) + require.NotNil(t, m) + require.Equal(t, 1, len(m)) + require.Equal(t, "TestUpdated", m[0].Title) +} + +func TestD(t *testing.T) { + psql := newPsql(t) + createTestEntries(t, psql) + + m, err := Q[simpleModel1](psql).F("title", "Test").All() + require.Nil(t, err) + require.NotNil(t, m) + require.Equal(t, 1, len(m)) + require.Equal(t, "Test", m[0].Title) + + elem := m[0] + err = Q[simpleModel1](psql).D(elem) + require.Nil(t, err) + + m, err = Q[simpleModel1](psql).F("title", "Test").All() + require.Nil(t, err) + require.Equal(t, 0, len(m)) +}