Skip to content

Commit

Permalink
further narrowing of used sql interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jchappelow committed Mar 13, 2024
1 parent acc5cfc commit 69ff718
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 62 deletions.
4 changes: 2 additions & 2 deletions internal/engine/execution/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func createSchemasTableIfNotExists(ctx context.Context, tx sql.DB) error {
// It will also store the schema in the kwil_schemas table.
// It also creates the relevant tables, indexes, etc.
// If the schema already exists in the Kwil schemas table, it will be updated.
func createSchema(ctx context.Context, tx sql.DB, schema *common.Schema) error {
func createSchema(ctx context.Context, tx sql.TxMaker, schema *common.Schema) error {
schemaName := dbidSchema(schema.DBID())

sp, err := tx.BeginTx(ctx)
Expand Down Expand Up @@ -129,7 +129,7 @@ func getSchemas(ctx context.Context, tx sql.Executor) ([]*common.Schema, error)

// deleteSchema deletes a schema from the database.
// It will also delete the schema from the kwil_schemas table.
func deleteSchema(ctx context.Context, tx sql.DB, dbid string) error {
func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error {
schemaName := dbidSchema(dbid)

sp, err := tx.BeginTx(ctx)
Expand Down
6 changes: 3 additions & 3 deletions internal/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (e *EventStore) Store(ctx context.Context, data []byte, eventType string) e
}

// GetEvents gets all events in the event store.
func GetEvents(ctx context.Context, db sql.DB) ([]*types.VotableEvent, error) {
func GetEvents(ctx context.Context, db sql.Executor) ([]*types.VotableEvent, error) {
res, err := db.Execute(ctx, getEvents)
if err != nil {
return nil, err
Expand Down Expand Up @@ -147,7 +147,7 @@ func GetEvents(ctx context.Context, db sql.DB) ([]*types.VotableEvent, error) {

// DeleteEvent deletes an event from the event store.
// It is idempotent. If the event does not exist, it will not return an error.
func DeleteEvent(ctx context.Context, db sql.DB, id types.UUID) error {
func DeleteEvent(ctx context.Context, db sql.Executor, id types.UUID) error {
_, err := db.Execute(ctx, deleteEvent, id[:])
return err
}
Expand Down Expand Up @@ -204,7 +204,7 @@ func (e *EventStore) MarkBroadcasted(ctx context.Context, ids []types.UUID) erro
}

// MarkReceived marks that an event has been received by the network, and should not be re-broadcasted.
func MarkReceived(ctx context.Context, db sql.DB, id types.UUID) error {
func MarkReceived(ctx context.Context, db sql.Executor, id types.UUID) error {
_, err := db.Execute(ctx, markReceived, id[:])
return err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/sql/versioning/versioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var (

// EnsureVersionTableExists ensures that the version table exists in the database.
// If the table does not exist, it will be created, and the first version will be set to version specified.
func ensureVersionTableExists(ctx context.Context, db sql.DB, schema string) error {
func ensureVersionTableExists(ctx context.Context, db sql.TxMaker, schema string) error {
tx, err := db.BeginTx(ctx)
if err != nil {
return err
Expand Down Expand Up @@ -48,7 +48,7 @@ func ensureVersionTableExists(ctx context.Context, db sql.DB, schema string) err
// All versions should start at 0.
// If the database is fresh, the schema will be initialized to the target version.
// Raw initialization at the target version can be done by providing a function for versions -1.
func Upgrade(ctx context.Context, db sql.DB, schema string, versions map[int64]UpgradeFunc, targetVersion int64) error {
func Upgrade(ctx context.Context, db sql.TxMaker, schema string, versions map[int64]UpgradeFunc, targetVersion int64) error {
tx, err := db.BeginTx(ctx)
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions internal/txapp/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type mempool struct {
}

// accountInfo retrieves the account info from the mempool state or the account store.
func (m *mempool) accountInfo(ctx context.Context, tx sql.DB, acctID []byte) (*types.Account, error) {
func (m *mempool) accountInfo(ctx context.Context, tx sql.Executor, acctID []byte) (*types.Account, error) {
if acctInfo, ok := m.accounts[string(acctID)]; ok {
return acctInfo, nil // there is an unconfirmed tx for this account
}
Expand All @@ -39,15 +39,15 @@ func (m *mempool) accountInfo(ctx context.Context, tx sql.DB, acctID []byte) (*t
}

// accountInfoSafe is wraps accountInfo in a mutex lock.
func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.DB, acctID []byte) (*types.Account, error) {
func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.Executor, acctID []byte) (*types.Account, error) {
m.mu.Lock()
defer m.mu.Unlock()

return m.accountInfo(ctx, tx, acctID)
}

// applyTransaction validates account specific info and applies valid transactions to the mempool state.
func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.DB, rebroadcaster Rebroadcaster) error {
func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.Executor, rebroadcaster Rebroadcaster) error {
m.mu.Lock()
defer m.mu.Unlock()

Expand Down
36 changes: 18 additions & 18 deletions internal/txapp/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ func Test_Routes(t *testing.T) {
deleteCount := 0

// override the functions with mocks
deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error {
deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error {
deleteCount++

return nil
}

approveResolution = func(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error {
approveResolution = func(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error {
approveCount++

return nil
}
resolutionContainsBody = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) {
resolutionContainsBody = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) {
return true, nil
}
isProcessed = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) {
isProcessed = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) {
return true, nil
}
getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) {
getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) {
return 1, nil
}

Expand All @@ -121,23 +121,23 @@ func Test_Routes(t *testing.T) {
deleteCount := 0

// override the functions with mocks
deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error {
deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error {
deleteCount++

return nil
}
approveResolution = func(ctx context.Context, db sql.DB, resolutionID types.UUID, expiration int64, from []byte) error {
approveResolution = func(ctx context.Context, db sql.TxMaker, resolutionID types.UUID, expiration int64, from []byte) error {
approveCount++

return nil
}
resolutionContainsBody = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) {
resolutionContainsBody = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) {
return true, nil
}
isProcessed = func(ctx context.Context, db sql.DB, id types.UUID) (bool, error) {
isProcessed = func(ctx context.Context, db sql.Executor, id types.UUID) (bool, error) {
return true, nil
}
getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) {
getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) {
return 1, nil
}

Expand All @@ -161,7 +161,7 @@ func Test_Routes(t *testing.T) {
name: "validator_vote_id, as non-validator",
fee: voting.ValidatorVoteIDPrice,
fn: func(t *testing.T, callback func(*TxApp)) {
getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) {
getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) {
return 0, nil
}

Expand All @@ -184,19 +184,19 @@ func Test_Routes(t *testing.T) {
deleteCount := 0

// override the functions with mocks
deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error {
deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error {
deleteCount++

return nil
}
createResolution = func(ctx context.Context, db sql.DB, event *types.VotableEvent, expiration int64, proposer []byte) error {
createResolution = func(ctx context.Context, db sql.TxMaker, event *types.VotableEvent, expiration int64, proposer []byte) error {
return nil
}

hasVoted = func(ctx context.Context, db sql.DB, resolutionID types.UUID, voter []byte) (bool, error) {
hasVoted = func(ctx context.Context, db sql.Executor, resolutionID types.UUID, voter []byte) (bool, error) {
return true, nil
}
getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) {
getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) {
return 1, nil
}

Expand Down Expand Up @@ -226,17 +226,17 @@ func Test_Routes(t *testing.T) {
fn: func(t *testing.T, callback func(*TxApp)) {
deleteCount := 0

deleteEvent = func(ctx context.Context, db sql.DB, id types.UUID) error {
deleteEvent = func(ctx context.Context, db sql.Executor, id types.UUID) error {
deleteCount++

return nil
}

hasVoted = func(ctx context.Context, db sql.DB, resolutionID types.UUID, voter []byte) (bool, error) {
hasVoted = func(ctx context.Context, db sql.Executor, resolutionID types.UUID, voter []byte) (bool, error) {
return true, nil
}

getVoterPower = func(ctx context.Context, db sql.DB, identifier []byte) (int64, error) {
getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) {
return 1, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/txapp/txapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ func (r *TxApp) Price(ctx context.Context, tx *transactions.Transaction) (*big.I
// It also returns an error code.
// if we allow users to implement their own routes, this function will need to
// be exported.
func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, pricer Pricer, dbTx sql.DB) (*big.Int, transactions.TxCode, error) {
func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, pricer Pricer, dbTx sql.Executor) (*big.Int, transactions.TxCode, error) {
amt := big.NewInt(0)
var err error

Expand Down
Loading

0 comments on commit 69ff718

Please sign in to comment.