From b6ee9d26218a0ef0297d34e00e2fb7956449be82 Mon Sep 17 00:00:00 2001 From: Pepper Lebeck-Jobe Date: Tue, 26 Nov 2024 17:30:06 +0100 Subject: [PATCH] Make previousGlobalState a value instead of pointer There's no longer a need to be able to pass in nil. --- bold | 2 +- staker/bold/bold_state_provider.go | 45 ++++++++++-------------- system_tests/bold_new_challenge_test.go | 4 +-- system_tests/bold_state_provider_test.go | 18 +++++----- 4 files changed, 31 insertions(+), 38 deletions(-) diff --git a/bold b/bold index b4a1f85be1..9bc97907ce 160000 --- a/bold +++ b/bold @@ -1 +1 @@ -Subproject commit b4a1f85be1a2b00886b06002a02068d4d07b973b +Subproject commit 9bc97907cef8d8123bf46753ed913d51fddf0c0b diff --git a/staker/bold/bold_state_provider.go b/staker/bold/bold_state_provider.go index bda3d95510..9707d09967 100644 --- a/staker/bold/bold_state_provider.go +++ b/staker/bold/bold_state_provider.go @@ -81,7 +81,7 @@ func NewBOLDStateProvider( func (s *BOLDStateProvider) ExecutionStateAfterPreviousState( ctx context.Context, maxInboxCount uint64, - previousGlobalState *protocol.GoGlobalState, + previousGlobalState protocol.GoGlobalState, ) (*protocol.ExecutionState, error) { if maxInboxCount == 0 { return nil, errors.New("max inbox count cannot be zero") @@ -95,26 +95,24 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState( } return nil, err } - if previousGlobalState != nil { - var previousMessageCount arbutil.MessageIndex - if previousGlobalState.Batch > 0 { - previousMessageCount, err = s.statelessValidator.InboxTracker().GetBatchMessageCount(previousGlobalState.Batch - 1) - if err != nil { - if strings.Contains(err.Error(), "not found") { - return nil, fmt.Errorf("%w: batch count %d", l2stateprovider.ErrChainCatchingUp, maxInboxCount) - } - return nil, err + var previousMessageCount arbutil.MessageIndex + if previousGlobalState.Batch > 0 { + previousMessageCount, err = s.statelessValidator.InboxTracker().GetBatchMessageCount(previousGlobalState.Batch - 1) + if err != nil { + if strings.Contains(err.Error(), "not found") { + return nil, fmt.Errorf("%w: batch count %d", l2stateprovider.ErrChainCatchingUp, maxInboxCount) } + return nil, err } - previousMessageCount += arbutil.MessageIndex(previousGlobalState.PosInBatch) - messageDiffBetweenBatches := messageCount - previousMessageCount - maxMessageCount := previousMessageCount + arbutil.MessageIndex(maxNumberOfBlocks) - if messageDiffBetweenBatches > maxMessageCount { - messageCount = maxMessageCount - batchIndex, _, err = s.statelessValidator.InboxTracker().FindInboxBatchContainingMessage(messageCount) - if err != nil { - return nil, err - } + } + previousMessageCount += arbutil.MessageIndex(previousGlobalState.PosInBatch) + messageDiffBetweenBatches := messageCount - previousMessageCount + maxMessageCount := previousMessageCount + arbutil.MessageIndex(maxNumberOfBlocks) + if messageDiffBetweenBatches > maxMessageCount { + messageCount = maxMessageCount + batchIndex, _, err = s.statelessValidator.InboxTracker().FindInboxBatchContainingMessage(messageCount) + if err != nil { + return nil, err } } globalState, err := s.findGlobalStateFromMessageCountAndBatch(messageCount, l2stateprovider.Batch(batchIndex)) @@ -135,15 +133,10 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState( GlobalState: protocol.GoGlobalState(globalState), MachineStatus: protocol.MachineStatusFinished, } - - var previousGlobalStateOrDefault protocol.GoGlobalState - if previousGlobalState != nil { - previousGlobalStateOrDefault = *previousGlobalState - } toBatch := executionState.GlobalState.Batch historyCommitStates, _, err := s.StatesInBatchRange( ctx, - previousGlobalStateOrDefault, + previousGlobalState, toBatch, l2stateprovider.Height(maxNumberOfBlocks), ) @@ -155,7 +148,7 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState( return nil, err } executionState.EndHistoryRoot = historyCommit.Merkle - fmt.Printf("ExecutionStateAfterPreviousState for previous state batch %v pos %v got end batch %v pos %v last leaf %v hash %v\n", previousGlobalStateOrDefault.Batch, previousGlobalStateOrDefault.PosInBatch, executionState.GlobalState.Batch, executionState.GlobalState.PosInBatch, historyCommitStates[len(historyCommitStates)-1], executionState.EndHistoryRoot) + fmt.Printf("ExecutionStateAfterPreviousState for previous state batch %v pos %v got end batch %v pos %v last leaf %v hash %v\n", previousGlobalState.Batch, previousGlobalState.PosInBatch, executionState.GlobalState.Batch, executionState.GlobalState.PosInBatch, historyCommitStates[len(historyCommitStates)-1], executionState.EndHistoryRoot) return executionState, nil } diff --git a/system_tests/bold_new_challenge_test.go b/system_tests/bold_new_challenge_test.go index 73ac8508a6..a9caec45f0 100644 --- a/system_tests/bold_new_challenge_test.go +++ b/system_tests/bold_new_challenge_test.go @@ -46,14 +46,14 @@ type incorrectBlockStateProvider struct { func (s *incorrectBlockStateProvider) ExecutionStateAfterPreviousState( ctx context.Context, maxInboxCount uint64, - previousGlobalState *protocol.GoGlobalState, + previousGlobalState protocol.GoGlobalState, ) (*protocol.ExecutionState, error) { maxNumberOfBlocks := s.chain.SpecChallengeManager().LayerZeroHeights().BlockChallengeHeight.Uint64() executionState, err := s.honest.ExecutionStateAfterPreviousState(ctx, maxInboxCount, previousGlobalState) if err != nil { return nil, err } - evilStates, err := s.L2MessageStatesUpTo(ctx, *previousGlobalState, l2stateprovider.Batch(maxInboxCount), option.Some(l2stateprovider.Height(maxNumberOfBlocks))) + evilStates, err := s.L2MessageStatesUpTo(ctx, previousGlobalState, l2stateprovider.Batch(maxInboxCount), option.Some(l2stateprovider.Height(maxNumberOfBlocks))) if err != nil { return nil, err } diff --git a/system_tests/bold_state_provider_test.go b/system_tests/bold_state_provider_test.go index 766ecce380..40578221db 100644 --- a/system_tests/bold_state_provider_test.go +++ b/system_tests/bold_state_provider_test.go @@ -238,7 +238,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { _, err = stateManager.ExecutionStateAfterPreviousState( ctx, 0, - &protocol.GoGlobalState{ + protocol.GoGlobalState{ Batch: 0, PosInBatch: 1, }, @@ -254,7 +254,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { genesis, err := stateManager.ExecutionStateAfterPreviousState( ctx, 1, - &protocol.GoGlobalState{ + protocol.GoGlobalState{ Batch: 0, PosInBatch: 0, }, @@ -268,7 +268,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { first, err := stateManager.ExecutionStateAfterPreviousState( ctx, 2, - &genesis.GlobalState, + genesis.GlobalState, ) Require(t, err) if first == nil { @@ -279,7 +279,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { _, err = stateManager.ExecutionStateAfterPreviousState( ctx, 10, - &first.GlobalState, + first.GlobalState, ) if err == nil { Fatal(t, "should not agree with execution state") @@ -298,7 +298,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { SendRoot: result.SendRoot, Batch: 3, } - got, err := stateManager.ExecutionStateAfterPreviousState(ctx, 3, &first.GlobalState) + got, err := stateManager.ExecutionStateAfterPreviousState(ctx, 3, first.GlobalState) Require(t, err) if state.Batch != got.GlobalState.Batch { Fatal(t, "wrong batch") @@ -315,7 +315,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { _, err = stateManager.ExecutionStateAfterPreviousState( ctx, state.Batch+1, - &got.GlobalState, + got.GlobalState, ) if err == nil { Fatal(t, "should not agree with execution state") @@ -325,7 +325,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { } }) t.Run("ExecutionStateAfterBatchCount", func(t *testing.T) { - _, err = stateManager.ExecutionStateAfterPreviousState(ctx, 0, &protocol.GoGlobalState{}) + _, err = stateManager.ExecutionStateAfterPreviousState(ctx, 0, protocol.GoGlobalState{}) if err == nil { Fatal(t, "should have failed") } @@ -333,9 +333,9 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { Fatal(t, "wrong error message", err) } - genesis, err := stateManager.ExecutionStateAfterPreviousState(ctx, 1, &protocol.GoGlobalState{}) + genesis, err := stateManager.ExecutionStateAfterPreviousState(ctx, 1, protocol.GoGlobalState{}) Require(t, err) - execState, err := stateManager.ExecutionStateAfterPreviousState(ctx, totalBatches, &genesis.GlobalState) + execState, err := stateManager.ExecutionStateAfterPreviousState(ctx, totalBatches, genesis.GlobalState) Require(t, err) if execState == nil { Fatal(t, "should not be nil")