Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql,pg,engine: rework sql interfaces #601

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,6 @@ func initVoteStore(d *coreDependencies, db *pg.DB) {
failBuild(err, "failed to initialize vote store")
}

_, err = tx.Precommit(d.ctx)
if err != nil {
failBuild(err, "failed to precommit")
}

err = tx.Commit(d.ctx)
if err != nil {
failBuild(err, "failed to commit")
Expand Down Expand Up @@ -353,11 +348,6 @@ func buildEngine(d *coreDependencies, db *pg.DB) *execution.GlobalContext {
failBuild(err, "failed to build engine")
}

_, err = tx.Precommit(d.ctx)
if err != nil {
failBuild(err, "failed to precommit")
}

err = tx.Commit(d.ctx)
if err != nil {
failBuild(err, "failed to commit")
Expand All @@ -378,11 +368,6 @@ func initAccountRepository(d *coreDependencies, db *pg.DB) {
failBuild(err, "failed to initialize account store")
}

_, err = tx.Precommit(d.ctx)
if err != nil {
failBuild(err, "failed to precommit")
}

err = tx.Commit(d.ctx)
if err != nil {
failBuild(err, "failed to commit")
Expand Down
133 changes: 64 additions & 69 deletions common/sql/sql.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Package sql defines common type required by SQL database implementations and
// consumers.
package sql

import (
Expand All @@ -10,33 +12,6 @@ var (
ErrNoRows = errors.New("no rows in result set")
)

// DB is a connection to a Postgres database.
// It has root user access, and can execute any Postgres command.
type DB interface {
Executor
TxMaker
// AccessMode gets the access mode of the database.
// It can be either read-write or read-only.
AccessMode() AccessMode
}

// Executor is an interface that can execute queries.
type Executor interface {
// Execute executes a query or command.
// The stmt should be a valid Postgres statement.
Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error)
}

// Tx is a database transaction. It can be nested within other
// transactions.
type Tx interface {
DB
// Rollback rolls back the transaction.
Rollback(ctx context.Context) error
// Commit commits the transaction.
Commit(ctx context.Context) error
}

// ResultSet is the result of a query or execution.
// It contains the returned columns and the rows.
type ResultSet struct {
Expand All @@ -53,44 +28,45 @@ type CommandTag struct {
RowsAffected int64
}

// AccessMode is the type of access to a database.
// It can be read-write or read-only.
type AccessMode uint8
// Executor is an interface that can execute queries.
type Executor interface {
// Execute executes a query or command.
Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error)
}

const (
// ReadWrite is the default access mode.
// It allows for reading and writing to the database.
ReadWrite AccessMode = iota
// ReadOnly allows for reading from the database, but not
// writing.
ReadOnly
)
// TxMaker is an interface that creates a new transaction. In the context of the
// recursive Tx interface, is creates a nested transaction.
type TxMaker interface {
BeginTx(ctx context.Context) (Tx, error)
}

// TxCloser terminates a transaction by committing or rolling it
// back. A method that returns this alone would keep the tx under the
// hood of the parent type, directing queries internally through the
// scope of a transaction/session started with BeginTx.
// TxCloser terminates a transaction by committing or rolling it back.
type TxCloser interface {
// Rollback rolls back the transaction.
Rollback(ctx context.Context) error
// Commit commits the transaction.
Commit(ctx context.Context) error
}

// TxPrecommitter is the special kind of transaction that can prepare
// a transaction for commit. It is only available on the outermost
// transaction.
type TxPrecommitter interface {
Precommit(ctx context.Context) ([]byte, error)
}

type TxBeginner interface {
Begin(ctx context.Context) (TxCloser, error)
// Tx represents a database transaction. It can be nested within other
// transactions, and create new nested transactions. An implementation of Tx may
// also be an AccessModer, but it is not required.
type Tx interface {
Executor
TxCloser
TxMaker // recursive interface
// note: does not embed DB for clear semantics (DB makes a Tx, not the reverse)
}

// OuterTxMaker is the special kind of transaction beginner that can
// make nested transactions, and that explicitly scopes Query/Execute
// to the tx.
type OuterTxMaker interface {
BeginTx(ctx context.Context) (OuterTx, error)
// DB is a top level database interface, which may directly execute queries or
// create transactions, which may be closed or create additional nested
// transactions.
//
// Some implementations may also be an OuterTxMaker and/or a ReadTxMaker. Embed
// with those interfaces to compose the minimal interface required.
type DB interface {
Executor
TxMaker
}

// ReadTxMaker can make read-only transactions.
Expand All @@ -99,20 +75,39 @@ type ReadTxMaker interface {
BeginReadTx(ctx context.Context) (Tx, error)
}

// TxMaker is the special kind of transaction beginner that can make
// nested
// transactions, and that explicitly scopes Query/Execute to the tx.
type TxMaker interface {
BeginTx(ctx context.Context) (Tx, error)
}

// OuterTx is a database transaction. It is the outermost transaction
// type. "nested transactions" are called savepoints, and can be
// created with BeginSavepoint. Savepoints can be nested, and are
// rolled back to the innermost savepoint on Rollback.
// OuterTx is the outermost database transaction.
//
// Anything using implicit tx/session management should use TxCloser.
// NOTE: An OuterTx may be used where only a Tx or DB is required since those
// interfaces are a subset of the OuterTx method set.
type OuterTx interface {
Tx
TxPrecommitter
Precommit(ctx context.Context) ([]byte, error)
}

// OuterTxMaker is the special kind of transaction that creates a transaction
// that has a Precommit method (see OuterTx), which supports obtaining a commit
// ID using a (two-phase) prepared transaction prior to Commit. This is a
// different method name so that an implementation may satisfy both OuterTxMaker
// and TxMaker.
type OuterTxMaker interface {
BeginOuterTx(ctx context.Context) (OuterTx, error)
}

// AccessMode is the type of access to a database.
// It can be read-write or read-only.
type AccessMode uint8

const (
// ReadWrite is the default access mode.
// It allows for reading and writing to the database.
ReadWrite AccessMode = iota
// ReadOnly allows for reading from the database, but not writing.
ReadOnly
)

// AccessModer may be satisfied by implementations of Tx and DB, but is not
// universally required for those interfaces (type assert as needed).
type AccessModer interface {
// AccessMode gets the access mode of the database or transaction.
AccessMode() AccessMode
}
4 changes: 0 additions & 4 deletions internal/accounts/accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ func newDB() *mockDB {
}
}

func (m *mockDB) AccessMode() sql.AccessMode {
return sql.ReadWrite // not use in these tests
}

func (m *mockDB) BeginTx(ctx context.Context) (sql.Tx, error) {
return &mockTx{m}, nil
}
Expand Down
2 changes: 2 additions & 0 deletions internal/engine/execution/execution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ type mockDB struct {
executedStmts []string
}

var _ sql.AccessModer = (*mockDB)(nil)

func (m *mockDB) AccessMode() sql.AccessMode {
return m.accessMode
}
Expand Down
24 changes: 16 additions & 8 deletions internal/engine/execution/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package execution
import (
"bytes"
"context"
"errors"
"fmt"
"sort"
"sync"
Expand Down Expand Up @@ -51,17 +52,18 @@ func InitializeEngine(ctx context.Context, tx sql.DB) error {
return nil
}

// NewGlobalContext creates a new global context.
// It will load any persisted datasets from the datastore.
func NewGlobalContext(ctx context.Context, tx sql.DB, extensionInitializers map[string]precompiles.Initializer,
// NewGlobalContext creates a new global context. It will load any persisted
// datasets from the datastore. The provided database is only used for
// construction.
func NewGlobalContext(ctx context.Context, db sql.Executor, extensionInitializers map[string]precompiles.Initializer,
service *common.Service) (*GlobalContext, error) {
g := &GlobalContext{
initializers: extensionInitializers,
datasets: make(map[string]*baseDataset),
service: service,
}

schemas, err := getSchemas(ctx, tx)
schemas, err := getSchemas(ctx, db)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -200,8 +202,8 @@ func (g *GlobalContext) GetSchema(_ context.Context, dbid string) (*common.Schem
return dataset.schema, nil
}

// Execute executes a SQL statement on a dataset.
// It uses Kwil's SQL dialect.
// Execute executes a SQL statement on a dataset. If the statement is mutative,
// the tx must also be a sql.AccessModer. It uses Kwil's SQL dialect.
func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query string, values map[string]any) (*sql.ResultSet, error) {
g.mu.RLock()
defer g.mu.RUnlock()
Expand All @@ -220,8 +222,14 @@ func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query stri
return nil, err
}

if parsed.Mutative && tx.AccessMode() == sql.ReadOnly {
return nil, fmt.Errorf("cannot execute a mutative query in a read-only transaction")
if parsed.Mutative {
txm, ok := tx.(sql.AccessModer)
if !ok {
return nil, errors.New("DB does not provide access mode needed for mutative statement")
}
if txm.AccessMode() == sql.ReadOnly {
return nil, fmt.Errorf("cannot execute a mutative query in a read-only transaction")
}
}

args := orderAndCleanValueMap(values, parsed.ParameterOrder)
Expand Down
21 changes: 18 additions & 3 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package execution
import (
"bytes"
"context"
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -83,7 +84,11 @@ func prepareProcedure(unparsed *common.Procedure, global *GlobalContext, schema
// need to return an error
if !isViewProcedure {
instructions = append(instructions, instructionFunc(func(scope *precompiles.ProcedureContext, global *GlobalContext, db sql.DB) error {
if db.AccessMode() != sql.ReadWrite {
tx, ok := db.(sql.AccessModer)
if !ok {
return errors.New("DB does not provide access mode needed for mutative action")
}
if tx.AccessMode() != sql.ReadWrite {
return fmt.Errorf("cannot call non-view procedure, not in a chain transaction")
}

Expand Down Expand Up @@ -223,8 +228,14 @@ func (p *procedure) call(scope *precompiles.ProcedureContext, global *GlobalCont

// if procedure does not have view tag, then it can mutate state
// this means that we must have a readwrite connection
if !p.view && db.AccessMode() != sql.ReadWrite {
return fmt.Errorf(`%w: mutable procedure "%s" called with non-mutative scope`, ErrMutativeProcedure, p.name)
if !p.view {
tx, ok := db.(sql.AccessModer)
if !ok {
return errors.New("DB does not provide access mode needed for mutative action")
}
if tx.AccessMode() != sql.ReadWrite {
return fmt.Errorf(`%w: mutable procedure "%s" called with non-mutative scope`, ErrMutativeProcedure, p.name)
}
}

for i, param := range p.parameters {
Expand Down Expand Up @@ -260,6 +271,8 @@ type callMethod struct {
Receivers []string
}

var _ instructionFunc = (&callMethod{}).execute

// Execute calls a method from a namespace that is accessible within this dataset.
// If no namespace is specified, the local namespace is used.
// It will pass all arguments to the method, and assign the return values to the receivers.
Expand Down Expand Up @@ -348,6 +361,8 @@ type dmlStmt struct {
OrderedParameters []string
}

var _ instructionFunc = (&dmlStmt{}).execute

func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext, db sql.DB) error {
// Expend the arguments based on the ordered parameters for the DML statement.
params := orderAndCleanValueMap(scope.Values(), e.OrderedParameters)
Expand Down
6 changes: 3 additions & 3 deletions internal/engine/execution/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func createSchemasTableIfNotExists(ctx context.Context, tx sql.DB) error {
// It will also store the schema in the kwil_schemas table.
// It also creates the relevant tables, indexes, etc.
// If the schema already exists in the Kwil schemas table, it will be updated.
func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error {
func createSchema(ctx context.Context, tx sql.TxMaker, schema *common.Schema) error {
schemaName := dbidSchema(schema.DBID())

sp, err := tx.BeginTx(ctx)
Expand Down Expand Up @@ -98,7 +98,7 @@ func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error {
}

// getSchemas returns all schemas in the kwil_schemas table
func getSchemas(ctx context.Context, tx sql.DB) ([]*common.Schema, error) {
func getSchemas(ctx context.Context, tx sql.Executor) ([]*common.Schema, error) {
res, err := tx.Execute(ctx, sqlListSchemaContent)
if err != nil {
return nil, err
Expand Down Expand Up @@ -129,7 +129,7 @@ func getSchemas(ctx context.Context, tx sql.DB) ([]*common.Schema, error) {

// deleteSchema deletes a schema from the database.
// It will also delete the schema from the kwil_schemas table.
func deleteSchema(ctx context.Context, tx sql.DB, dbid string) error {
func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error {
schemaName := dbidSchema(dbid)

sp, err := tx.BeginTx(ctx)
Expand Down
Loading
Loading