From da359513f205abb34a120bd914a28717a28d9662 Mon Sep 17 00:00:00 2001 From: Jonathan Chappelow Date: Tue, 12 Mar 2024 18:00:11 -0500 Subject: [PATCH] sql,pg,engine: rework sql interfaces Reduce sql interfaces, and clean up pg methods. --- cmd/kwild/server/build.go | 15 -- common/sql/sql.go | 133 +++++++++--------- internal/accounts/accounts_test.go | 4 - internal/engine/execution/execution_test.go | 2 + internal/engine/execution/global.go | 24 ++-- internal/engine/execution/procedure.go | 21 ++- internal/engine/execution/queries.go | 6 +- internal/engine/integration/execution_test.go | 63 ++++----- internal/engine/integration/setup_test.go | 3 - internal/events/events.go | 6 +- internal/events/events_test.go | 19 +-- internal/services/grpc/txsvc/v1/call.go | 2 +- internal/services/grpc/txsvc/v1/query.go | 4 +- internal/sql/common.go | 55 -------- internal/sql/pg/conn.go | 48 ++----- internal/sql/pg/db.go | 38 +++-- internal/sql/pg/db_live_test.go | 9 +- internal/sql/pg/query.go | 2 +- internal/sql/pg/tx.go | 19 +-- internal/sql/versioning/sql.go | 2 +- internal/sql/versioning/versioning.go | 4 +- internal/txapp/interfaces.go | 4 +- internal/txapp/mempool.go | 6 +- internal/txapp/mempool_test.go | 4 - internal/txapp/routes_test.go | 36 ++--- internal/txapp/txapp.go | 6 +- internal/voting/voting.go | 57 ++++---- 27 files changed, 239 insertions(+), 353 deletions(-) delete mode 100644 internal/sql/common.go diff --git a/cmd/kwild/server/build.go b/cmd/kwild/server/build.go index 7c7537c0d..6cdb30310 100644 --- a/cmd/kwild/server/build.go +++ b/cmd/kwild/server/build.go @@ -265,11 +265,6 @@ func initVoteStore(d *coreDependencies, db *pg.DB) { failBuild(err, "failed to initialize vote store") } - _, err = tx.Precommit(d.ctx) - if err != nil { - failBuild(err, "failed to precommit") - } - err = tx.Commit(d.ctx) if err != nil { failBuild(err, "failed to commit") @@ -353,11 +348,6 @@ func buildEngine(d *coreDependencies, db *pg.DB) *execution.GlobalContext { failBuild(err, "failed to build engine") } - _, err = tx.Precommit(d.ctx) - if err != nil { - failBuild(err, "failed to precommit") - } - err = tx.Commit(d.ctx) if err != nil { failBuild(err, "failed to commit") @@ -378,11 +368,6 @@ func initAccountRepository(d *coreDependencies, db *pg.DB) { failBuild(err, "failed to initialize account store") } - _, err = tx.Precommit(d.ctx) - if err != nil { - failBuild(err, "failed to precommit") - } - err = tx.Commit(d.ctx) if err != nil { failBuild(err, "failed to commit") diff --git a/common/sql/sql.go b/common/sql/sql.go index 138498637..51dffc917 100644 --- a/common/sql/sql.go +++ b/common/sql/sql.go @@ -1,3 +1,5 @@ +// Package sql defines common type required by SQL database implementations and +// consumers. package sql import ( @@ -10,33 +12,6 @@ var ( ErrNoRows = errors.New("no rows in result set") ) -// DB is a connection to a Postgres database. -// It has root user access, and can execute any Postgres command. -type DB interface { - Executor - TxMaker - // AccessMode gets the access mode of the database. - // It can be either read-write or read-only. - AccessMode() AccessMode -} - -// Executor is an interface that can execute queries. -type Executor interface { - // Execute executes a query or command. - // The stmt should be a valid Postgres statement. - Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error) -} - -// Tx is a database transaction. It can be nested within other -// transactions. -type Tx interface { - DB - // Rollback rolls back the transaction. - Rollback(ctx context.Context) error - // Commit commits the transaction. - Commit(ctx context.Context) error -} - // ResultSet is the result of a query or execution. // It contains the returned columns and the rows. type ResultSet struct { @@ -53,44 +28,45 @@ type CommandTag struct { RowsAffected int64 } -// AccessMode is the type of access to a database. -// It can be read-write or read-only. -type AccessMode uint8 +// Executor is an interface that can execute queries. +type Executor interface { + // Execute executes a query or command. + Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error) +} -const ( - // ReadWrite is the default access mode. - // It allows for reading and writing to the database. - ReadWrite AccessMode = iota - // ReadOnly allows for reading from the database, but not - // writing. - ReadOnly -) +// TxMaker is an interface that creates a new transaction. In the context of the +// recursive Tx interface, is creates a nested transaction. +type TxMaker interface { + BeginTx(ctx context.Context) (Tx, error) +} -// TxCloser terminates a transaction by committing or rolling it -// back. A method that returns this alone would keep the tx under the -// hood of the parent type, directing queries internally through the -// scope of a transaction/session started with BeginTx. +// TxCloser terminates a transaction by committing or rolling it back. type TxCloser interface { + // Rollback rolls back the transaction. Rollback(ctx context.Context) error + // Commit commits the transaction. Commit(ctx context.Context) error } -// TxPrecommitter is the special kind of transaction that can prepare -// a transaction for commit. It is only available on the outermost -// transaction. -type TxPrecommitter interface { - Precommit(ctx context.Context) ([]byte, error) -} - -type TxBeginner interface { - Begin(ctx context.Context) (TxCloser, error) +// Tx represents a database transaction. It can be nested within other +// transactions, and create new nested transactions. An implementation of Tx may +// also be an AccessModer, but it is not required. +type Tx interface { + Executor + TxCloser + TxMaker // recursive interface + // note: does not embed DB for clear semantics (DB makes a Tx, not the reverse) } -// OuterTxMaker is the special kind of transaction beginner that can -// make nested transactions, and that explicitly scopes Query/Execute -// to the tx. -type OuterTxMaker interface { - BeginTx(ctx context.Context) (OuterTx, error) +// DB is a top level database interface, which may directly execute queries or +// create transactions, which may be closed or create additional nested +// transactions. +// +// Some implementations may also be an OuterTxMaker and/or a ReadTxMaker. Embed +// with those interfaces to compose the minimal interface required. +type DB interface { + Executor + TxMaker } // ReadTxMaker can make read-only transactions. @@ -99,20 +75,39 @@ type ReadTxMaker interface { BeginReadTx(ctx context.Context) (Tx, error) } -// TxMaker is the special kind of transaction beginner that can make -// nested -// transactions, and that explicitly scopes Query/Execute to the tx. -type TxMaker interface { - BeginTx(ctx context.Context) (Tx, error) -} - -// OuterTx is a database transaction. It is the outermost transaction -// type. "nested transactions" are called savepoints, and can be -// created with BeginSavepoint. Savepoints can be nested, and are -// rolled back to the innermost savepoint on Rollback. +// OuterTx is the outermost database transaction. // -// Anything using implicit tx/session management should use TxCloser. +// NOTE: An OuterTx may be used where only a Tx or DB is required since those +// interfaces are a subset of the OuterTx method set. type OuterTx interface { Tx - TxPrecommitter + Precommit(ctx context.Context) ([]byte, error) +} + +// OuterTxMaker is the special kind of transaction that creates a transaction +// that has a Precommit method (see OuterTx), which supports obtaining a commit +// ID using a (two-phase) prepared transaction prior to Commit. This is a +// different method name so that an implementation may satisfy both OuterTxMaker +// and TxMaker. +type OuterTxMaker interface { + BeginOuterTx(ctx context.Context) (OuterTx, error) +} + +// AccessMode is the type of access to a database. +// It can be read-write or read-only. +type AccessMode uint8 + +const ( + // ReadWrite is the default access mode. + // It allows for reading and writing to the database. + ReadWrite AccessMode = iota + // ReadOnly allows for reading from the database, but not writing. + ReadOnly +) + +// AccessModer may be satisfied by implementations of Tx and DB, but is not +// universally required for those interfaces (type assert as needed). +type AccessModer interface { + // AccessMode gets the access mode of the database or transaction. + AccessMode() AccessMode } diff --git a/internal/accounts/accounts_test.go b/internal/accounts/accounts_test.go index 3b4a17299..d5f0cd308 100644 --- a/internal/accounts/accounts_test.go +++ b/internal/accounts/accounts_test.go @@ -34,10 +34,6 @@ func newDB() *mockDB { } } -func (m *mockDB) AccessMode() sql.AccessMode { - return sql.ReadWrite // not use in these tests -} - func (m *mockDB) BeginTx(ctx context.Context) (sql.Tx, error) { return &mockTx{m}, nil } diff --git a/internal/engine/execution/execution_test.go b/internal/engine/execution/execution_test.go index d5785b211..1c8adb389 100644 --- a/internal/engine/execution/execution_test.go +++ b/internal/engine/execution/execution_test.go @@ -240,6 +240,8 @@ type mockDB struct { executedStmts []string } +var _ sql.AccessModer = (*mockDB)(nil) + func (m *mockDB) AccessMode() sql.AccessMode { return m.accessMode } diff --git a/internal/engine/execution/global.go b/internal/engine/execution/global.go index 8f3452403..cf4b3e457 100644 --- a/internal/engine/execution/global.go +++ b/internal/engine/execution/global.go @@ -3,6 +3,7 @@ package execution import ( "bytes" "context" + "errors" "fmt" "sort" "sync" @@ -51,9 +52,10 @@ func InitializeEngine(ctx context.Context, tx sql.DB) error { return nil } -// NewGlobalContext creates a new global context. -// It will load any persisted datasets from the datastore. -func NewGlobalContext(ctx context.Context, tx sql.DB, extensionInitializers map[string]precompiles.Initializer, +// NewGlobalContext creates a new global context. It will load any persisted +// datasets from the datastore. The provided database is only used for +// construction. +func NewGlobalContext(ctx context.Context, db sql.Executor, extensionInitializers map[string]precompiles.Initializer, service *common.Service) (*GlobalContext, error) { g := &GlobalContext{ initializers: extensionInitializers, @@ -61,7 +63,7 @@ func NewGlobalContext(ctx context.Context, tx sql.DB, extensionInitializers map[ service: service, } - schemas, err := getSchemas(ctx, tx) + schemas, err := getSchemas(ctx, db) if err != nil { return nil, err } @@ -200,8 +202,8 @@ func (g *GlobalContext) GetSchema(_ context.Context, dbid string) (*common.Schem return dataset.schema, nil } -// Execute executes a SQL statement on a dataset. -// It uses Kwil's SQL dialect. +// Execute executes a SQL statement on a dataset. If the statement is mutative, +// the tx must also be a sql.AccessModer. It uses Kwil's SQL dialect. func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query string, values map[string]any) (*sql.ResultSet, error) { g.mu.RLock() defer g.mu.RUnlock() @@ -220,8 +222,14 @@ func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query stri return nil, err } - if parsed.Mutative && tx.AccessMode() == sql.ReadOnly { - return nil, fmt.Errorf("cannot execute a mutative query in a read-only transaction") + if parsed.Mutative { + txm, ok := tx.(sql.AccessModer) + if !ok { + return nil, errors.New("DB does not provide access mode needed for mutative statement") + } + if txm.AccessMode() == sql.ReadOnly { + return nil, fmt.Errorf("cannot execute a mutative query in a read-only transaction") + } } args := orderAndCleanValueMap(values, parsed.ParameterOrder) diff --git a/internal/engine/execution/procedure.go b/internal/engine/execution/procedure.go index dd840529f..d3bd8271d 100644 --- a/internal/engine/execution/procedure.go +++ b/internal/engine/execution/procedure.go @@ -3,6 +3,7 @@ package execution import ( "bytes" "context" + "errors" "fmt" "strings" @@ -83,7 +84,11 @@ func prepareProcedure(unparsed *common.Procedure, global *GlobalContext, schema // need to return an error if !isViewProcedure { instructions = append(instructions, instructionFunc(func(scope *precompiles.ProcedureContext, global *GlobalContext, db sql.DB) error { - if db.AccessMode() != sql.ReadWrite { + tx, ok := db.(sql.AccessModer) + if !ok { + return errors.New("DB does not provide access mode needed for mutative action") + } + if tx.AccessMode() != sql.ReadWrite { return fmt.Errorf("cannot call non-view procedure, not in a chain transaction") } @@ -223,8 +228,14 @@ func (p *procedure) call(scope *precompiles.ProcedureContext, global *GlobalCont // if procedure does not have view tag, then it can mutate state // this means that we must have a readwrite connection - if !p.view && db.AccessMode() != sql.ReadWrite { - return fmt.Errorf(`%w: mutable procedure "%s" called with non-mutative scope`, ErrMutativeProcedure, p.name) + if !p.view { + tx, ok := db.(sql.AccessModer) + if !ok { + return errors.New("DB does not provide access mode needed for mutative action") + } + if tx.AccessMode() != sql.ReadWrite { + return fmt.Errorf(`%w: mutable procedure "%s" called with non-mutative scope`, ErrMutativeProcedure, p.name) + } } for i, param := range p.parameters { @@ -260,6 +271,8 @@ type callMethod struct { Receivers []string } +var _ instructionFunc = (&callMethod{}).execute + // Execute calls a method from a namespace that is accessible within this dataset. // If no namespace is specified, the local namespace is used. // It will pass all arguments to the method, and assign the return values to the receivers. @@ -348,6 +361,8 @@ type dmlStmt struct { OrderedParameters []string } +var _ instructionFunc = (&dmlStmt{}).execute + func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext, db sql.DB) error { // Expend the arguments based on the ordered parameters for the DML statement. params := orderAndCleanValueMap(scope.Values(), e.OrderedParameters) diff --git a/internal/engine/execution/queries.go b/internal/engine/execution/queries.go index 988561651..be962d21e 100644 --- a/internal/engine/execution/queries.go +++ b/internal/engine/execution/queries.go @@ -52,7 +52,7 @@ func createSchemasTableIfNotExists(ctx context.Context, tx sql.DB) error { // It will also store the schema in the kwil_schemas table. // It also creates the relevant tables, indexes, etc. // If the schema already exists in the Kwil schemas table, it will be updated. -func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error { +func createSchema(ctx context.Context, tx sql.TxMaker, schema *common.Schema) error { schemaName := dbidSchema(schema.DBID()) sp, err := tx.BeginTx(ctx) @@ -98,7 +98,7 @@ func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error { } // getSchemas returns all schemas in the kwil_schemas table -func getSchemas(ctx context.Context, tx sql.DB) ([]*common.Schema, error) { +func getSchemas(ctx context.Context, tx sql.Executor) ([]*common.Schema, error) { res, err := tx.Execute(ctx, sqlListSchemaContent) if err != nil { return nil, err @@ -129,7 +129,7 @@ func getSchemas(ctx context.Context, tx sql.DB) ([]*common.Schema, error) { // deleteSchema deletes a schema from the database. // It will also delete the schema from the kwil_schemas table. -func deleteSchema(ctx context.Context, tx sql.DB, dbid string) error { +func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error { schemaName := dbidSchema(dbid) sp, err := tx.BeginTx(ctx) diff --git a/internal/engine/integration/execution_test.go b/internal/engine/integration/execution_test.go index a91a4b9d1..57e77ede9 100644 --- a/internal/engine/integration/execution_test.go +++ b/internal/engine/integration/execution_test.go @@ -20,24 +20,24 @@ func Test_Engine(t *testing.T) { type testCase struct { name string // ses1 is the first round of execution - ses1 func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) + ses1 func(t *testing.T, global *execution.GlobalContext, tx sql.DB) // ses2 is the second round of execution - ses2 func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) + ses2 func(t *testing.T, global *execution.GlobalContext, tx sql.DB) // after is called after the second round // It is not called in a session, and therefore can only read from the database. - after func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) + after func(t *testing.T, global *execution.GlobalContext, tx sql.DB) } tests := []testCase{ { name: "create database", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) require.NoError(t, err) }, - after: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + after: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() schema, err := global.GetSchema(ctx, testdata.TestSchema.DBID()) require.NoError(t, err) @@ -53,18 +53,18 @@ func Test_Engine(t *testing.T) { }, { name: "drop database", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) require.NoError(t, err) }, - ses2: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses2: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.DeleteDataset(ctx, tx, testdata.TestSchema.DBID(), testdata.TestSchema.Owner) require.NoError(t, err) }, - after: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + after: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() dbs, err := global.ListDatasets(ctx, testdata.TestSchema.Owner) require.NoError(t, err) @@ -74,12 +74,12 @@ func Test_Engine(t *testing.T) { }, { name: "execute procedures", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) require.NoError(t, err) }, - ses2: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses2: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { signer := "signer" ctx := context.Background() @@ -101,7 +101,7 @@ func Test_Engine(t *testing.T) { }) require.NoError(t, err) }, - after: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + after: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() res, err := global.Procedure(ctx, tx, &common.ExecutionData{ @@ -136,12 +136,12 @@ func Test_Engine(t *testing.T) { }, { name: "executing outside of a commit", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) require.NoError(t, err) }, - after: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + after: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() _, err := global.Procedure(ctx, tx, &common.ExecutionData{ @@ -156,7 +156,7 @@ func Test_Engine(t *testing.T) { }, { name: "calling outside of a commit", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) require.NoError(t, err) @@ -170,7 +170,7 @@ func Test_Engine(t *testing.T) { }) require.NoError(t, err) }, - after: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + after: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() users, err := global.Procedure(ctx, tx, &common.ExecutionData{ @@ -188,7 +188,7 @@ func Test_Engine(t *testing.T) { }, { name: "deploying database and immediately calling procedure", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) @@ -203,7 +203,7 @@ func Test_Engine(t *testing.T) { }) require.NoError(t, err) }, - after: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + after: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() users, err := global.Procedure(ctx, tx, &common.ExecutionData{ @@ -221,7 +221,7 @@ func Test_Engine(t *testing.T) { }, { name: "test failed extension init", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() oldExtensions := []*common.Extension{} @@ -251,7 +251,7 @@ func Test_Engine(t *testing.T) { }, { name: "owner only action", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) @@ -278,7 +278,7 @@ func Test_Engine(t *testing.T) { }, { name: "private action", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) @@ -311,7 +311,7 @@ func Test_Engine(t *testing.T) { // and it's actually preferable that we can support this. Logically, it makes sense // that a deploy tx followed by an execute tx in the same block should work. name: "deploy and call at the same time", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) @@ -338,7 +338,7 @@ func Test_Engine(t *testing.T) { }, { name: "deploy many databases", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() for i := 0; i < 10; i++ { @@ -352,7 +352,7 @@ func Test_Engine(t *testing.T) { }, { name: "deploying and immediately dropping", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() err := global.CreateDataset(ctx, tx, testdata.TestSchema, testdata.TestSchema.Owner) @@ -364,7 +364,7 @@ func Test_Engine(t *testing.T) { }, { name: "case insensitive", - ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) { + ses1: func(t *testing.T, global *execution.GlobalContext, tx sql.DB) { ctx := context.Background() schema := *caseSchema @@ -425,13 +425,13 @@ func Test_Engine(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.ses1 == nil { - test.ses1 = func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) {} + test.ses1 = func(t *testing.T, global *execution.GlobalContext, tx sql.DB) {} } if test.ses2 == nil { - test.ses2 = func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) {} + test.ses2 = func(t *testing.T, global *execution.GlobalContext, tx sql.DB) {} } if test.after == nil { - test.after = func(t *testing.T, global *execution.GlobalContext, tx sql.Tx) {} + test.after = func(t *testing.T, global *execution.GlobalContext, tx sql.DB) {} } global, db, err := setup(t) @@ -442,13 +442,13 @@ func Test_Engine(t *testing.T) { ctx := context.Background() - tx, err := db.BeginTx(ctx) + tx, err := db.BeginOuterTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) test.ses1(t, global, tx) - id, err := tx.Precommit(ctx) + id, err := tx.Precommit(ctx) // not needed, but test how txApp would use the engine require.NoError(t, err) require.NotEmpty(t, id) @@ -461,9 +461,8 @@ func Test_Engine(t *testing.T) { test.ses2(t, global, tx2) - id, err = tx2.Precommit(ctx) - require.NoError(t, err) - require.NotEmpty(t, id) + // Omit Precommit here, just to test that it's allowed even though + // txApp would want the commit ID. err = tx2.Commit(ctx) require.NoError(t, err) diff --git a/internal/engine/integration/setup_test.go b/internal/engine/integration/setup_test.go index 8f440aff8..761500052 100644 --- a/internal/engine/integration/setup_test.go +++ b/internal/engine/integration/setup_test.go @@ -84,9 +84,6 @@ func setup(t *testing.T) (global *execution.GlobalContext, db *pg.DB, err error) }) require.NoError(t, err) - _, err = tx.Precommit(ctx) - require.NoError(t, err) - err = tx.Commit(ctx) require.NoError(t, err) diff --git a/internal/events/events.go b/internal/events/events.go index 4e5fb7cd3..99daf07dd 100644 --- a/internal/events/events.go +++ b/internal/events/events.go @@ -114,7 +114,7 @@ func (e *EventStore) Store(ctx context.Context, data []byte, eventType string) e } // GetEvents gets all events in the event store. -func GetEvents(ctx context.Context, db sql.DB) ([]*types.VotableEvent, error) { +func GetEvents(ctx context.Context, db sql.Executor) ([]*types.VotableEvent, error) { res, err := db.Execute(ctx, getEvents) if err != nil { return nil, err @@ -147,7 +147,7 @@ func GetEvents(ctx context.Context, db sql.DB) ([]*types.VotableEvent, error) { // DeleteEvent deletes an event from the event store. // It is idempotent. If the event does not exist, it will not return an error. -func DeleteEvent(ctx context.Context, db sql.DB, id types.UUID) error { +func DeleteEvent(ctx context.Context, db sql.Executor, id types.UUID) error { _, err := db.Execute(ctx, deleteEvent, id[:]) return err } @@ -204,7 +204,7 @@ func (e *EventStore) MarkBroadcasted(ctx context.Context, ids []types.UUID) erro } // MarkReceived marks that an event has been received by the network, and should not be re-broadcasted. -func MarkReceived(ctx context.Context, db sql.DB, id types.UUID) error { +func MarkReceived(ctx context.Context, db sql.Executor, id types.UUID) error { _, err := db.Execute(ctx, markReceived, id[:]) return err } diff --git a/internal/events/events_test.go b/internal/events/events_test.go index 189f58916..544c09769 100644 --- a/internal/events/events_test.go +++ b/internal/events/events_test.go @@ -212,25 +212,14 @@ func Test_EventStore(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - dbVoting, err := dbtest.NewTestDB(t) - txVoting, _ := dbVoting.BeginTx(ctx) - err = voting.InitializeVoteStore(ctx, txVoting) + db, cleanup, err := dbtest.NewTestPool(ctx, []string{schemaName, "kwild_voting"}) // db is the event store specific connection require.NoError(t, err) - _, err = txVoting.Precommit(ctx) - require.NoError(t, err) - err = txVoting.Commit(ctx) - require.NoError(t, err) - // defer // drop kwild_voting - defer func() { - dbVoting.AutoCommit(true) - dbVoting.Execute(ctx, `drop schema if exists kwild_voting cascade;`) - }() + defer cleanup() - db, cleanup, err := dbtest.NewTestPool(ctx, []string{schemaName}) // db is the event store specific connection + err = voting.InitializeVoteStore(ctx, db) require.NoError(t, err) - defer cleanup() - e, err := NewEventStore(ctx, db) + e, err := NewEventStore(ctx, db) // needs BeginReadTx require.NoError(t, err) // create a second db connection to emulate the consensus db diff --git a/internal/services/grpc/txsvc/v1/call.go b/internal/services/grpc/txsvc/v1/call.go index 8878f7d17..7b5be11bc 100644 --- a/internal/services/grpc/txsvc/v1/call.go +++ b/internal/services/grpc/txsvc/v1/call.go @@ -58,7 +58,7 @@ func (s *Service) Call(ctx context.Context, req *txpb.CallRequest) (*txpb.CallRe // marshalling the map is less efficient, but necessary for backwards compatibility - btsResult, err := json.Marshal(ResultMap(executeResult)) + btsResult, err := json.Marshal(resultMap(executeResult)) if err != nil { return nil, status.Errorf(codes.Internal, "failed to marshal call result") } diff --git a/internal/services/grpc/txsvc/v1/query.go b/internal/services/grpc/txsvc/v1/query.go index d17fa97f0..8a3e92c45 100644 --- a/internal/services/grpc/txsvc/v1/query.go +++ b/internal/services/grpc/txsvc/v1/query.go @@ -25,7 +25,7 @@ func (s *Service) Query(ctx context.Context, req *txpb.QueryRequest) (*txpb.Quer return nil, status.Error(codes.InvalidArgument, err.Error()) } - bts, err := json.Marshal(ResultMap(result)) // marshalling the map is less efficient, but necessary for backwards compatibility + bts, err := json.Marshal(resultMap(result)) // marshalling the map is less efficient, but necessary for backwards compatibility if err != nil { return nil, status.Errorf(codes.Internal, "failed to marshal call result") } @@ -35,7 +35,7 @@ func (s *Service) Query(ctx context.Context, req *txpb.QueryRequest) (*txpb.Quer }, nil } -func ResultMap(r *sql.ResultSet) []map[string]any { +func resultMap(r *sql.ResultSet) []map[string]any { m := make([]map[string]any, len(r.Rows)) for i, row := range r.Rows { m2 := make(map[string]any) diff --git a/internal/sql/common.go b/internal/sql/common.go deleted file mode 100644 index 37a171446..000000000 --- a/internal/sql/common.go +++ /dev/null @@ -1,55 +0,0 @@ -package sql - -// var ( -// ErrNoTransaction = errors.New("no transaction") -// ErrNoRows = errors.New("no rows in result set") -// ) - -// // TxCloser terminates a transaction by committing or rolling it back. A method -// // that returns this alone would keep the tx under the hood of the parent type, -// // directing queries internally through the scope of a transaction/session -// // started with BeginTx. -// type TxCloser interface { -// Rollback(ctx context.Context) error -// Commit(ctx context.Context) error -// } - -// // TxPrecommitter is the special kind of transaction that can prepare a -// // transaction for commit. -// // It is only available on the outermost transaction. -// type TxPrecommitter interface { -// Precommit(ctx context.Context) ([]byte, error) -// } - -// type TxBeginner interface { -// Begin(ctx context.Context) (TxCloser, error) -// } - -// // OuterTxMaker is the special kind of transaction beginner that can make nested -// // transactions, and that explicitly scopes Query/Execute to the tx. -// type OuterTxMaker interface { -// BeginTx(ctx context.Context) (OuterTx, error) -// } - -// // ReadTxMaker can make read-only transactions. -// // Many read-only transactions can be made at once. -// type ReadTxMaker interface { -// BeginReadTx(ctx context.Context) (common.Tx, error) -// } - -// // TxMaker is the special kind of transaction beginner that can make nested -// // transactions, and that explicitly scopes Query/Execute to the tx. -// type TxMaker interface { -// BeginTx(ctx context.Context) (common.Tx, error) -// } - -// // OuterTx is a database transaction. It is the outermost transaction type. -// // "nested transactions" are called savepoints, and can be created with -// // BeginSavepoint. Savepoints can be nested, and are rolled back to the -// // innermost savepoint on Rollback. -// // -// // Anything using implicit tx/session management should use TxCloser. -// type OuterTx interface { -// common.Tx -// TxPrecommitter -// } diff --git a/internal/sql/pg/conn.go b/internal/sql/pg/conn.go index 94857e7af..f6437726e 100644 --- a/internal/sql/pg/conn.go +++ b/internal/sql/pg/conn.go @@ -145,6 +145,7 @@ func (p *Pool) Query(ctx context.Context, stmt string, args ...any) (*sql.Result // intended to be used with the DB type, which performs all such operations via // the Tx returned from BeginTx. +// Execute performs a read-write query on the writer connection. func (p *Pool) Execute(ctx context.Context, stmt string, args ...any) (*sql.ResultSet, error) { return query(ctx, &cqWrapper{p.writer}, stmt, args...) } @@ -154,42 +155,12 @@ func (p *Pool) Close() error { return p.writer.Close(context.TODO()) } -type poolTx struct { - pgx.Tx - RowsAffected int64 // for debugging and testing -} - -// Execute is now identical to Query. We should consider removing Query as a -// transaction method since their is no semantic or syntactic difference -// (transactions generated from DB or Pool use the write connection). -func (ptx *poolTx) Execute(ctx context.Context, stmt string, args ...any) (*sql.ResultSet, error) { - // This method is now identical to Query, but we previously used pgx.Tx.Exec - // res,_ := ptx.Tx.Exec(ctx, stmt, args...) - // ptx.RowsAffected += res.RowsAffected() - return query(ctx, ptx.Tx, stmt, args...) -} - -func (ptx *poolTx) Query(ctx context.Context, stmt string, args ...any) (*sql.ResultSet, error) { - return query(ctx, ptx.Tx, stmt, args...) -} - -// Begin starts a read-write transaction on the writer connection. -func (p *Pool) Begin(ctx context.Context) (sql.TxCloser, error) { - tx, err := p.writer.BeginTx(ctx, pgx.TxOptions{ - IsoLevel: pgx.ReadCommitted, - AccessMode: pgx.ReadWrite, - }) - if err != nil { - return nil, err - } - return &poolTx{tx, 0}, nil -} - -// BeginTx starts a read-write transaction. +// BeginTx starts a read-write transaction. It is an error to call this twice +// without first closing the initial transaction. func (p *Pool) BeginTx(ctx context.Context) (sql.Tx, error) { tx, err := p.writer.BeginTx(ctx, pgx.TxOptions{ - IsoLevel: pgx.ReadCommitted, AccessMode: pgx.ReadWrite, + IsoLevel: pgx.ReadCommitted, }) if err != nil { return nil, err @@ -202,7 +173,10 @@ func (p *Pool) BeginTx(ctx context.Context) (sql.Tx, error) { // BeginReadTx starts a read-only transaction. func (p *Pool) BeginReadTx(ctx context.Context) (sql.Tx, error) { - tx, err := p.pgxp.Begin(ctx) + tx, err := p.pgxp.BeginTx(ctx, pgx.TxOptions{ + AccessMode: pgx.ReadOnly, + IsoLevel: pgx.RepeatableRead, + }) if err != nil { return nil, err } @@ -211,9 +185,3 @@ func (p *Pool) BeginReadTx(ctx context.Context) (sql.Tx, error) { accessMode: sql.ReadOnly, }, nil } - -// AccessMode implements the sql.DB interface. -// It is always ReadWrite for the pool. -func (p *Pool) AccessMode() sql.AccessMode { - return sql.ReadWrite -} diff --git a/internal/sql/pg/db.go b/internal/sql/pg/db.go index 3a39b0660..ea59e4f5e 100644 --- a/internal/sql/pg/db.go +++ b/internal/sql/pg/db.go @@ -33,6 +33,9 @@ import ( // 3. Emulating SQLite changesets by collecting WAL data for updates from a // dedicated logical replication connection and slot. The Precommit method // is used to retrieve the commit ID prior to Commit. +// +// DB requires a superuser connection to a Postgres database that can perform +// administrative actions on the database. type DB struct { pool *Pool // raw connection pool repl *replMon // logical replication monitor for collecting commit IDs @@ -80,6 +83,10 @@ var defaultSchemaFilter = func(schema string) bool { // PoolConfig plus a special connection for a logical replication slot receiver. // The database user (postgresql "role") must be a super user for several // reasons: creating triggers, collations, and the replication publication. +// +// WARNING: There must only be ONE instance of a DB for a given postgres +// database. Transactions that use the Precommit method update an internal table +// used to sequence transactions. func NewDB(ctx context.Context, cfg *DBConfig) (*DB, error) { // Create the unrestricted connection pool. pool, err := NewPool(ctx, &cfg.PoolConfig) @@ -196,7 +203,7 @@ var _ sql.OuterTxMaker = (*DB)(nil) // for dataset Registry // The returned transaction is also capable of creating nested transactions. // This functionality is used to prevent user dataset query errors from rolling // back the outermost transaction. -func (db *DB) BeginTx(ctx context.Context) (sql.OuterTx, error) { +func (db *DB) BeginOuterTx(ctx context.Context) (sql.OuterTx, error) { tx, err := db.beginWriterTx(ctx) if err != nil { return nil, err @@ -213,6 +220,13 @@ func (db *DB) BeginTx(ctx context.Context) (sql.OuterTx, error) { }, nil } +var _ sql.TxMaker = (*DB)(nil) +var _ sql.DB = (*DB)(nil) + +func (db *DB) BeginTx(ctx context.Context) (sql.Tx, error) { + return db.BeginOuterTx(ctx) // slice off the Precommit method from sql.OuterTx +} + // ReadTx creates a read-only transaction for the database. // It obtains a read connection from the pool, which will be returned // to the pool when the transaction is closed. @@ -224,7 +238,7 @@ func (db *DB) BeginReadTx(ctx context.Context) (sql.Tx, error) { tx, err := conn.BeginTx(ctx, pgx.TxOptions{ AccessMode: pgx.ReadOnly, - IsoLevel: pgx.RepeatableRead, + IsoLevel: pgx.RepeatableRead, // only for read-only as repeatable ready can fail a write tx commit }) if err != nil { conn.Release() @@ -238,19 +252,10 @@ func (db *DB) BeginReadTx(ctx context.Context) (sql.Tx, error) { return &readTx{ nestedTx: ntx, - release: conn.Release, + release: sync.OnceFunc(conn.Release), }, nil } -var _ sql.TxBeginner = (*DB)(nil) // for CommittableStore => MultiCommitter - -// Begin is for consumers that require a smaller interface on the return but -// same instance of the concrete type, a case which annoyingly creates -// incompatible interfaces in Go. -func (db *DB) Begin(ctx context.Context) (sql.TxCloser, error) { - return db.BeginTx(ctx) // just slice down sql.Tx -} - // beginWriterTx is the critical section of BeginTx. // It creates a new transaction on the write connection, and stores it in the // DB's tx field. It is not exported, and is only called from BeginTx. @@ -264,7 +269,7 @@ func (db *DB) beginWriterTx(ctx context.Context) (pgx.Tx, error) { tx, err := db.pool.writer.BeginTx(ctx, pgx.TxOptions{ AccessMode: pgx.ReadWrite, - IsoLevel: pgx.ReadUncommitted, + IsoLevel: pgx.ReadUncommitted, // consider if ReadCommitted would be fine. uncommitted refers to other transactions, not needed }) if err != nil { return nil, err @@ -328,8 +333,11 @@ func (db *DB) commit(ctx context.Context) error { if db.tx == nil { return errors.New("no tx exists") } - if db.txid == "" { // NOTE: we could consider doing a regular commit if not using prepared, but for now we that flow - return errors.New("transaction not yet prepared") + if db.txid == "" { + // Allow commit without two-phase prepare + err := db.tx.Commit(ctx) + db.tx = nil + return err } defer func() { diff --git a/internal/sql/pg/db_live_test.go b/internal/sql/pg/db_live_test.go index 6939d34e5..fa2e1b817 100644 --- a/internal/sql/pg/db_live_test.go +++ b/internal/sql/pg/db_live_test.go @@ -254,7 +254,7 @@ func TestNestedTx(t *testing.T) { } // Start the outer transaction. - tx, err := db.BeginTx(ctx) + tx, err := db.BeginOuterTx(ctx) if err != nil { t.Fatal(err) } @@ -326,11 +326,6 @@ func TestNestedTx(t *testing.T) { t.Fatal(err) } - err = tx.Commit(ctx) - if err == nil { - t.Fatalf("commit should have errored without precommit first") - } - id, err := tx.Precommit(ctx) if err != nil { t.Fatal(err) @@ -346,6 +341,8 @@ func TestNestedTx(t *testing.T) { // TODO: enure updates in other non-failed savepoints take } +// func TestCommitWithoutPrecommit + // tests that a read tx can be created and used // while another tx is in progress func TestReadTxs(t *testing.T) { diff --git a/internal/sql/pg/query.go b/internal/sql/pg/query.go index 2506e638a..fcdb880b4 100644 --- a/internal/sql/pg/query.go +++ b/internal/sql/pg/query.go @@ -158,7 +158,7 @@ func query(ctx context.Context, cq connQueryer, stmt string, args ...any) (*sql. if mustInferArgs(args) { // return nil, errors.New("cannot use QueryModeInferredArgTypes with query") args = args[1:] // args[0] was QueryModeInferredArgTypes - q = func(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + q = func(ctx context.Context, stmt string, args ...any) (pgx.Rows, error) { return queryImpliedArgTypes(ctx, cq.Conn(), stmt, args...) } } diff --git a/internal/sql/pg/tx.go b/internal/sql/pg/tx.go index 90650c0e8..a0c4c4869 100644 --- a/internal/sql/pg/tx.go +++ b/internal/sql/pg/tx.go @@ -4,7 +4,6 @@ package pg import ( "context" - "sync" "github.com/jackc/pgx/v5" common "github.com/kwilteam/kwil-db/common/sql" @@ -20,7 +19,9 @@ type nestedTx struct { var _ common.Tx = (*nestedTx)(nil) -// TODO: switch this to be BeginTx +// BeginTx creates a new transaction with the same access mode as the parent. +// Internally this is savepoint, which allows rollback to the innermost +// savepoint rather than the entire outer transaction. func (tx *nestedTx) BeginTx(ctx context.Context) (common.Tx, error) { // Make the nested transaction (savepoint) pgtx, err := tx.Tx.Begin(ctx) @@ -49,13 +50,6 @@ func (tx *nestedTx) AccessMode() common.AccessMode { return tx.accessMode } -// Commit is direct from embedded pgx.Tx. -// func (tx *nestedTx) Commit(ctx context.Context) error { return tx.Tx.Commit(ctx) } - -// Rollback is direct from embedded pgx.Tx. It is ok to call Rollback repeatedly -// and even after Commit with no error. -// func (tx *nestedTx) Rollback(ctx context.Context) error { return tx.Tx.Rollback(ctx) } - // dbTx is the type returned by (*DB).BeginTx. It embeds all the nestedTx // methods (thus returning a *nestedTx from it's BeginTx), but shadows Commit // and Rollback to allow the DB to begin a subsequent transaction, and to @@ -93,8 +87,7 @@ func (tx *dbTx) AccessMode() common.AccessMode { // when it is closed. type readTx struct { *nestedTx - release func() // should only be run once - once sync.Once + release func() } // Commit is a no-op for read-only transactions. @@ -105,7 +98,7 @@ func (tx *readTx) Commit(ctx context.Context) error { return err } - tx.once.Do(tx.release) + tx.release() return nil } @@ -116,6 +109,6 @@ func (tx *readTx) Rollback(ctx context.Context) error { return err } - tx.once.Do(tx.release) + tx.release() return nil } diff --git a/internal/sql/versioning/sql.go b/internal/sql/versioning/sql.go index d1c2d3dc5..9a23524ed 100644 --- a/internal/sql/versioning/sql.go +++ b/internal/sql/versioning/sql.go @@ -29,7 +29,7 @@ var ( ) // getCurrentVersion returns the current version of the database. -func getCurrentVersion(ctx context.Context, db sql.DB, schema string) (int64, error) { +func getCurrentVersion(ctx context.Context, db sql.Executor, schema string) (int64, error) { res, err := db.Execute(ctx, fmt.Sprintf(sqlCurrentVersion, schema)) if err != nil { return 0, err diff --git a/internal/sql/versioning/versioning.go b/internal/sql/versioning/versioning.go index 74908cd47..25419406a 100644 --- a/internal/sql/versioning/versioning.go +++ b/internal/sql/versioning/versioning.go @@ -14,7 +14,7 @@ var ( // EnsureVersionTableExists ensures that the version table exists in the database. // If the table does not exist, it will be created, and the first version will be set to version specified. -func ensureVersionTableExists(ctx context.Context, db sql.DB, schema string) error { +func ensureVersionTableExists(ctx context.Context, db sql.TxMaker, schema string) error { tx, err := db.BeginTx(ctx) if err != nil { return err @@ -48,7 +48,7 @@ func ensureVersionTableExists(ctx context.Context, db sql.DB, schema string) err // All versions should start at 0. // If the database is fresh, the schema will be initialized to the target version. // Raw initialization at the target version can be done by providing a function for versions -1. -func Upgrade(ctx context.Context, db sql.DB, schema string, versions map[int64]UpgradeFunc, targetVersion int64) error { +func Upgrade(ctx context.Context, db sql.TxMaker, schema string, versions map[int64]UpgradeFunc, targetVersion int64) error { tx, err := db.BeginTx(ctx) if err != nil { return err diff --git a/internal/txapp/interfaces.go b/internal/txapp/interfaces.go index ca68abd50..1fda31390 100644 --- a/internal/txapp/interfaces.go +++ b/internal/txapp/interfaces.go @@ -16,7 +16,9 @@ type Rebroadcaster interface { MarkRebroadcast(ctx context.Context, ids []types.UUID) error } -// DB is the interface for the main SQL database. +// DB is the interface for the main SQL database. All queries must be executed +// from within a transaction. A DB can create read transactions or the special +// two-phase outer write transaction. type DB interface { sql.OuterTxMaker sql.ReadTxMaker diff --git a/internal/txapp/mempool.go b/internal/txapp/mempool.go index f606db612..f102f6b16 100644 --- a/internal/txapp/mempool.go +++ b/internal/txapp/mempool.go @@ -21,7 +21,7 @@ type mempool struct { } // accountInfo retrieves the account info from the mempool state or the account store. -func (m *mempool) accountInfo(ctx context.Context, tx sql.DB, acctID []byte) (*types.Account, error) { +func (m *mempool) accountInfo(ctx context.Context, tx sql.Executor, acctID []byte) (*types.Account, error) { if acctInfo, ok := m.accounts[string(acctID)]; ok { return acctInfo, nil // there is an unconfirmed tx for this account } @@ -39,7 +39,7 @@ func (m *mempool) accountInfo(ctx context.Context, tx sql.DB, acctID []byte) (*t } // accountInfoSafe is wraps accountInfo in a mutex lock. -func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.DB, acctID []byte) (*types.Account, error) { +func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.Executor, acctID []byte) (*types.Account, error) { m.mu.Lock() defer m.mu.Unlock() @@ -47,7 +47,7 @@ func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.DB, acctID []byte) } // applyTransaction validates account specific info and applies valid transactions to the mempool state. -func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.DB, rebroadcaster Rebroadcaster) error { +func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.Executor, rebroadcaster Rebroadcaster) error { m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/txapp/mempool_test.go b/internal/txapp/mempool_test.go index d98fb52df..254a458dd 100644 --- a/internal/txapp/mempool_test.go +++ b/internal/txapp/mempool_test.go @@ -97,10 +97,6 @@ func newTx(t *testing.T, nonce uint64, sender string) *transactions.Transaction type mockDb struct{} -func (m *mockDb) AccessMode() sql.AccessMode { - return sql.ReadOnly -} - func (m *mockDb) BeginTx(ctx context.Context) (sql.Tx, error) { return &mockTx{m}, nil } diff --git a/internal/txapp/routes_test.go b/internal/txapp/routes_test.go index 2acd7b766..837a26f9f 100644 --- a/internal/txapp/routes_test.go +++ b/internal/txapp/routes_test.go @@ -77,24 +77,24 @@ func Test_Routes(t *testing.T) { deleteCount := 0 // override the functions with mocks - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - approveResolution = func(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error { + approveResolution = func(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error { approveCount++ return nil } - resolutionContainsBody = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + resolutionContainsBody = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - isProcessed = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + isProcessed = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } @@ -121,23 +121,23 @@ func Test_Routes(t *testing.T) { deleteCount := 0 // override the functions with mocks - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - approveResolution = func(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error { + approveResolution = func(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error { approveCount++ return nil } - resolutionContainsBody = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + resolutionContainsBody = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - isProcessed = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + isProcessed = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } @@ -161,7 +161,7 @@ func Test_Routes(t *testing.T) { name: "validator_vote_id, as non-validator", fee: voting.ValidatorVoteIDPrice, fn: func(t *testing.T, callback func(*TxApp)) { - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 0, nil } @@ -184,19 +184,19 @@ func Test_Routes(t *testing.T) { deleteCount := 0 // override the functions with mocks - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - createResolution = func(ctx context.Context, db sql.DB, event *types.VotableEvent, expiration int64, proposer []byte) error { + createResolution = func(ctx context.Context, db sql.TxMaker, event *types.VotableEvent, expiration int64, proposer []byte) error { return nil } - hasVoted = func(ctx context.Context, db sql.DB, resolutionID types.UUID, voter []byte) (bool, error) { + hasVoted = func(ctx context.Context, db sql.Executor, resolutionID types.UUID, voter []byte) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } @@ -226,17 +226,17 @@ func Test_Routes(t *testing.T) { fn: func(t *testing.T, callback func(*TxApp)) { deleteCount := 0 - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - hasVoted = func(ctx context.Context, db sql.DB, resolutionID types.UUID, voter []byte) (bool, error) { + hasVoted = func(ctx context.Context, db sql.Executor, resolutionID types.UUID, voter []byte) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } diff --git a/internal/txapp/txapp.go b/internal/txapp/txapp.go index ba91c949a..843349d2d 100644 --- a/internal/txapp/txapp.go +++ b/internal/txapp/txapp.go @@ -90,7 +90,7 @@ type TxApp struct { // It can assign the initial validator set and initial account balances. // It is only called once for a new chain. func (r *TxApp) GenesisInit(ctx context.Context, validators []*types.Validator, genesisAccounts []*types.Account, initialHeight int64) error { - tx, err := r.Database.BeginTx(ctx) + tx, err := r.Database.BeginOuterTx(ctx) if err != nil { return err } @@ -221,7 +221,7 @@ func (r *TxApp) Begin(ctx context.Context) error { return nil } - tx, err := r.Database.BeginTx(ctx) + tx, err := r.Database.BeginOuterTx(ctx) if err != nil { return err } @@ -687,7 +687,7 @@ func (r *TxApp) Price(ctx context.Context, tx *transactions.Transaction) (*big.I // It also returns an error code. // if we allow users to implement their own routes, this function will need to // be exported. -func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, pricer Pricer, dbTx sql.DB) (*big.Int, transactions.TxCode, error) { +func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, pricer Pricer, dbTx sql.Executor) (*big.Int, transactions.TxCode, error) { amt := big.NewInt(0) var err error diff --git a/internal/voting/voting.go b/internal/voting/voting.go index 5dd3ce9d6..7a963e75b 100644 --- a/internal/voting/voting.go +++ b/internal/voting/voting.go @@ -65,7 +65,7 @@ func initTables(ctx context.Context, db sql.DB) error { // If the voter does not exist, an error will be returned. // If the voter has already approved the resolution, no error will be returned. // If the resolution has already been processed, no error will be returned. -func ApproveResolution(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error { +func ApproveResolution(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error { tx, err := db.BeginTx(ctx) if err != nil { return err @@ -102,7 +102,7 @@ func ApproveResolution(ctx context.Context, db sql.DB, resolutionID types.UUID, // and an expiration. The expiration should be a blockheight. // If the resolution already exists, it will not be changed. // If the resolution was already processed, nothing will happen. -func CreateResolution(ctx context.Context, db sql.DB, event *types.VotableEvent, expiration int64, voteBodyProposer []byte) error { +func CreateResolution(ctx context.Context, db sql.TxMaker, event *types.VotableEvent, expiration int64, voteBodyProposer []byte) error { tx, err := db.BeginTx(ctx) if err != nil { return err @@ -247,7 +247,7 @@ func fromRow(row []any) (*resolutions.Resolution, error) { } // GetResolutionInfo gets a resolution, identified by the ID. -func GetResolutionInfo(ctx context.Context, db sql.DB, id types.UUID) (*resolutions.Resolution, error) { +func GetResolutionInfo(ctx context.Context, db sql.Executor, id types.UUID) (*resolutions.Resolution, error) { res, err := db.Execute(ctx, getFullResolutionInfo, id[:]) if err != nil { return nil, err @@ -266,7 +266,7 @@ func GetResolutionInfo(ctx context.Context, db sql.DB, id types.UUID) (*resoluti // GetExpired returns all resolutions with an expiration // less than or equal to the given blockheight. -func GetExpired(ctx context.Context, db sql.DB, blockheight int64) ([]*resolutions.Resolution, error) { +func GetExpired(ctx context.Context, db sql.Executor, blockheight int64) ([]*resolutions.Resolution, error) { res, err := db.Execute(ctx, getResolutionsFullInfoByExpiration, blockheight) if err != nil { return nil, err @@ -284,7 +284,7 @@ func GetExpired(ctx context.Context, db sql.DB, blockheight int64) ([]*resolutio } // GetResolutionsByThresholdAndType gets all resolutions that have reached the threshold of votes and are of a specific type. -func GetResolutionsByThresholdAndType(ctx context.Context, db sql.DB, threshold *big.Rat, resType string) ([]*resolutions.Resolution, error) { +func GetResolutionsByThresholdAndType(ctx context.Context, db sql.TxMaker, threshold *big.Rat, resType string) ([]*resolutions.Resolution, error) { tx, err := db.BeginTx(ctx) if err != nil { return nil, err @@ -313,7 +313,7 @@ func GetResolutionsByThresholdAndType(ctx context.Context, db sql.DB, threshold } // GetResolutionsByType gets all resolutions of a specific type. -func GetResolutionsByType(ctx context.Context, db sql.DB, resType string) ([]*resolutions.Resolution, error) { +func GetResolutionsByType(ctx context.Context, db sql.Executor, resType string) ([]*resolutions.Resolution, error) { res, err := db.Execute(ctx, getResolutionsFullInfoByType, resType) if err != nil { return nil, err @@ -333,19 +333,19 @@ func GetResolutionsByType(ctx context.Context, db sql.DB, resType string) ([]*re // DeleteResolutions deletes a slice of resolution IDs from the database. // It will mark the resolutions as processed in the processed table. -func DeleteResolutions(ctx context.Context, db sql.DB, ids ...types.UUID) error { +func DeleteResolutions(ctx context.Context, db sql.Executor, ids ...types.UUID) error { _, err := db.Execute(ctx, deleteResolutions, types.UUIDArray(ids)) return err } // MarkProcessed marks a set of resolutions as processed. -func MarkProcessed(ctx context.Context, db sql.DB, ids ...types.UUID) error { +func MarkProcessed(ctx context.Context, db sql.Executor, ids ...types.UUID) error { _, err := db.Execute(ctx, markManyProcessed, types.UUIDArray(ids)) return err } // ResolutionContainsBody returns true if the resolution has a body. -func ResolutionContainsBody(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { +func ResolutionContainsBody(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { res, err := db.Execute(ctx, containsBody, id[:]) if err != nil { return false, err @@ -368,7 +368,7 @@ func ResolutionContainsBody(ctx context.Context, db sql.DB, id types.UUID) (bool } // IsProcessed checks if a vote has been marked as processed. -func IsProcessed(ctx context.Context, tx sql.DB, resolutionID types.UUID) (bool, error) { +func IsProcessed(ctx context.Context, tx sql.Executor, resolutionID types.UUID) (bool, error) { res, err := tx.Execute(ctx, alreadyProcessed, resolutionID[:]) if err != nil { return false, err @@ -379,7 +379,7 @@ func IsProcessed(ctx context.Context, tx sql.DB, resolutionID types.UUID) (bool, // FilterNotProcessed takes a set of resolutions and returns the ones that have not been processed. // If a resolution does not exist, it WILL be included in the result. -func FilterNotProcessed(ctx context.Context, db sql.DB, ids ...types.UUID) ([]types.UUID, error) { +func FilterNotProcessed(ctx context.Context, db sql.Executor, ids ...types.UUID) ([]types.UUID, error) { res, err := db.Execute(ctx, returnNotProcessed, types.UUIDArray(ids)) if err != nil { return nil, err @@ -398,7 +398,7 @@ func FilterNotProcessed(ctx context.Context, db sql.DB, ids ...types.UUID) ([]ty } // FilterExistsNoBody takes a set of resolutions and returns the ones that do exist but do not have a body. -func FilterExistsNoBody(ctx context.Context, db sql.DB, ids ...types.UUID) ([]types.UUID, error) { +func FilterExistsNoBody(ctx context.Context, db sql.Executor, ids ...types.UUID) ([]types.UUID, error) { res, err := db.Execute(ctx, returnNoBody, types.UUIDArray(ids)) if err != nil { return nil, err @@ -417,7 +417,7 @@ func FilterExistsNoBody(ctx context.Context, db sql.DB, ids ...types.UUID) ([]ty } // HasVoted checks if a voter has voted on a resolution. -func HasVoted(ctx context.Context, tx sql.DB, resolutionID types.UUID, from []byte) (bool, error) { +func HasVoted(ctx context.Context, tx sql.Executor, resolutionID types.UUID, from []byte) (bool, error) { userId := types.NewUUIDV5(from) res, err := tx.Execute(ctx, hasVoted, resolutionID[:], userId[:]) @@ -430,16 +430,10 @@ func HasVoted(ctx context.Context, tx sql.DB, resolutionID types.UUID, from []by // GetValidatorPower gets the power of a voter. // If the voter does not exist, it will return 0. -func GetValidatorPower(ctx context.Context, db sql.DB, identifier []byte) (power int64, err error) { - tx, err := db.BeginTx(ctx) - if err != nil { - return 0, err - } - defer tx.Rollback(ctx) - +func GetValidatorPower(ctx context.Context, db sql.Executor, identifier []byte) (power int64, err error) { uuid := types.NewUUIDV5(identifier) - res, err := tx.Execute(ctx, getVoterPower, uuid[:]) + res, err := db.Execute(ctx, getVoterPower, uuid[:]) if err != nil { return 0, err } @@ -459,11 +453,11 @@ func GetValidatorPower(ctx context.Context, db sql.DB, identifier []byte) (power return 0, fmt.Errorf("invalid type for power (%T). this is an internal bug", powerIface) } - return power, tx.Commit(ctx) + return power, nil } // GetValidators gets all voters in the vote store, along with their power. -func GetValidators(ctx context.Context, db sql.DB) ([]*types.Validator, error) { +func GetValidators(ctx context.Context, db sql.Executor) ([]*types.Validator, error) { res, err := db.Execute(ctx, allVoters) if err != nil { return nil, err @@ -506,7 +500,7 @@ func GetValidators(ctx context.Context, db sql.DB) ([]*types.Validator, error) { // It will create the voter if it does not exist. // It will return an error if a negative power is given. // If set to 0, the voter will be deleted. -func SetValidatorPower(ctx context.Context, db sql.DB, recipient []byte, power int64) error { +func SetValidatorPower(ctx context.Context, db sql.Executor, recipient []byte, power int64) error { if power < 0 { return fmt.Errorf("cannot set a negative power") } @@ -523,7 +517,7 @@ func SetValidatorPower(ctx context.Context, db sql.DB, recipient []byte, power i } // RequiredPower gets the required power to meet the threshold requirements. -func RequiredPower(ctx context.Context, db sql.DB, threshold *big.Rat) (int64, error) { +func RequiredPower(ctx context.Context, db sql.Executor, threshold *big.Rat) (int64, error) { numerator := threshold.Num().Int64() denominator := threshold.Denom().Int64() @@ -557,7 +551,7 @@ func RequiredPower(ctx context.Context, db sql.DB, threshold *big.Rat) (int64, e } // GetResolutionIDsByTypeAndProposer gets all resolution ids of a specific type and the body proposer. -func GetResolutionIDsByTypeAndProposer(ctx context.Context, db sql.DB, resType string, proposer []byte) ([]types.UUID, error) { +func GetResolutionIDsByTypeAndProposer(ctx context.Context, db sql.Executor, resType string, proposer []byte) ([]types.UUID, error) { res, err := db.Execute(ctx, getResolutionByTypeAndProposer, resType, proposer) if err != nil { return nil, err @@ -596,7 +590,7 @@ func intDivUpFraction(val, numerator, divisor int64) int64 { return new(big.Int).Div(tempNumerator, divBig).Int64() } -func GetHeight(ctx context.Context, db sql.DB) (int64, error) { +func GetHeight(ctx context.Context, db sql.Executor) (int64, error) { res, err := db.Execute(ctx, getHeight) if err != nil { return 0, err @@ -615,10 +609,7 @@ func GetHeight(ctx context.Context, db sql.DB) (int64, error) { return height, nil } -func SetHeight(ctx context.Context, db sql.DB, height int64) error { - if _, err := db.Execute(ctx, updateHeight, height); err != nil { - return err - } - - return nil +func SetHeight(ctx context.Context, db sql.Executor, height int64) error { + _, err := db.Execute(ctx, updateHeight, height) + return err }