diff --git a/common/version/prover_version.go b/common/version/prover_version.go index 2bc72322f1..f9787bf1cb 100644 --- a/common/version/prover_version.go +++ b/common/version/prover_version.go @@ -1,7 +1,10 @@ package version import ( + "strconv" "strings" + + "github.com/scroll-tech/go-ethereum/log" ) // CheckScrollProverVersion check the "scroll-prover" version, if it's different from the local one, return false @@ -19,3 +22,54 @@ func CheckScrollProverVersion(proverVersion string) bool { // compare the `scroll_prover` version return remote[2] == local[2] } + +// parseVersion takes a version string and returns its major, minor, and patch numbers. +func parseVersion(version string) (major, minor, patch int) { + trimVersion := strings.TrimPrefix(version, "v") + if trimVersion == version { + log.Error("version does not start with v", "vesion", version) + return 0, 0, 0 + } + + versionPart := strings.SplitN(trimVersion, "-", 2)[0] + parts := strings.Split(versionPart, ".") + if len(parts) != 3 { + log.Error("invalid version format", "expected format", "v..", "got", version) + return 0, 0, 0 + } + + var err error + major, err = strconv.Atoi(parts[0]) + if err != nil { + log.Error("invalid major version", "value", parts[0], "error", err) + return 0, 0, 0 + } + + minor, err = strconv.Atoi(parts[1]) + if err != nil { + log.Error("invalid minor version", "value", parts[1], "error", err) + return 0, 0, 0 + } + + patch, err = strconv.Atoi(parts[2]) + if err != nil { + log.Error("invalid patch version", "value", parts[2], "error", err) + return 0, 0, 0 + } + + return major, minor, patch +} + +// CheckScrollRepoVersion checks if the proverVersion is at least the minimum required version. +func CheckScrollRepoVersion(proverVersion, minVersion string) bool { + major1, minor1, patch1 := parseVersion(proverVersion) + major2, minor2, patch2 := parseVersion(minVersion) + + if major1 != major2 { + return major1 > major2 + } + if minor1 != minor2 { + return minor1 > minor2 + } + return patch1 >= patch2 +} diff --git a/common/version/prover_version_test.go b/common/version/prover_version_test.go new file mode 100644 index 0000000000..dfbe464d36 --- /dev/null +++ b/common/version/prover_version_test.go @@ -0,0 +1,80 @@ +package version + +import ( + "os" + "testing" + + "github.com/scroll-tech/go-ethereum/log" +) + +func TestMain(m *testing.M) { + glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.LogfmtFormat())) + glogger.Verbosity(log.LvlInfo) + log.Root().SetHandler(glogger) + + m.Run() +} + +func TestCheckScrollProverVersion(t *testing.T) { + tests := []struct { + proverVersion string + want bool + }{ + {Version, true}, + {"tag-commit-111111-000000", false}, + {"incorrect-format", false}, + {"tag-commit-222222-111111", false}, + } + + for _, tt := range tests { + if got := CheckScrollProverVersion(tt.proverVersion); got != tt.want { + t.Errorf("CheckScrollProverVersion(%q) = %v, want %v", tt.proverVersion, got, tt.want) + } + } +} + +func TestParseVersion(t *testing.T) { + tests := []struct { + version string + wantMajor int + wantMinor int + wantPatch int + }{ + {"v1.2.3-commit-111111-000000", 1, 2, 3}, + {"v0.10.0-patch-commit-111111-000000", 0, 10, 0}, + {"v2.0.1-alpha", 2, 0, 1}, + {"v10.0.0", 10, 0, 0}, + {"v1.0", 0, 0, 0}, // Invalid format + {"v..", 0, 0, 0}, // Invalid format + } + + for _, tt := range tests { + gotMajor, gotMinor, gotPatch := parseVersion(tt.version) + if gotMajor != tt.wantMajor || gotMinor != tt.wantMinor || gotPatch != tt.wantPatch { + t.Errorf("parseVersion(%q) = %v, %v, %v, want %v, %v, %v", tt.version, gotMajor, gotMinor, gotPatch, tt.wantMajor, tt.wantMinor, tt.wantPatch) + } + } +} + +func TestCheckScrollRepoVersion(t *testing.T) { + tests := []struct { + proverVersion string + minVersion string + want bool + }{ + {"v1.2.3-commit-111111-000000", "v1.2.3", true}, + {"v1.2.3-patch-commit-111111-000000", "v1.2.2", true}, + {"v1.0.0-alpha", "v1.0.0", true}, + {"v1.2.2", "v1.2.3", false}, + {"v2.0.0", "v1.9.9", true}, + {"v0.9.0", "v1.0.0", false}, + {"v9.9.9", "v10.0.0", false}, + {"v4.1.98-aaa-bbb-ccc", "v999.0.0", false}, + } + + for _, tt := range tests { + if got := CheckScrollRepoVersion(tt.proverVersion, tt.minVersion); got != tt.want { + t.Errorf("CheckScrollRepoVersion(%q, %q) = %v, want %v", tt.proverVersion, tt.minVersion, got, tt.want) + } + } +} diff --git a/coordinator/cmd/api/app/mock_app.go b/coordinator/cmd/api/app/mock_app.go index dafaeba25b..c1bb858372 100644 --- a/coordinator/cmd/api/app/mock_app.go +++ b/coordinator/cmd/api/app/mock_app.go @@ -87,6 +87,7 @@ func (c *CoordinatorApp) MockConfig(store bool) error { ChunkCollectionTimeSec: 60, SessionAttempts: 10, MaxVerifierWorkers: 4, + MinProverVersion: "v1.0.0", } cfg.DB.DSN = base.DBImg.Endpoint() cfg.L2.ChainID = 111 diff --git a/coordinator/conf/config.json b/coordinator/conf/config.json index 6cbeaca0d6..b143a702f7 100644 --- a/coordinator/conf/config.json +++ b/coordinator/conf/config.json @@ -9,7 +9,8 @@ "params_path": "", "assets_path": "" }, - "max_verifier_workers": 4 + "max_verifier_workers": 4, + "min_prover_version": "v1.0.0" }, "db": { "driver_name": "postgres", diff --git a/coordinator/internal/config/config.go b/coordinator/internal/config/config.go index 273869c988..33142b1138 100644 --- a/coordinator/internal/config/config.go +++ b/coordinator/internal/config/config.go @@ -23,6 +23,8 @@ type ProverManager struct { ChunkCollectionTimeSec int `json:"chunk_collection_time_sec"` // Max number of workers in verifier worker pool MaxVerifierWorkers int `json:"max_verifier_workers"` + // MinProverVersion is the minimum version of the prover that is required. + MinProverVersion string `json:"min_prover_version"` } // L2 loads l2geth configuration items. diff --git a/coordinator/internal/config/config_test.go b/coordinator/internal/config/config_test.go index 482f8ec117..748fe9c2a9 100644 --- a/coordinator/internal/config/config_test.go +++ b/coordinator/internal/config/config_test.go @@ -22,7 +22,8 @@ func TestConfig(t *testing.T) { "params_path": "", "agg_vk_path": "" }, - "max_verifier_workers": 4 + "max_verifier_workers": 4, + "min_prover_version": "v1.0.0" }, "db": { "driver_name": "postgres", diff --git a/coordinator/internal/logic/provertask/prover_task.go b/coordinator/internal/logic/provertask/prover_task.go index 70e258fa55..b1782df3f1 100644 --- a/coordinator/internal/logic/provertask/prover_task.go +++ b/coordinator/internal/logic/provertask/prover_task.go @@ -59,6 +59,10 @@ func (b *BaseProverTask) checkParameter(ctx *gin.Context, getTaskParameter *coor } ptc.ProverVersion = proverVersion.(string) + if !version.CheckScrollRepoVersion(proverVersion.(string), b.cfg.ProverManager.MinProverVersion) { + return nil, fmt.Errorf("incompatible prover version. please upgrade your prover, minimum allowed version: %s, actual version: %s", b.cfg.ProverManager.MinProverVersion, proverVersion.(string)) + } + // if the prover has a different vk if getTaskParameter.VK != b.vk { // if the prover reports a different prover version diff --git a/coordinator/test/api_test.go b/coordinator/test/api_test.go index 372a0fe551..eadf19af5d 100644 --- a/coordinator/test/api_test.go +++ b/coordinator/test/api_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/scroll-tech/go-ethereum/log" "github.com/stretchr/testify/assert" "gorm.io/gorm" @@ -53,6 +54,10 @@ var ( ) func TestMain(m *testing.M) { + glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.LogfmtFormat())) + glogger.Verbosity(log.LvlInfo) + log.Root().SetHandler(glogger) + base = docker.NewDockerApp() m.Run() base.Free() @@ -63,7 +68,7 @@ func randomURL() string { return fmt.Sprintf("localhost:%d", 10000+2000+id.Int64()) } -func setupCoordinator(t *testing.T, proversPerSession uint8, coordinatorURL string) (*cron.Collector, *http.Server) { +func setupCoordinator(t *testing.T, proversPerSession uint8, coordinatorURL string, proverVersion string) (*cron.Collector, *http.Server) { var err error db, err = database.InitDB(dbCfg) assert.NoError(t, err) @@ -83,6 +88,7 @@ func setupCoordinator(t *testing.T, proversPerSession uint8, coordinatorURL stri ChunkCollectionTimeSec: 10, MaxVerifierWorkers: 10, SessionAttempts: 5, + MinProverVersion: proverVersion, }, Auth: &config.Auth{ ChallengeExpireDurationSec: tokenTimeout, @@ -160,6 +166,7 @@ func TestApis(t *testing.T) { t.Run("TestHandshake", testHandshake) t.Run("TestFailedHandshake", testFailedHandshake) t.Run("TestGetTaskBlocked", testGetTaskBlocked) + t.Run("TestOutdatedProverVersion", testOutdatedProverVersion) t.Run("TestValidProof", testValidProof) t.Run("TestInvalidProof", testInvalidProof) t.Run("TestProofGeneratedFailed", testProofGeneratedFailed) @@ -174,7 +181,7 @@ func TestApis(t *testing.T) { func testHandshake(t *testing.T) { // Setup coordinator and http server. coordinatorURL := randomURL() - proofCollector, httpHandler := setupCoordinator(t, 1, coordinatorURL) + proofCollector, httpHandler := setupCoordinator(t, 1, coordinatorURL, version.Version) defer func() { proofCollector.Stop() assert.NoError(t, httpHandler.Shutdown(context.Background())) @@ -187,7 +194,7 @@ func testHandshake(t *testing.T) { func testFailedHandshake(t *testing.T) { // Setup coordinator and http server. coordinatorURL := randomURL() - proofCollector, httpHandler := setupCoordinator(t, 1, coordinatorURL) + proofCollector, httpHandler := setupCoordinator(t, 1, coordinatorURL, version.Version) defer func() { proofCollector.Stop() }() @@ -205,7 +212,7 @@ func testFailedHandshake(t *testing.T) { func testGetTaskBlocked(t *testing.T) { coordinatorURL := randomURL() - collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL, version.Version) defer func() { collector.Stop() assert.NoError(t, httpHandler.Shutdown(context.Background())) @@ -247,9 +254,34 @@ func testGetTaskBlocked(t *testing.T) { assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) } +func testOutdatedProverVersion(t *testing.T) { + coordinatorURL := randomURL() + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL, "v999.0.0") + defer func() { + collector.Stop() + assert.NoError(t, httpHandler.Shutdown(context.Background())) + }() + + chunkProver := newMockProver(t, "prover_chunk_test", coordinatorURL, message.ProofTypeChunk) + assert.True(t, chunkProver.healthCheckSuccess(t)) + + batchProver := newMockProver(t, "prover_batch_test", coordinatorURL, message.ProofTypeBatch) + assert.True(t, chunkProver.healthCheckSuccess(t)) + + expectedErr := fmt.Errorf("return prover task err:check prover task parameter failed, error:incompatible prover version. please upgrade your prover, minimum allowed version: v999.0.0, actual version: %s", chunkProver.proverVersion) + code, errMsg := chunkProver.tryGetProverTask(t, message.ProofTypeChunk) + assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code) + assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) + + expectedErr = fmt.Errorf("return prover task err:check prover task parameter failed, error:incompatible prover version. please upgrade your prover, minimum allowed version: v999.0.0, actual version: %s", batchProver.proverVersion) + code, errMsg = batchProver.tryGetProverTask(t, message.ProofTypeBatch) + assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code) + assert.Equal(t, expectedErr, fmt.Errorf(errMsg)) +} + func testValidProof(t *testing.T) { coordinatorURL := randomURL() - collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL, version.Version) defer func() { collector.Stop() assert.NoError(t, httpHandler.Shutdown(context.Background())) @@ -333,7 +365,7 @@ func testValidProof(t *testing.T) { func testInvalidProof(t *testing.T) { // Setup coordinator and ws server. coordinatorURL := randomURL() - collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL, version.Version) defer func() { collector.Stop() assert.NoError(t, httpHandler.Shutdown(context.Background())) @@ -409,7 +441,7 @@ func testInvalidProof(t *testing.T) { func testProofGeneratedFailed(t *testing.T) { // Setup coordinator and ws server. coordinatorURL := randomURL() - collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL, version.Version) defer func() { collector.Stop() assert.NoError(t, httpHandler.Shutdown(context.Background())) @@ -496,7 +528,7 @@ func testProofGeneratedFailed(t *testing.T) { func testTimeoutProof(t *testing.T) { // Setup coordinator and ws server. coordinatorURL := randomURL() - collector, httpHandler := setupCoordinator(t, 1, coordinatorURL) + collector, httpHandler := setupCoordinator(t, 1, coordinatorURL, version.Version) defer func() { collector.Stop() assert.NoError(t, httpHandler.Shutdown(context.Background()))