diff --git a/coordinator/internal/logic/provertask/batch_prover_task.go b/coordinator/internal/logic/provertask/batch_prover_task.go index b043cd29d3..0f50710d3e 100644 --- a/coordinator/internal/logic/provertask/batch_prover_task.go +++ b/coordinator/internal/logic/provertask/batch_prover_task.go @@ -34,12 +34,13 @@ type BatchProverTask struct { func NewBatchProverTask(cfg *config.Config, db *gorm.DB, vk string, reg prometheus.Registerer) *BatchProverTask { bp := &BatchProverTask{ BaseProverTask: BaseProverTask{ - vk: vk, - db: db, - cfg: cfg, - chunkOrm: orm.NewChunk(db), - batchOrm: orm.NewBatch(db), - proverTaskOrm: orm.NewProverTask(db), + vk: vk, + db: db, + cfg: cfg, + chunkOrm: orm.NewChunk(db), + batchOrm: orm.NewBatch(db), + proverTaskOrm: orm.NewProverTask(db), + proverBlockListOrm: orm.NewProverBlockList(db), }, batchAttemptsExceedTotal: promauto.With(reg).NewCounter(prometheus.CounterOpts{ Name: "coordinator_batch_attempts_exceed_total", diff --git a/coordinator/internal/logic/provertask/chunk_prover_task.go b/coordinator/internal/logic/provertask/chunk_prover_task.go index 84b108e0ec..42787bcf47 100644 --- a/coordinator/internal/logic/provertask/chunk_prover_task.go +++ b/coordinator/internal/logic/provertask/chunk_prover_task.go @@ -9,7 +9,6 @@ import ( "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/log" "gorm.io/gorm" @@ -37,12 +36,13 @@ type ChunkProverTask struct { func NewChunkProverTask(cfg *config.Config, db *gorm.DB, vk string, reg prometheus.Registerer) *ChunkProverTask { cp := &ChunkProverTask{ BaseProverTask: BaseProverTask{ - vk: vk, - db: db, - cfg: cfg, - chunkOrm: orm.NewChunk(db), - blockOrm: orm.NewL2Block(db), - proverTaskOrm: orm.NewProverTask(db), + vk: vk, + db: db, + cfg: cfg, + chunkOrm: orm.NewChunk(db), + blockOrm: orm.NewL2Block(db), + proverTaskOrm: orm.NewProverTask(db), + proverBlockListOrm: orm.NewProverBlockList(db), }, chunkAttemptsExceedTotal: promauto.With(reg).NewCounter(prometheus.CounterOpts{ Name: "coordinator_chunk_attempts_exceed_total", @@ -144,14 +144,9 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato func (cp *ChunkProverTask) formatProverTask(ctx context.Context, task *orm.ProverTask) (*coordinatorType.GetTaskSchema, error) { // Get block hashes. - wrappedBlocks, wrappedErr := cp.blockOrm.GetL2BlocksByChunkHash(ctx, task.TaskID) - if wrappedErr != nil || len(wrappedBlocks) == 0 { - return nil, fmt.Errorf("failed to fetch wrapped blocks, chunk hash:%s err:%w", task.TaskID, wrappedErr) - } - - blockHashes := make([]common.Hash, len(wrappedBlocks)) - for i, wrappedBlock := range wrappedBlocks { - blockHashes[i] = wrappedBlock.Header.Hash() + blockHashes, dbErr := cp.blockOrm.GetL2BlockHashesByChunkHash(ctx, task.TaskID) + if dbErr != nil || len(blockHashes) == 0 { + return nil, fmt.Errorf("failed to fetch block hashes of a chunk, chunk hash:%s err:%w", task.TaskID, dbErr) } taskDetail := message.ChunkTaskDetail{ diff --git a/coordinator/internal/logic/provertask/prover_task.go b/coordinator/internal/logic/provertask/prover_task.go index 012006dfac..70e258fa55 100644 --- a/coordinator/internal/logic/provertask/prover_task.go +++ b/coordinator/internal/logic/provertask/prover_task.go @@ -24,10 +24,11 @@ type BaseProverTask struct { db *gorm.DB vk string - batchOrm *orm.Batch - chunkOrm *orm.Chunk - blockOrm *orm.L2Block - proverTaskOrm *orm.ProverTask + batchOrm *orm.Batch + chunkOrm *orm.Chunk + blockOrm *orm.L2Block + proverTaskOrm *orm.ProverTask + proverBlockListOrm *orm.ProverBlockList } type proverTaskContext struct { @@ -68,13 +69,21 @@ func (b *BaseProverTask) checkParameter(ctx *gin.Context, getTaskParameter *coor return nil, fmt.Errorf("incompatible vk. please check your params files or config files") } + isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx, publicKey.(string)) + if err != nil { + return nil, fmt.Errorf("failed to check if the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion) + } + if isBlocked { + return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion) + } + isAssigned, err := b.proverTaskOrm.IsProverAssigned(ctx, publicKey.(string)) if err != nil { - return nil, fmt.Errorf("failed to check if prover is assigned a task: %w", err) + return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err) } if isAssigned { - return nil, fmt.Errorf("prover with publicKey %s is already assigned a task", publicKey) + return nil, fmt.Errorf("prover with publicKey %s is already assigned a task. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion) } return &ptc, nil } diff --git a/coordinator/internal/orm/l2_block.go b/coordinator/internal/orm/l2_block.go index 38728dea8f..2030ba7ba0 100644 --- a/coordinator/internal/orm/l2_block.go +++ b/coordinator/internal/orm/l2_block.go @@ -50,42 +50,29 @@ func (*L2Block) TableName() string { return "l2_block" } -// GetL2BlocksByChunkHash retrieves the L2 blocks associated with the specified chunk hash. -// The returned blocks are sorted in ascending order by their block number. -func (o *L2Block) GetL2BlocksByChunkHash(ctx context.Context, chunkHash string) ([]*types.WrappedBlock, error) { +// GetL2BlockHashesByChunkHash retrieves the L2 block hashes associated with the specified chunk hash. +// The returned block hashes are sorted in ascending order by their block number. +func (o *L2Block) GetL2BlockHashesByChunkHash(ctx context.Context, chunkHash string) ([]common.Hash, error) { db := o.db.WithContext(ctx) db = db.Model(&L2Block{}) - db = db.Select("header, transactions, withdraw_root, row_consumption") + db = db.Select("header") db = db.Where("chunk_hash = ?", chunkHash) db = db.Order("number ASC") var l2Blocks []L2Block if err := db.Find(&l2Blocks).Error; err != nil { - return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash) + return nil, fmt.Errorf("L2Block.GetL2BlockHashesByChunkHash error: %w, chunk hash: %v", err, chunkHash) } - var wrappedBlocks []*types.WrappedBlock + var blockHashes []common.Hash for _, v := range l2Blocks { - var wrappedBlock types.WrappedBlock - - if err := json.Unmarshal([]byte(v.Transactions), &wrappedBlock.Transactions); err != nil { - return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash) - } - - wrappedBlock.Header = &gethTypes.Header{} - if err := json.Unmarshal([]byte(v.Header), wrappedBlock.Header); err != nil { - return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash) + var header gethTypes.Header + if err := json.Unmarshal([]byte(v.Header), &header); err != nil { + return nil, fmt.Errorf("L2Block.GetL2BlockHashesByChunkHash error: %w, chunk hash: %v", err, chunkHash) } - - wrappedBlock.WithdrawRoot = common.HexToHash(v.WithdrawRoot) - if err := json.Unmarshal([]byte(v.RowConsumption), &wrappedBlock.RowConsumption); err != nil { - return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash) - } - - wrappedBlocks = append(wrappedBlocks, &wrappedBlock) + blockHashes = append(blockHashes, header.Hash()) } - - return wrappedBlocks, nil + return blockHashes, nil } // InsertL2Blocks inserts l2 blocks into the "l2_block" table. diff --git a/coordinator/internal/orm/prover_block_list.go b/coordinator/internal/orm/prover_block_list.go new file mode 100644 index 0000000000..e465673296 --- /dev/null +++ b/coordinator/internal/orm/prover_block_list.go @@ -0,0 +1,77 @@ +package orm + +import ( + "context" + "fmt" + "time" + + "gorm.io/gorm" +) + +// ProverBlockList represents the prover's block entry in the database. +type ProverBlockList struct { + db *gorm.DB `gorm:"-"` + + ID uint `json:"id" gorm:"column:id;primaryKey"` + ProverName string `json:"prover_name" gorm:"column:prover_name"` + PublicKey string `json:"public_key" gorm:"column:public_key"` + + // metadata + CreatedAt time.Time `json:"created_at" gorm:"column:created_at"` + UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at"` + DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"column:deleted_at;default:NULL"` +} + +// NewProverBlockList creates a new ProverBlockList instance. +func NewProverBlockList(db *gorm.DB) *ProverBlockList { + return &ProverBlockList{db: db} +} + +// TableName returns the name of the "prover_block_list" table. +func (*ProverBlockList) TableName() string { + return "prover_block_list" +} + +// InsertProverPublicKey adds a new Prover public key to the block list. +// for unit test only. +func (p *ProverBlockList) InsertProverPublicKey(ctx context.Context, proverName, publicKey string) error { + prover := ProverBlockList{ + ProverName: proverName, + PublicKey: publicKey, + } + + db := p.db.WithContext(ctx) + db = db.Model(&ProverBlockList{}) + if err := db.Create(&prover).Error; err != nil { + return fmt.Errorf("ProverBlockList.InsertProverPublicKey error: %w, prover name: %v, public key: %v", err, proverName, publicKey) + } + return nil +} + +// DeleteProverPublicKey marks a Prover public key as deleted in the block list. +// for unit test only. +func (p *ProverBlockList) DeleteProverPublicKey(ctx context.Context, publicKey string) error { + db := p.db.WithContext(ctx) + db = db.Where("public_key = ?", publicKey) + if err := db.Delete(&ProverBlockList{}).Error; err != nil { + return fmt.Errorf("ProverBlockList.DeleteProverPublicKey error: %w, public key: %v", err, publicKey) + } + return nil +} + +// IsPublicKeyBlocked checks if the given public key is blocked. +func (p *ProverBlockList) IsPublicKeyBlocked(ctx context.Context, publicKey string) (bool, error) { + var proverBlock ProverBlockList + + db := p.db.WithContext(ctx) + db = db.Model(&ProverBlockList{}) + db = db.Where("public_key = ?", publicKey) + if err := db.First(&proverBlock).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return false, nil // Public key not found, hence it's not blocked. + } + return true, fmt.Errorf("ProverBlockList.IsPublicKeyBlocked error: %w, public key: %v", err, publicKey) + } + + return true, nil +} diff --git a/coordinator/test/api_test.go b/coordinator/test/api_test.go index 260d9034fe..372a0fe551 100644 --- a/coordinator/test/api_test.go +++ b/coordinator/test/api_test.go @@ -38,11 +38,12 @@ var ( base *docker.App - db *gorm.DB - l2BlockOrm *orm.L2Block - chunkOrm *orm.Chunk - batchOrm *orm.Batch - proverTaskOrm *orm.ProverTask + db *gorm.DB + l2BlockOrm *orm.L2Block + chunkOrm *orm.Chunk + batchOrm *orm.Batch + proverTaskOrm *orm.ProverTask + proverBlockListOrm *orm.ProverBlockList wrappedBlock1 *types.WrappedBlock wrappedBlock2 *types.WrappedBlock @@ -133,6 +134,7 @@ func setEnv(t *testing.T) { chunkOrm = orm.NewChunk(db) l2BlockOrm = orm.NewL2Block(db) proverTaskOrm = orm.NewProverTask(db) + proverBlockListOrm = orm.NewProverBlockList(db) templateBlockTrace, err := os.ReadFile("../../common/testdata/blockTrace_02.json") assert.NoError(t, err) @@ -157,6 +159,7 @@ func TestApis(t *testing.T) { t.Run("TestHandshake", testHandshake) t.Run("TestFailedHandshake", testFailedHandshake) + t.Run("TestGetTaskBlocked", testGetTaskBlocked) t.Run("TestValidProof", testValidProof) t.Run("TestInvalidProof", testInvalidProof) t.Run("TestProofGeneratedFailed", testProofGeneratedFailed) @@ -200,6 +203,50 @@ func testFailedHandshake(t *testing.T) { assert.True(t, batchProver.healthCheckFailure(t)) } +func testGetTaskBlocked(t *testing.T) { + coordinatorURL := randomURL() + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) + defer func() { + collector.Stop() + assert.NoError(t, httpHandler.Shutdown(context.Background())) + }() + + chunkProver := newMockProver(t, "prover_chunk_test", coordinatorURL, message.ProofTypeChunk) + assert.True(t, chunkProver.healthCheckSuccess(t)) + + batchProver := newMockProver(t, "prover_batch_test", coordinatorURL, message.ProofTypeBatch) + assert.True(t, chunkProver.healthCheckSuccess(t)) + + err := proverBlockListOrm.InsertProverPublicKey(context.Background(), chunkProver.proverName, chunkProver.publicKey()) + assert.NoError(t, err) + + expectedErr := fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion) + code, errMsg := chunkProver.tryGetProverTask(t, message.ProofTypeChunk) + assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code) + assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) + + expectedErr = fmt.Errorf("get empty prover task") + code, errMsg = batchProver.tryGetProverTask(t, message.ProofTypeBatch) + assert.Equal(t, types.ErrCoordinatorEmptyProofData, code) + assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) + + err = proverBlockListOrm.InsertProverPublicKey(context.Background(), batchProver.proverName, batchProver.publicKey()) + assert.NoError(t, err) + + err = proverBlockListOrm.DeleteProverPublicKey(context.Background(), chunkProver.publicKey()) + assert.NoError(t, err) + + expectedErr = fmt.Errorf("get empty prover task") + code, errMsg = chunkProver.tryGetProverTask(t, message.ProofTypeChunk) + assert.Equal(t, types.ErrCoordinatorEmptyProofData, code) + assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) + + expectedErr = fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion) + code, errMsg = batchProver.tryGetProverTask(t, message.ProofTypeBatch) + assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code) + assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) +} + func testValidProof(t *testing.T) { coordinatorURL := randomURL() collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) diff --git a/coordinator/test/mock_prover.go b/coordinator/test/mock_prover.go index 804c163973..885f576d37 100644 --- a/coordinator/test/mock_prover.go +++ b/coordinator/test/mock_prover.go @@ -7,6 +7,7 @@ import ( "net/http" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/go-resty/resty/v2" "github.com/mitchellh/mapstructure" "github.com/scroll-tech/go-ethereum/crypto" @@ -30,6 +31,7 @@ const ( type mockProver struct { proverName string + proverVersion string privKey *ecdsa.PrivateKey proofType message.ProofType coordinatorURL string @@ -41,6 +43,7 @@ func newMockProver(t *testing.T, proverName string, coordinatorURL string, proof prover := &mockProver{ proverName: proverName, + proverVersion: version.Version, privKey: privKey, proofType: proofType, coordinatorURL: coordinatorURL, @@ -78,8 +81,8 @@ func (r *mockProver) login(t *testing.T, challengeString string) string { authMsg := message.AuthMsg{ Identity: &message.Identity{ Challenge: challengeString, - ProverName: "test", - ProverVersion: version.Version, + ProverName: r.proverName, + ProverVersion: r.proverVersion, }, } assert.NoError(t, authMsg.SignWithKey(r.privKey)) @@ -162,6 +165,32 @@ func (r *mockProver) getProverTask(t *testing.T, proofType message.ProofType) *t return &result.Data } +// Testing expected errors returned by coordinator. +func (r *mockProver) tryGetProverTask(t *testing.T, proofType message.ProofType) (int, string) { + // get task from coordinator + token := r.connectToCoordinator(t) + assert.NotEmpty(t, token) + + type response struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + Data types.GetTaskSchema `json:"data"` + } + + var result response + client := resty.New() + resp, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)). + SetBody(map[string]interface{}{"prover_height": 100, "task_type": int(proofType)}). + SetResult(&result). + Post("http://" + r.coordinatorURL + "/coordinator/v1/get_task") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + + return result.ErrCode, result.ErrMsg +} + func (r *mockProver) submitProof(t *testing.T, proverTaskSchema *types.GetTaskSchema, proofStatus proofStatus, errCode int) { proofMsgStatus := message.StatusOk if proofStatus == generatedFailed { @@ -224,3 +253,7 @@ func (r *mockProver) submitProof(t *testing.T, proverTaskSchema *types.GetTaskSc assert.Equal(t, http.StatusOK, resp.StatusCode()) assert.Equal(t, errCode, result.ErrCode) } + +func (r *mockProver) publicKey() string { + return common.Bytes2Hex(crypto.CompressPubkey(&r.privKey.PublicKey)) +} diff --git a/database/migrate/migrate_test.go b/database/migrate/migrate_test.go index 08b7bd66be..c6c25db039 100644 --- a/database/migrate/migrate_test.go +++ b/database/migrate/migrate_test.go @@ -50,7 +50,7 @@ func TestMigrate(t *testing.T) { func testCurrent(t *testing.T) { cur, err := Current(pgDB.DB) assert.NoError(t, err) - assert.Equal(t, 0, int(cur)) + assert.Equal(t, int64(0), cur) } func testStatus(t *testing.T) { @@ -63,24 +63,31 @@ func testResetDB(t *testing.T) { cur, err := Current(pgDB.DB) assert.NoError(t, err) // total number of tables. - assert.Equal(t, 15, int(cur)) + assert.Equal(t, int64(16), cur) } func testMigrate(t *testing.T) { assert.NoError(t, Migrate(pgDB.DB)) cur, err := Current(pgDB.DB) assert.NoError(t, err) - assert.Equal(t, true, cur > 0) + assert.Equal(t, int64(16), cur) } func testRollback(t *testing.T) { version, err := Current(pgDB.DB) assert.NoError(t, err) - assert.Equal(t, true, version > 0) + assert.Equal(t, int64(16), version) assert.NoError(t, Rollback(pgDB.DB, nil)) cur, err := Current(pgDB.DB) assert.NoError(t, err) - assert.Equal(t, true, cur+1 == version) + assert.Equal(t, version, cur+1) + + targetVersion := int64(0) + assert.NoError(t, Rollback(pgDB.DB, &targetVersion)) + + cur, err = Current(pgDB.DB) + assert.NoError(t, err) + assert.Equal(t, int64(0), cur) } diff --git a/database/migrate/migrations/00016_prover_block_list.sql b/database/migrate/migrations/00016_prover_block_list.sql new file mode 100644 index 0000000000..97e3b930da --- /dev/null +++ b/database/migrate/migrations/00016_prover_block_list.sql @@ -0,0 +1,26 @@ +-- +goose Up +-- +goose StatementBegin + +CREATE TABLE prover_block_list +( + id BIGSERIAL PRIMARY KEY, + + public_key VARCHAR NOT NULL, + +-- debug info + prover_name VARCHAR NOT NULL, + + created_at TIMESTAMP(0) NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP(0) NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMP(0) DEFAULT NULL +); + +CREATE INDEX idx_prover_block_list_on_public_key ON prover_block_list(public_key); +CREATE INDEX idx_prover_block_list_on_prover_name ON prover_block_list(prover_name); + +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TABLE IF EXISTS prover_block_list; +-- +goose StatementEnd