Skip to content

Commit

Permalink
sql,pg,engine: rework sql interfaces
Browse files Browse the repository at this point in the history
Reduce sql interfaces, and clean up pg methods.
  • Loading branch information
jchappelow committed Mar 15, 2024
1 parent 89f2b36 commit 5af4d4a
Show file tree
Hide file tree
Showing 27 changed files with 239 additions and 353 deletions.
15 changes: 0 additions & 15 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,6 @@ func initVoteStore(d *coreDependencies, db *pg.DB) {
failBuild(err, "failed to initialize vote store")
}

_, err = tx.Precommit(d.ctx)
if err != nil {
failBuild(err, "failed to precommit")
}

err = tx.Commit(d.ctx)
if err != nil {
failBuild(err, "failed to commit")
Expand Down Expand Up @@ -353,11 +348,6 @@ func buildEngine(d *coreDependencies, db *pg.DB) *execution.GlobalContext {
failBuild(err, "failed to build engine")
}

_, err = tx.Precommit(d.ctx)
if err != nil {
failBuild(err, "failed to precommit")
}

err = tx.Commit(d.ctx)
if err != nil {
failBuild(err, "failed to commit")
Expand All @@ -378,11 +368,6 @@ func initAccountRepository(d *coreDependencies, db *pg.DB) {
failBuild(err, "failed to initialize account store")
}

_, err = tx.Precommit(d.ctx)
if err != nil {
failBuild(err, "failed to precommit")
}

err = tx.Commit(d.ctx)
if err != nil {
failBuild(err, "failed to commit")
Expand Down
133 changes: 64 additions & 69 deletions common/sql/sql.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Package sql defines common type required by SQL database implementations and
// consumers.
package sql

import (
Expand All @@ -10,33 +12,6 @@ var (
ErrNoRows = errors.New("no rows in result set")
)

// DB is a connection to a Postgres database.
// It has root user access, and can execute any Postgres command.
type DB interface {
Executor
TxMaker
// AccessMode gets the access mode of the database.
// It can be either read-write or read-only.
AccessMode() AccessMode
}

// Executor is an interface that can execute queries.
type Executor interface {
// Execute executes a query or command.
// The stmt should be a valid Postgres statement.
Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error)
}

// Tx is a database transaction. It can be nested within other
// transactions.
type Tx interface {
DB
// Rollback rolls back the transaction.
Rollback(ctx context.Context) error
// Commit commits the transaction.
Commit(ctx context.Context) error
}

// ResultSet is the result of a query or execution.
// It contains the returned columns and the rows.
type ResultSet struct {
Expand All @@ -53,44 +28,45 @@ type CommandTag struct {
RowsAffected int64
}

// AccessMode is the type of access to a database.
// It can be read-write or read-only.
type AccessMode uint8
// Executor is an interface that can execute queries.
type Executor interface {
// Execute executes a query or command.
Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error)
}

const (
// ReadWrite is the default access mode.
// It allows for reading and writing to the database.
ReadWrite AccessMode = iota
// ReadOnly allows for reading from the database, but not
// writing.
ReadOnly
)
// TxMaker is an interface that creates a new transaction. In the context of the
// recursive Tx interface, is creates a nested transaction.
type TxMaker interface {
BeginTx(ctx context.Context) (Tx, error)
}

// TxCloser terminates a transaction by committing or rolling it
// back. A method that returns this alone would keep the tx under the
// hood of the parent type, directing queries internally through the
// scope of a transaction/session started with BeginTx.
// TxCloser terminates a transaction by committing or rolling it back.
type TxCloser interface {
// Rollback rolls back the transaction.
Rollback(ctx context.Context) error
// Commit commits the transaction.
Commit(ctx context.Context) error
}

// TxPrecommitter is the special kind of transaction that can prepare
// a transaction for commit. It is only available on the outermost
// transaction.
type TxPrecommitter interface {
Precommit(ctx context.Context) ([]byte, error)
}

type TxBeginner interface {
Begin(ctx context.Context) (TxCloser, error)
// Tx represents a database transaction. It can be nested within other
// transactions, and create new nested transactions. An implementation of Tx may
// also be an AccessModer, but it is not required.
type Tx interface {
Executor
TxCloser
TxMaker // recursive interface
// note: does not embed DB for clear semantics (DB makes a Tx, not the reverse)
}

// OuterTxMaker is the special kind of transaction beginner that can
// make nested transactions, and that explicitly scopes Query/Execute
// to the tx.
type OuterTxMaker interface {
BeginTx(ctx context.Context) (OuterTx, error)
// DB is a top level database interface, which may directly execute queries or
// create transactions, which may be closed or create additional nested
// transactions.
//
// Some implementations may also be an OuterTxMaker and/or a ReadTxMaker. Embed
// with those interfaces to compose the minimal interface required.
type DB interface {
Executor
TxMaker
}

// ReadTxMaker can make read-only transactions.
Expand All @@ -99,20 +75,39 @@ type ReadTxMaker interface {
BeginReadTx(ctx context.Context) (Tx, error)
}

// TxMaker is the special kind of transaction beginner that can make
// nested
// transactions, and that explicitly scopes Query/Execute to the tx.
type TxMaker interface {
BeginTx(ctx context.Context) (Tx, error)
}

// OuterTx is a database transaction. It is the outermost transaction
// type. "nested transactions" are called savepoints, and can be
// created with BeginSavepoint. Savepoints can be nested, and are
// rolled back to the innermost savepoint on Rollback.
// OuterTx is the outermost database transaction.
//
// Anything using implicit tx/session management should use TxCloser.
// NOTE: An OuterTx may be used where only a Tx or DB is required since those
// interfaces are a subset of the OuterTx method set.
type OuterTx interface {
Tx
TxPrecommitter
Precommit(ctx context.Context) ([]byte, error)
}

// OuterTxMaker is the special kind of transaction that creates a transaction
// that has a Precommit method (see OuterTx), which supports obtaining a commit
// ID using a (two-phase) prepared transaction prior to Commit. This is a
// different method name so that an implementation may satisfy both OuterTxMaker
// and TxMaker.
type OuterTxMaker interface {
BeginOuterTx(ctx context.Context) (OuterTx, error)
}

// AccessMode is the type of access to a database.
// It can be read-write or read-only.
type AccessMode uint8

const (
// ReadWrite is the default access mode.
// It allows for reading and writing to the database.
ReadWrite AccessMode = iota
// ReadOnly allows for reading from the database, but not writing.
ReadOnly
)

// AccessModer may be satisfied by implementations of Tx and DB, but is not
// universally required for those interfaces (type assert as needed).
type AccessModer interface {
// AccessMode gets the access mode of the database or transaction.
AccessMode() AccessMode
}
4 changes: 0 additions & 4 deletions internal/accounts/accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ func newDB() *mockDB {
}
}

func (m *mockDB) AccessMode() sql.AccessMode {
return sql.ReadWrite // not use in these tests
}

func (m *mockDB) BeginTx(ctx context.Context) (sql.Tx, error) {
return &mockTx{m}, nil
}
Expand Down
2 changes: 2 additions & 0 deletions internal/engine/execution/execution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ type mockDB struct {
executedStmts []string
}

var _ sql.AccessModer = (*mockDB)(nil)

func (m *mockDB) AccessMode() sql.AccessMode {
return m.accessMode
}
Expand Down
24 changes: 16 additions & 8 deletions internal/engine/execution/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package execution
import (
"bytes"
"context"
"errors"
"fmt"
"sort"
"sync"
Expand Down Expand Up @@ -51,17 +52,18 @@ func InitializeEngine(ctx context.Context, tx sql.DB) error {
return nil
}

// NewGlobalContext creates a new global context.
// It will load any persisted datasets from the datastore.
func NewGlobalContext(ctx context.Context, tx sql.DB, extensionInitializers map[string]precompiles.Initializer,
// NewGlobalContext creates a new global context. It will load any persisted
// datasets from the datastore. The provided database is only used for
// construction.
func NewGlobalContext(ctx context.Context, db sql.Executor, extensionInitializers map[string]precompiles.Initializer,
service *common.Service) (*GlobalContext, error) {
g := &GlobalContext{
initializers: extensionInitializers,
datasets: make(map[string]*baseDataset),
service: service,
}

schemas, err := getSchemas(ctx, tx)
schemas, err := getSchemas(ctx, db)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -200,8 +202,8 @@ func (g *GlobalContext) GetSchema(_ context.Context, dbid string) (*common.Schem
return dataset.schema, nil
}

// Execute executes a SQL statement on a dataset.
// It uses Kwil's SQL dialect.
// Execute executes a SQL statement on a dataset. If the statement is mutative,
// the tx must also be a sql.AccessModer. It uses Kwil's SQL dialect.
func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query string, values map[string]any) (*sql.ResultSet, error) {
g.mu.RLock()
defer g.mu.RUnlock()
Expand All @@ -220,8 +222,14 @@ func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query stri
return nil, err
}

if parsed.Mutative && tx.AccessMode() == sql.ReadOnly {
return nil, fmt.Errorf("cannot execute a mutative query in a read-only transaction")
if parsed.Mutative {
txm, ok := tx.(sql.AccessModer)
if !ok {
return nil, errors.New("DB does not provide access mode needed for mutative statement")
}
if txm.AccessMode() == sql.ReadOnly {
return nil, fmt.Errorf("cannot execute a mutative query in a read-only transaction")
}
}

args := orderAndCleanValueMap(values, parsed.ParameterOrder)
Expand Down
21 changes: 18 additions & 3 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package execution
import (
"bytes"
"context"
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -83,7 +84,11 @@ func prepareProcedure(unparsed *common.Procedure, global *GlobalContext, schema
// need to return an error
if !isViewProcedure {
instructions = append(instructions, instructionFunc(func(scope *precompiles.ProcedureContext, global *GlobalContext, db sql.DB) error {
if db.AccessMode() != sql.ReadWrite {
tx, ok := db.(sql.AccessModer)
if !ok {
return errors.New("DB does not provide access mode needed for mutative action")
}
if tx.AccessMode() != sql.ReadWrite {
return fmt.Errorf("cannot call non-view procedure, not in a chain transaction")
}

Expand Down Expand Up @@ -223,8 +228,14 @@ func (p *procedure) call(scope *precompiles.ProcedureContext, global *GlobalCont

// if procedure does not have view tag, then it can mutate state
// this means that we must have a readwrite connection
if !p.view && db.AccessMode() != sql.ReadWrite {
return fmt.Errorf(`%w: mutable procedure "%s" called with non-mutative scope`, ErrMutativeProcedure, p.name)
if !p.view {
tx, ok := db.(sql.AccessModer)
if !ok {
return errors.New("DB does not provide access mode needed for mutative action")
}
if tx.AccessMode() != sql.ReadWrite {
return fmt.Errorf(`%w: mutable procedure "%s" called with non-mutative scope`, ErrMutativeProcedure, p.name)
}
}

for i, param := range p.parameters {
Expand Down Expand Up @@ -260,6 +271,8 @@ type callMethod struct {
Receivers []string
}

var _ instructionFunc = (&callMethod{}).execute

// Execute calls a method from a namespace that is accessible within this dataset.
// If no namespace is specified, the local namespace is used.
// It will pass all arguments to the method, and assign the return values to the receivers.
Expand Down Expand Up @@ -348,6 +361,8 @@ type dmlStmt struct {
OrderedParameters []string
}

var _ instructionFunc = (&dmlStmt{}).execute

func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext, db sql.DB) error {
// Expend the arguments based on the ordered parameters for the DML statement.
params := orderAndCleanValueMap(scope.Values(), e.OrderedParameters)
Expand Down
6 changes: 3 additions & 3 deletions internal/engine/execution/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func createSchemasTableIfNotExists(ctx context.Context, tx sql.DB) error {
// It will also store the schema in the kwil_schemas table.
// It also creates the relevant tables, indexes, etc.
// If the schema already exists in the Kwil schemas table, it will be updated.
func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error {
func createSchema(ctx context.Context, tx sql.TxMaker, schema *common.Schema) error {
schemaName := dbidSchema(schema.DBID())

sp, err := tx.BeginTx(ctx)
Expand Down Expand Up @@ -98,7 +98,7 @@ func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error {
}

// getSchemas returns all schemas in the kwil_schemas table
func getSchemas(ctx context.Context, tx sql.DB) ([]*common.Schema, error) {
func getSchemas(ctx context.Context, tx sql.Executor) ([]*common.Schema, error) {
res, err := tx.Execute(ctx, sqlListSchemaContent)
if err != nil {
return nil, err
Expand Down Expand Up @@ -129,7 +129,7 @@ func getSchemas(ctx context.Context, tx sql.DB) ([]*common.Schema, error) {

// deleteSchema deletes a schema from the database.
// It will also delete the schema from the kwil_schemas table.
func deleteSchema(ctx context.Context, tx sql.DB, dbid string) error {
func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error {
schemaName := dbidSchema(dbid)

sp, err := tx.BeginTx(ctx)
Expand Down
Loading

0 comments on commit 5af4d4a

Please sign in to comment.