diff --git a/staker/bold_state_provider.go b/staker/bold_state_provider.go index 4312d75b27..4a25637edb 100644 --- a/staker/bold_state_provider.go +++ b/staker/bold_state_provider.go @@ -139,6 +139,7 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState( } toBatch := executionState.GlobalState.Batch historyCommitStates, _, err := s.StatesInBatchRange( + ctx, 0, l2stateprovider.Height(maxNumberOfBlocks)+1, l2stateprovider.Batch(fromBatch), @@ -184,6 +185,7 @@ func (s *BOLDStateProvider) isStateValidatedAndMessageCountPastThreshold( } func (s *BOLDStateProvider) StatesInBatchRange( + ctx context.Context, fromHeight, toHeight l2stateprovider.Height, fromBatch, @@ -198,83 +200,63 @@ func (s *BOLDStateProvider) StatesInBatchRange( } // Compute the total desired hashes from this request. totalDesiredHashes := (toHeight - fromHeight) + 1 + machineHashes := make([]common.Hash, 0, totalDesiredHashes) + states := make([]validator.GoGlobalState, 0, totalDesiredHashes) var prevBatchMsgCount arbutil.MessageIndex var err error - if fromBatch == 0 { - prevBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(0) + batchNum, found, err := s.statelessValidator.inboxTracker.FindInboxBatchContainingMessage(arbutil.MessageIndex(fromHeight)) + if err != nil { + return nil, nil, err + } + if !found { + return nil, nil, fmt.Errorf("could not find batch containing message %d", fromHeight) + } + if batchNum == 0 { + prevBatchMsgCount = 0 } else { - prevBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(uint64(fromBatch) - 1) + prevBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(batchNum - 1) } if err != nil { return nil, nil, err } - executionResult, err := s.statelessValidator.streamer.ResultAtCount(prevBatchMsgCount) + currBatchMsgCount, err := s.statelessValidator.inboxTracker.GetBatchMessageCount(batchNum) if err != nil { return nil, nil, err } - startState := validator.GoGlobalState{ - BlockHash: executionResult.BlockHash, - SendRoot: executionResult.SendRoot, - Batch: uint64(fromBatch), - PosInBatch: 0, - } - machineHashes := make([]common.Hash, 0, totalDesiredHashes) - states := make([]validator.GoGlobalState, 0, totalDesiredHashes) - machineHashes = append(machineHashes, machineHash(startState)) - states = append(states, startState) - - for batch := fromBatch; batch < toBatch; batch++ { - batchMessageCount, err := s.statelessValidator.inboxTracker.GetBatchMessageCount(uint64(batch)) - if err != nil { - return nil, nil, err + posInBatch := uint64(fromHeight) - uint64(prevBatchMsgCount) + for pos := fromHeight; pos <= toHeight; pos++ { + if ctx.Err() != nil { + return nil, nil, ctx.Err() } - messagesInBatch := batchMessageCount - prevBatchMsgCount - - // Obtain the states for each message in the batch. - for i := uint64(0); i < uint64(messagesInBatch); i++ { - msgIndex := uint64(prevBatchMsgCount) + i - messageCount := msgIndex + 1 - executionResult, err := s.statelessValidator.streamer.ResultAtCount(arbutil.MessageIndex(messageCount)) - if err != nil { - return nil, nil, err - } - // If the position in batch is equal to the number of messages in the batch, - // we do not include this state. Instead, we break and include the state - // that fully consumes the batch. - if i+1 == uint64(messagesInBatch) { - break - } - state := validator.GoGlobalState{ - BlockHash: executionResult.BlockHash, - SendRoot: executionResult.SendRoot, - Batch: uint64(batch), - PosInBatch: i + 1, - } - states = append(states, state) - machineHashes = append(machineHashes, machineHash(state)) - } - - // Fully consume the batch. - executionResult, err := s.statelessValidator.streamer.ResultAtCount(batchMessageCount) + executionResult, err := s.statelessValidator.streamer.ResultAtCount(arbutil.MessageIndex(pos)) if err != nil { return nil, nil, err } state := validator.GoGlobalState{ BlockHash: executionResult.BlockHash, SendRoot: executionResult.SendRoot, - Batch: uint64(batch) + 1, - PosInBatch: 0, + Batch: batchNum, + PosInBatch: posInBatch, } states = append(states, state) machineHashes = append(machineHashes, machineHash(state)) - prevBatchMsgCount = batchMessageCount + if uint64(pos) == uint64(currBatchMsgCount) { + posInBatch = 0 + batchNum++ + currBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(batchNum) + if err != nil { + return nil, nil, err + } + } else { + posInBatch++ + } } for uint64(len(machineHashes)) < uint64(totalDesiredHashes) { machineHashes = append(machineHashes, machineHashes[len(machineHashes)-1]) states = append(states, states[len(states)-1]) } - return machineHashes[fromHeight : toHeight+1], states[fromHeight : toHeight+1], nil + return machineHashes, states, nil } func machineHash(gs validator.GoGlobalState) common.Hash { @@ -309,7 +291,7 @@ func (s *BOLDStateProvider) findGlobalStateFromMessageCountAndBatch(count arbuti // and up to a required batch index. The hashes used for this commitment are the machine hashes // at each message number. func (s *BOLDStateProvider) L2MessageStatesUpTo( - _ context.Context, + ctx context.Context, fromHeight l2stateprovider.Height, toHeight option.Option[l2stateprovider.Height], fromBatch, @@ -321,7 +303,7 @@ func (s *BOLDStateProvider) L2MessageStatesUpTo( } else { to = s.blockChallengeLeafHeight } - items, _, err := s.StatesInBatchRange(fromHeight, to, fromBatch, toBatch) + items, _, err := s.StatesInBatchRange(ctx, fromHeight, to, fromBatch, toBatch) if err != nil { return nil, err } diff --git a/system_tests/bold_challenge_protocol_test.go b/system_tests/bold_challenge_protocol_test.go index db9ca08d65..895927c2f4 100644 --- a/system_tests/bold_challenge_protocol_test.go +++ b/system_tests/bold_challenge_protocol_test.go @@ -542,9 +542,10 @@ func createTestNodeOnL1ForBoldProtocol( l1info.SetContract("Rollup", addresses.Rollup) l1info.SetContract("UpgradeExecutor", addresses.UpgradeExecutor) - cacheConfig := TestCachingConfig - cacheConfig.StateScheme = rawdb.HashScheme - _, l2stack, l2chainDb, l2arbDb, l2blockchain = createL2BlockChainWithStackConfig(t, l2info, "", chainConfig, getInitMessage(ctx, t, l1client, addresses), stackConfig, &cacheConfig) + execConfig := ExecConfigDefaultNonSequencerTest(t) + Require(t, execConfig.Validate()) + execConfig.Caching.StateScheme = rawdb.HashScheme + _, l2stack, l2chainDb, l2arbDb, l2blockchain = createL2BlockChain(t, l2info, "", chainConfig, execConfig) var sequencerTxOptsPtr *bind.TransactOpts var dataSigner signature.DataSignerFunc if isSequencer { @@ -560,9 +561,6 @@ func createTestNodeOnL1ForBoldProtocol( AddValNodeIfNeeded(t, ctx, nodeConfig, true, "", "") - execConfig := ExecConfigDefaultNonSequencerTest() - Require(t, execConfig.Validate()) - execConfig.Caching.StateScheme = rawdb.HashScheme execConfigFetcher := func() *gethexec.Config { return execConfig } execNode, err := gethexec.CreateExecutionNode(ctx, l2stack, l2chainDb, l2blockchain, l1client, execConfigFetcher) Require(t, err) @@ -765,7 +763,7 @@ func create2ndNodeWithConfigForBoldProtocol( initReader := statetransfer.NewMemoryInitDataReader(l2InitData) initMessage := getInitMessage(ctx, t, l1client, first.DeployInfo) - execConfig := ExecConfigDefaultNonSequencerTest() + execConfig := ExecConfigDefaultNonSequencerTest(t) Require(t, execConfig.Validate()) execConfig.Caching.StateScheme = rawdb.HashScheme coreCacheConfig := gethexec.DefaultCacheConfigFor(l2stack, &execConfig.Caching) diff --git a/system_tests/bold_state_provider_test.go b/system_tests/bold_state_provider_test.go index dadf698797..db6fa91104 100644 --- a/system_tests/bold_state_provider_test.go +++ b/system_tests/bold_state_provider_test.go @@ -212,7 +212,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) { toBatch := l2stateprovider.Batch(3) fromHeight := l2stateprovider.Height(0) toHeight := l2stateprovider.Height(14) - stateRoots, states, err := stateManager.StatesInBatchRange(fromHeight, toHeight, fromBatch, toBatch) + stateRoots, states, err := stateManager.StatesInBatchRange(ctx, fromHeight, toHeight, fromBatch, toBatch) Require(t, err) if len(stateRoots) != 15 {