Skip to content

Commit

Permalink
Delete intermediate aggregations on DB replace (#284)
Browse files Browse the repository at this point in the history
* feat: Add unscoped to aggregation db replace operations

* feat: Implement Databaser on Database

* feat: Add unscoped to DB operations and pass model to aggregation store methods

* feat: Use FirstOrCreate when storing messages

* feat: Add DB to Databaser methods
  • Loading branch information
Hyodar authored Aug 13, 2024
1 parent 26cfd46 commit 0154475
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 68 deletions.
8 changes: 4 additions & 4 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,12 @@ func (agg *Aggregator) handleStateRootUpdateReachedQuorum(blsAggServiceResp blsa

agg.logger.Info("Storing state root update", "digest", blsAggServiceResp.MessageDigest, "status", blsAggServiceResp.Status)

err := agg.msgDb.StoreStateRootUpdate(msg)
msgModel, err := agg.msgDb.StoreStateRootUpdate(msg)
if err != nil {
agg.logger.Error("Aggregator could not store message")
return
}
err = agg.msgDb.StoreStateRootUpdateAggregation(msg, blsAggServiceResp.MessageBlsAggregation)
err = agg.msgDb.StoreStateRootUpdateAggregation(msgModel, blsAggServiceResp.MessageBlsAggregation)
if err != nil {
agg.logger.Error("Aggregator could not store message aggregation")
return
Expand Down Expand Up @@ -499,12 +499,12 @@ func (agg *Aggregator) handleOperatorSetUpdateReachedQuorum(ctx context.Context,

agg.logger.Info("Storing operator set update", "digest", blsAggServiceResp.MessageDigest, "status", blsAggServiceResp.Status)

err := agg.msgDb.StoreOperatorSetUpdate(msg)
msgModel, err := agg.msgDb.StoreOperatorSetUpdate(msg)
if err != nil {
agg.logger.Error("Aggregator could not store message")
return
}
err = agg.msgDb.StoreOperatorSetUpdateAggregation(msg, blsAggServiceResp.MessageBlsAggregation)
err = agg.msgDb.StoreOperatorSetUpdateAggregation(msgModel, blsAggServiceResp.MessageBlsAggregation)
if err != nil {
agg.logger.Error("Aggregator could not store message aggregation")
return
Expand Down
14 changes: 10 additions & 4 deletions aggregator/aggregator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/NethermindEth/near-sffl/aggregator/blsagg"
dbmocks "github.com/NethermindEth/near-sffl/aggregator/database/mocks"
"github.com/NethermindEth/near-sffl/aggregator/database/models"
aggmocks "github.com/NethermindEth/near-sffl/aggregator/mocks"
"github.com/NethermindEth/near-sffl/aggregator/types"
taskmanager "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLTaskManager"
Expand Down Expand Up @@ -116,8 +117,11 @@ func TestHandleStateRootUpdateAggregationReachedQuorum(t *testing.T) {
Finished: true,
}

mockMsgDb.EXPECT().StoreStateRootUpdate(msg)
mockMsgDb.EXPECT().StoreStateRootUpdateAggregation(msg, blsAggServiceResp.MessageBlsAggregation)
model := models.NewStateRootUpdateMessageModel(msg)

// get first return from StoreStateRootUpdate and use it as first argument on StoreStateRootUpdateAggregation
mockMsgDb.EXPECT().StoreStateRootUpdate(msg).Return(&model, nil)
mockMsgDb.EXPECT().StoreStateRootUpdateAggregation(&model, blsAggServiceResp.MessageBlsAggregation)

aggregator.handleStateRootUpdateReachedQuorum(blsAggServiceResp)
}
Expand All @@ -144,8 +148,10 @@ func TestHandleOperatorSetUpdateAggregationReachedQuorum(t *testing.T) {
Finished: true,
}

mockMsgDb.EXPECT().StoreOperatorSetUpdate(msg)
mockMsgDb.EXPECT().StoreOperatorSetUpdateAggregation(msg, blsAggServiceResp.MessageBlsAggregation)
msgModel := models.NewOperatorSetUpdateMessageModel(msg)

mockMsgDb.EXPECT().StoreOperatorSetUpdate(msg).Return(&msgModel, nil)
mockMsgDb.EXPECT().StoreOperatorSetUpdateAggregation(&msgModel, blsAggServiceResp.MessageBlsAggregation)

signatureInfo := blsAggServiceResp.ExtractBindingRollup()
mockRollupBroadcaster.EXPECT().BroadcastOperatorSetUpdate(context.Background(), msg, signatureInfo)
Expand Down
81 changes: 45 additions & 36 deletions aggregator/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"

"github.com/NethermindEth/near-sffl/aggregator/database/models"
Expand All @@ -23,15 +22,16 @@ type Databaser interface {
core.Metricable

Close() error
StoreStateRootUpdate(stateRootUpdateMessage messages.StateRootUpdateMessage) error
StoreStateRootUpdate(stateRootUpdateMessage messages.StateRootUpdateMessage) (*models.StateRootUpdateMessage, error)
FetchStateRootUpdate(rollupId uint32, blockHeight uint64) (*messages.StateRootUpdateMessage, error)
StoreStateRootUpdateAggregation(stateRootUpdateMessage messages.StateRootUpdateMessage, aggregation messages.MessageBlsAggregation) error
StoreStateRootUpdateAggregation(stateRootUpdateMessage *models.StateRootUpdateMessage, aggregation messages.MessageBlsAggregation) error
FetchStateRootUpdateAggregation(rollupId uint32, blockHeight uint64) (*messages.MessageBlsAggregation, error)
StoreOperatorSetUpdate(operatorSetUpdateMessage messages.OperatorSetUpdateMessage) error
StoreOperatorSetUpdate(operatorSetUpdateMessage messages.OperatorSetUpdateMessage) (*models.OperatorSetUpdateMessage, error)
FetchOperatorSetUpdate(id uint64) (*messages.OperatorSetUpdateMessage, error)
StoreOperatorSetUpdateAggregation(operatorSetUpdateMessage messages.OperatorSetUpdateMessage, aggregation messages.MessageBlsAggregation) error
StoreOperatorSetUpdateAggregation(operatorSetUpdateMessage *models.OperatorSetUpdateMessage, aggregation messages.MessageBlsAggregation) error
FetchOperatorSetUpdateAggregation(id uint64) (*messages.MessageBlsAggregation, error)
FetchCheckpointMessages(fromTimestamp uint64, toTimestamp uint64) (*messages.CheckpointMessages, error)
DB() *gorm.DB
}

type Database struct {
Expand All @@ -41,6 +41,7 @@ type Database struct {
}

var _ core.Metricable = (*Database)(nil)
var _ Databaser = (*Database)(nil)

func NewDatabase(dbPath string) (*Database, error) {
logger := logger.New(
Expand Down Expand Up @@ -105,22 +106,25 @@ func (d *Database) EnableMetrics(registry *prometheus.Registry) error {
return nil
}

func (d *Database) StoreStateRootUpdate(stateRootUpdateMessage messages.StateRootUpdateMessage) error {
func (d *Database) StoreStateRootUpdate(stateRootUpdateMessage messages.StateRootUpdateMessage) (*models.StateRootUpdateMessage, error) {
start := time.Now()
defer func() { d.listener.OnStore(time.Since(start)) }()

model := models.StateRootUpdateMessage{
RollupId: stateRootUpdateMessage.RollupId,
BlockHeight: stateRootUpdateMessage.BlockHeight,
Timestamp: stateRootUpdateMessage.Timestamp,
NearDaTransactionId: stateRootUpdateMessage.NearDaTransactionId[:],
NearDaCommitment: stateRootUpdateMessage.NearDaCommitment[:],
StateRoot: stateRootUpdateMessage.StateRoot[:],
}

tx := d.db.
Clauses(clause.OnConflict{Columns: []clause.Column{{Name: "rollup_id"}, {Name: "block_height"}}, UpdateAll: true}).
Create(&models.StateRootUpdateMessage{
RollupId: stateRootUpdateMessage.RollupId,
BlockHeight: stateRootUpdateMessage.BlockHeight,
Timestamp: stateRootUpdateMessage.Timestamp,
NearDaTransactionId: stateRootUpdateMessage.NearDaTransactionId[:],
NearDaCommitment: stateRootUpdateMessage.NearDaCommitment[:],
StateRoot: stateRootUpdateMessage.StateRoot[:],
})

return tx.Error
Where("rollup_id = ?", stateRootUpdateMessage.RollupId).
Where("block_height = ?", stateRootUpdateMessage.BlockHeight).
FirstOrCreate(&model)

return &model, tx.Error
}

func (d *Database) FetchStateRootUpdate(rollupId uint32, blockHeight uint64) (*messages.StateRootUpdateMessage, error) {
Expand All @@ -141,18 +145,17 @@ func (d *Database) FetchStateRootUpdate(rollupId uint32, blockHeight uint64) (*m
return &stateRootUpdateMessage, nil
}

func (d *Database) StoreStateRootUpdateAggregation(stateRootUpdateMessage messages.StateRootUpdateMessage, aggregation messages.MessageBlsAggregation) error {
func (d *Database) StoreStateRootUpdateAggregation(stateRootUpdateMessage *models.StateRootUpdateMessage, aggregation messages.MessageBlsAggregation) error {
start := time.Now()
defer func() { d.listener.OnStore(time.Since(start)) }()

model := models.NewMessageBlsAggregationModel(aggregation)

err := d.db.
Clauses(clause.OnConflict{UpdateAll: true}).
Model(&models.StateRootUpdateMessage{}).
Where("rollup_id = ?", stateRootUpdateMessage.RollupId).
Where("block_height = ?", stateRootUpdateMessage.BlockHeight).
Unscoped().
Model(stateRootUpdateMessage).
Association("Aggregation").
Unscoped().
Replace(&model)
if err != nil {
return err
Expand Down Expand Up @@ -186,19 +189,21 @@ func (d *Database) FetchStateRootUpdateAggregation(rollupId uint32, blockHeight
return &aggregation, nil
}

func (d *Database) StoreOperatorSetUpdate(operatorSetUpdateMessage messages.OperatorSetUpdateMessage) error {
func (d *Database) StoreOperatorSetUpdate(operatorSetUpdateMessage messages.OperatorSetUpdateMessage) (*models.OperatorSetUpdateMessage, error) {
start := time.Now()
defer func() { d.listener.OnStore(time.Since(start)) }()

model := models.OperatorSetUpdateMessage{
UpdateId: operatorSetUpdateMessage.Id,
Timestamp: operatorSetUpdateMessage.Timestamp,
Operators: operatorSetUpdateMessage.Operators,
}

tx := d.db.
Clauses(clause.OnConflict{Columns: []clause.Column{{Name: "update_id"}}, UpdateAll: true}).
Create(&models.OperatorSetUpdateMessage{
UpdateId: operatorSetUpdateMessage.Id,
Timestamp: operatorSetUpdateMessage.Timestamp,
Operators: operatorSetUpdateMessage.Operators,
})

return tx.Error
Where("update_id = ?", operatorSetUpdateMessage.Id).
FirstOrCreate(&model)

return &model, tx.Error
}

func (d *Database) FetchOperatorSetUpdate(id uint64) (*messages.OperatorSetUpdateMessage, error) {
Expand All @@ -218,17 +223,17 @@ func (d *Database) FetchOperatorSetUpdate(id uint64) (*messages.OperatorSetUpdat
return &operatorSetUpdateMessage, nil
}

func (d *Database) StoreOperatorSetUpdateAggregation(operatorSetUpdateMessage messages.OperatorSetUpdateMessage, aggregation messages.MessageBlsAggregation) error {
func (d *Database) StoreOperatorSetUpdateAggregation(operatorSetUpdateMessage *models.OperatorSetUpdateMessage, aggregation messages.MessageBlsAggregation) error {
start := time.Now()
defer func() { d.listener.OnStore(time.Since(start)) }()

model := models.NewMessageBlsAggregationModel(aggregation)

err := d.db.
Clauses(clause.OnConflict{UpdateAll: true}).
Model(&models.OperatorSetUpdateMessage{}).
Where("update_id = ?", operatorSetUpdateMessage.Id).
Unscoped().
Model(operatorSetUpdateMessage).
Association("Aggregation").
Unscoped().
Replace(&model)
if err != nil {
return err
Expand Down Expand Up @@ -266,7 +271,7 @@ func (d *Database) FetchCheckpointMessages(fromTimestamp uint64, toTimestamp uin
return nil, errors.New("timestamp does not fit in int64")
}

if (toTimestamp < fromTimestamp) {
if toTimestamp < fromTimestamp {
return nil, errors.New("toTimestamp is less than fromTimestamp")
}

Expand Down Expand Up @@ -333,3 +338,7 @@ func (d *Database) FetchCheckpointMessages(fromTimestamp uint64, toTimestamp uin

return result, nil
}

func (d *Database) DB() *gorm.DB {
return d.db
}
Loading

0 comments on commit 0154475

Please sign in to comment.