diff --git a/internal/engine/execution/queries.go b/internal/engine/execution/queries.go index 971eb2074..be962d21e 100644 --- a/internal/engine/execution/queries.go +++ b/internal/engine/execution/queries.go @@ -52,7 +52,7 @@ func createSchemasTableIfNotExists(ctx context.Context, tx sql.DB) error { // It will also store the schema in the kwil_schemas table. // It also creates the relevant tables, indexes, etc. // If the schema already exists in the Kwil schemas table, it will be updated. -func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error { +func createSchema(ctx context.Context, tx sql.TxMaker, schema *common.Schema) error { schemaName := dbidSchema(schema.DBID()) sp, err := tx.BeginTx(ctx) @@ -129,7 +129,7 @@ func getSchemas(ctx context.Context, tx sql.Executor) ([]*common.Schema, error) // deleteSchema deletes a schema from the database. // It will also delete the schema from the kwil_schemas table. -func deleteSchema(ctx context.Context, tx sql.DB, dbid string) error { +func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error { schemaName := dbidSchema(dbid) sp, err := tx.BeginTx(ctx) diff --git a/internal/events/events.go b/internal/events/events.go index 4e5fb7cd3..99daf07dd 100644 --- a/internal/events/events.go +++ b/internal/events/events.go @@ -114,7 +114,7 @@ func (e *EventStore) Store(ctx context.Context, data []byte, eventType string) e } // GetEvents gets all events in the event store. -func GetEvents(ctx context.Context, db sql.DB) ([]*types.VotableEvent, error) { +func GetEvents(ctx context.Context, db sql.Executor) ([]*types.VotableEvent, error) { res, err := db.Execute(ctx, getEvents) if err != nil { return nil, err @@ -147,7 +147,7 @@ func GetEvents(ctx context.Context, db sql.DB) ([]*types.VotableEvent, error) { // DeleteEvent deletes an event from the event store. // It is idempotent. If the event does not exist, it will not return an error. -func DeleteEvent(ctx context.Context, db sql.DB, id types.UUID) error { +func DeleteEvent(ctx context.Context, db sql.Executor, id types.UUID) error { _, err := db.Execute(ctx, deleteEvent, id[:]) return err } @@ -204,7 +204,7 @@ func (e *EventStore) MarkBroadcasted(ctx context.Context, ids []types.UUID) erro } // MarkReceived marks that an event has been received by the network, and should not be re-broadcasted. -func MarkReceived(ctx context.Context, db sql.DB, id types.UUID) error { +func MarkReceived(ctx context.Context, db sql.Executor, id types.UUID) error { _, err := db.Execute(ctx, markReceived, id[:]) return err } diff --git a/internal/sql/versioning/versioning.go b/internal/sql/versioning/versioning.go index 74908cd47..25419406a 100644 --- a/internal/sql/versioning/versioning.go +++ b/internal/sql/versioning/versioning.go @@ -14,7 +14,7 @@ var ( // EnsureVersionTableExists ensures that the version table exists in the database. // If the table does not exist, it will be created, and the first version will be set to version specified. -func ensureVersionTableExists(ctx context.Context, db sql.DB, schema string) error { +func ensureVersionTableExists(ctx context.Context, db sql.TxMaker, schema string) error { tx, err := db.BeginTx(ctx) if err != nil { return err @@ -48,7 +48,7 @@ func ensureVersionTableExists(ctx context.Context, db sql.DB, schema string) err // All versions should start at 0. // If the database is fresh, the schema will be initialized to the target version. // Raw initialization at the target version can be done by providing a function for versions -1. -func Upgrade(ctx context.Context, db sql.DB, schema string, versions map[int64]UpgradeFunc, targetVersion int64) error { +func Upgrade(ctx context.Context, db sql.TxMaker, schema string, versions map[int64]UpgradeFunc, targetVersion int64) error { tx, err := db.BeginTx(ctx) if err != nil { return err diff --git a/internal/txapp/mempool.go b/internal/txapp/mempool.go index f606db612..f102f6b16 100644 --- a/internal/txapp/mempool.go +++ b/internal/txapp/mempool.go @@ -21,7 +21,7 @@ type mempool struct { } // accountInfo retrieves the account info from the mempool state or the account store. -func (m *mempool) accountInfo(ctx context.Context, tx sql.DB, acctID []byte) (*types.Account, error) { +func (m *mempool) accountInfo(ctx context.Context, tx sql.Executor, acctID []byte) (*types.Account, error) { if acctInfo, ok := m.accounts[string(acctID)]; ok { return acctInfo, nil // there is an unconfirmed tx for this account } @@ -39,7 +39,7 @@ func (m *mempool) accountInfo(ctx context.Context, tx sql.DB, acctID []byte) (*t } // accountInfoSafe is wraps accountInfo in a mutex lock. -func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.DB, acctID []byte) (*types.Account, error) { +func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.Executor, acctID []byte) (*types.Account, error) { m.mu.Lock() defer m.mu.Unlock() @@ -47,7 +47,7 @@ func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.DB, acctID []byte) } // applyTransaction validates account specific info and applies valid transactions to the mempool state. -func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.DB, rebroadcaster Rebroadcaster) error { +func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.Executor, rebroadcaster Rebroadcaster) error { m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/txapp/routes_test.go b/internal/txapp/routes_test.go index 2acd7b766..837a26f9f 100644 --- a/internal/txapp/routes_test.go +++ b/internal/txapp/routes_test.go @@ -77,24 +77,24 @@ func Test_Routes(t *testing.T) { deleteCount := 0 // override the functions with mocks - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - approveResolution = func(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error { + approveResolution = func(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error { approveCount++ return nil } - resolutionContainsBody = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + resolutionContainsBody = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - isProcessed = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + isProcessed = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } @@ -121,23 +121,23 @@ func Test_Routes(t *testing.T) { deleteCount := 0 // override the functions with mocks - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - approveResolution = func(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error { + approveResolution = func(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error { approveCount++ return nil } - resolutionContainsBody = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + resolutionContainsBody = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - isProcessed = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { + isProcessed = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } @@ -161,7 +161,7 @@ func Test_Routes(t *testing.T) { name: "validator_vote_id, as non-validator", fee: voting.ValidatorVoteIDPrice, fn: func(t *testing.T, callback func(*TxApp)) { - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 0, nil } @@ -184,19 +184,19 @@ func Test_Routes(t *testing.T) { deleteCount := 0 // override the functions with mocks - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - createResolution = func(ctx context.Context, db sql.DB, event *types.VotableEvent, expiration int64, proposer []byte) error { + createResolution = func(ctx context.Context, db sql.TxMaker, event *types.VotableEvent, expiration int64, proposer []byte) error { return nil } - hasVoted = func(ctx context.Context, db sql.DB, resolutionID types.UUID, voter []byte) (bool, error) { + hasVoted = func(ctx context.Context, db sql.Executor, resolutionID types.UUID, voter []byte) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } @@ -226,17 +226,17 @@ func Test_Routes(t *testing.T) { fn: func(t *testing.T, callback func(*TxApp)) { deleteCount := 0 - deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error { + deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error { deleteCount++ return nil } - hasVoted = func(ctx context.Context, db sql.DB, resolutionID types.UUID, voter []byte) (bool, error) { + hasVoted = func(ctx context.Context, db sql.Executor, resolutionID types.UUID, voter []byte) (bool, error) { return true, nil } - getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) { + getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 1, nil } diff --git a/internal/txapp/txapp.go b/internal/txapp/txapp.go index eb788a4c9..843349d2d 100644 --- a/internal/txapp/txapp.go +++ b/internal/txapp/txapp.go @@ -687,7 +687,7 @@ func (r *TxApp) Price(ctx context.Context, tx *transactions.Transaction) (*big.I // It also returns an error code. // if we allow users to implement their own routes, this function will need to // be exported. -func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, pricer Pricer, dbTx sql.DB) (*big.Int, transactions.TxCode, error) { +func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, pricer Pricer, dbTx sql.Executor) (*big.Int, transactions.TxCode, error) { amt := big.NewInt(0) var err error diff --git a/internal/voting/voting.go b/internal/voting/voting.go index 5dd3ce9d6..7a963e75b 100644 --- a/internal/voting/voting.go +++ b/internal/voting/voting.go @@ -65,7 +65,7 @@ func initTables(ctx context.Context, db sql.DB) error { // If the voter does not exist, an error will be returned. // If the voter has already approved the resolution, no error will be returned. // If the resolution has already been processed, no error will be returned. -func ApproveResolution(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error { +func ApproveResolution(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error { tx, err := db.BeginTx(ctx) if err != nil { return err @@ -102,7 +102,7 @@ func ApproveResolution(ctx context.Context, db sql.DB, resolutionID types.UUID, // and an expiration. The expiration should be a blockheight. // If the resolution already exists, it will not be changed. // If the resolution was already processed, nothing will happen. -func CreateResolution(ctx context.Context, db sql.DB, event *types.VotableEvent, expiration int64, voteBodyProposer []byte) error { +func CreateResolution(ctx context.Context, db sql.TxMaker, event *types.VotableEvent, expiration int64, voteBodyProposer []byte) error { tx, err := db.BeginTx(ctx) if err != nil { return err @@ -247,7 +247,7 @@ func fromRow(row []any) (*resolutions.Resolution, error) { } // GetResolutionInfo gets a resolution, identified by the ID. -func GetResolutionInfo(ctx context.Context, db sql.DB, id types.UUID) (*resolutions.Resolution, error) { +func GetResolutionInfo(ctx context.Context, db sql.Executor, id types.UUID) (*resolutions.Resolution, error) { res, err := db.Execute(ctx, getFullResolutionInfo, id[:]) if err != nil { return nil, err @@ -266,7 +266,7 @@ func GetResolutionInfo(ctx context.Context, db sql.DB, id types.UUID) (*resoluti // GetExpired returns all resolutions with an expiration // less than or equal to the given blockheight. -func GetExpired(ctx context.Context, db sql.DB, blockheight int64) ([]*resolutions.Resolution, error) { +func GetExpired(ctx context.Context, db sql.Executor, blockheight int64) ([]*resolutions.Resolution, error) { res, err := db.Execute(ctx, getResolutionsFullInfoByExpiration, blockheight) if err != nil { return nil, err @@ -284,7 +284,7 @@ func GetExpired(ctx context.Context, db sql.DB, blockheight int64) ([]*resolutio } // GetResolutionsByThresholdAndType gets all resolutions that have reached the threshold of votes and are of a specific type. -func GetResolutionsByThresholdAndType(ctx context.Context, db sql.DB, threshold *big.Rat, resType string) ([]*resolutions.Resolution, error) { +func GetResolutionsByThresholdAndType(ctx context.Context, db sql.TxMaker, threshold *big.Rat, resType string) ([]*resolutions.Resolution, error) { tx, err := db.BeginTx(ctx) if err != nil { return nil, err @@ -313,7 +313,7 @@ func GetResolutionsByThresholdAndType(ctx context.Context, db sql.DB, threshold } // GetResolutionsByType gets all resolutions of a specific type. -func GetResolutionsByType(ctx context.Context, db sql.DB, resType string) ([]*resolutions.Resolution, error) { +func GetResolutionsByType(ctx context.Context, db sql.Executor, resType string) ([]*resolutions.Resolution, error) { res, err := db.Execute(ctx, getResolutionsFullInfoByType, resType) if err != nil { return nil, err @@ -333,19 +333,19 @@ func GetResolutionsByType(ctx context.Context, db sql.DB, resType string) ([]*re // DeleteResolutions deletes a slice of resolution IDs from the database. // It will mark the resolutions as processed in the processed table. -func DeleteResolutions(ctx context.Context, db sql.DB, ids ...types.UUID) error { +func DeleteResolutions(ctx context.Context, db sql.Executor, ids ...types.UUID) error { _, err := db.Execute(ctx, deleteResolutions, types.UUIDArray(ids)) return err } // MarkProcessed marks a set of resolutions as processed. -func MarkProcessed(ctx context.Context, db sql.DB, ids ...types.UUID) error { +func MarkProcessed(ctx context.Context, db sql.Executor, ids ...types.UUID) error { _, err := db.Execute(ctx, markManyProcessed, types.UUIDArray(ids)) return err } // ResolutionContainsBody returns true if the resolution has a body. -func ResolutionContainsBody(ctx context.Context, db sql.DB, id types.UUID) (bool, error) { +func ResolutionContainsBody(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) { res, err := db.Execute(ctx, containsBody, id[:]) if err != nil { return false, err @@ -368,7 +368,7 @@ func ResolutionContainsBody(ctx context.Context, db sql.DB, id types.UUID) (bool } // IsProcessed checks if a vote has been marked as processed. -func IsProcessed(ctx context.Context, tx sql.DB, resolutionID types.UUID) (bool, error) { +func IsProcessed(ctx context.Context, tx sql.Executor, resolutionID types.UUID) (bool, error) { res, err := tx.Execute(ctx, alreadyProcessed, resolutionID[:]) if err != nil { return false, err @@ -379,7 +379,7 @@ func IsProcessed(ctx context.Context, tx sql.DB, resolutionID types.UUID) (bool, // FilterNotProcessed takes a set of resolutions and returns the ones that have not been processed. // If a resolution does not exist, it WILL be included in the result. -func FilterNotProcessed(ctx context.Context, db sql.DB, ids ...types.UUID) ([]types.UUID, error) { +func FilterNotProcessed(ctx context.Context, db sql.Executor, ids ...types.UUID) ([]types.UUID, error) { res, err := db.Execute(ctx, returnNotProcessed, types.UUIDArray(ids)) if err != nil { return nil, err @@ -398,7 +398,7 @@ func FilterNotProcessed(ctx context.Context, db sql.DB, ids ...types.UUID) ([]ty } // FilterExistsNoBody takes a set of resolutions and returns the ones that do exist but do not have a body. -func FilterExistsNoBody(ctx context.Context, db sql.DB, ids ...types.UUID) ([]types.UUID, error) { +func FilterExistsNoBody(ctx context.Context, db sql.Executor, ids ...types.UUID) ([]types.UUID, error) { res, err := db.Execute(ctx, returnNoBody, types.UUIDArray(ids)) if err != nil { return nil, err @@ -417,7 +417,7 @@ func FilterExistsNoBody(ctx context.Context, db sql.DB, ids ...types.UUID) ([]ty } // HasVoted checks if a voter has voted on a resolution. -func HasVoted(ctx context.Context, tx sql.DB, resolutionID types.UUID, from []byte) (bool, error) { +func HasVoted(ctx context.Context, tx sql.Executor, resolutionID types.UUID, from []byte) (bool, error) { userId := types.NewUUIDV5(from) res, err := tx.Execute(ctx, hasVoted, resolutionID[:], userId[:]) @@ -430,16 +430,10 @@ func HasVoted(ctx context.Context, tx sql.DB, resolutionID types.UUID, from []by // GetValidatorPower gets the power of a voter. // If the voter does not exist, it will return 0. -func GetValidatorPower(ctx context.Context, db sql.DB, identifier []byte) (power int64, err error) { - tx, err := db.BeginTx(ctx) - if err != nil { - return 0, err - } - defer tx.Rollback(ctx) - +func GetValidatorPower(ctx context.Context, db sql.Executor, identifier []byte) (power int64, err error) { uuid := types.NewUUIDV5(identifier) - res, err := tx.Execute(ctx, getVoterPower, uuid[:]) + res, err := db.Execute(ctx, getVoterPower, uuid[:]) if err != nil { return 0, err } @@ -459,11 +453,11 @@ func GetValidatorPower(ctx context.Context, db sql.DB, identifier []byte) (power return 0, fmt.Errorf("invalid type for power (%T). this is an internal bug", powerIface) } - return power, tx.Commit(ctx) + return power, nil } // GetValidators gets all voters in the vote store, along with their power. -func GetValidators(ctx context.Context, db sql.DB) ([]*types.Validator, error) { +func GetValidators(ctx context.Context, db sql.Executor) ([]*types.Validator, error) { res, err := db.Execute(ctx, allVoters) if err != nil { return nil, err @@ -506,7 +500,7 @@ func GetValidators(ctx context.Context, db sql.DB) ([]*types.Validator, error) { // It will create the voter if it does not exist. // It will return an error if a negative power is given. // If set to 0, the voter will be deleted. -func SetValidatorPower(ctx context.Context, db sql.DB, recipient []byte, power int64) error { +func SetValidatorPower(ctx context.Context, db sql.Executor, recipient []byte, power int64) error { if power < 0 { return fmt.Errorf("cannot set a negative power") } @@ -523,7 +517,7 @@ func SetValidatorPower(ctx context.Context, db sql.DB, recipient []byte, power i } // RequiredPower gets the required power to meet the threshold requirements. -func RequiredPower(ctx context.Context, db sql.DB, threshold *big.Rat) (int64, error) { +func RequiredPower(ctx context.Context, db sql.Executor, threshold *big.Rat) (int64, error) { numerator := threshold.Num().Int64() denominator := threshold.Denom().Int64() @@ -557,7 +551,7 @@ func RequiredPower(ctx context.Context, db sql.DB, threshold *big.Rat) (int64, e } // GetResolutionIDsByTypeAndProposer gets all resolution ids of a specific type and the body proposer. -func GetResolutionIDsByTypeAndProposer(ctx context.Context, db sql.DB, resType string, proposer []byte) ([]types.UUID, error) { +func GetResolutionIDsByTypeAndProposer(ctx context.Context, db sql.Executor, resType string, proposer []byte) ([]types.UUID, error) { res, err := db.Execute(ctx, getResolutionByTypeAndProposer, resType, proposer) if err != nil { return nil, err @@ -596,7 +590,7 @@ func intDivUpFraction(val, numerator, divisor int64) int64 { return new(big.Int).Div(tempNumerator, divBig).Int64() } -func GetHeight(ctx context.Context, db sql.DB) (int64, error) { +func GetHeight(ctx context.Context, db sql.Executor) (int64, error) { res, err := db.Execute(ctx, getHeight) if err != nil { return 0, err @@ -615,10 +609,7 @@ func GetHeight(ctx context.Context, db sql.DB) (int64, error) { return height, nil } -func SetHeight(ctx context.Context, db sql.DB, height int64) error { - if _, err := db.Execute(ctx, updateHeight, height); err != nil { - return err - } - - return nil +func SetHeight(ctx context.Context, db sql.Executor, height int64) error { + _, err := db.Execute(ctx, updateHeight, height) + return err }