Skip to content

Commit

Permalink
Changes based on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amsanghi committed Sep 26, 2024
1 parent 5334e63 commit b6cb5a1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 61 deletions.
88 changes: 35 additions & 53 deletions staker/bold_state_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState(
}
toBatch := executionState.GlobalState.Batch
historyCommitStates, _, err := s.StatesInBatchRange(
ctx,
0,
l2stateprovider.Height(maxNumberOfBlocks)+1,
l2stateprovider.Batch(fromBatch),
Expand Down Expand Up @@ -184,6 +185,7 @@ func (s *BOLDStateProvider) isStateValidatedAndMessageCountPastThreshold(
}

func (s *BOLDStateProvider) StatesInBatchRange(
ctx context.Context,
fromHeight,
toHeight l2stateprovider.Height,
fromBatch,
Expand All @@ -198,83 +200,63 @@ func (s *BOLDStateProvider) StatesInBatchRange(
}
// Compute the total desired hashes from this request.
totalDesiredHashes := (toHeight - fromHeight) + 1
machineHashes := make([]common.Hash, 0, totalDesiredHashes)
states := make([]validator.GoGlobalState, 0, totalDesiredHashes)

var prevBatchMsgCount arbutil.MessageIndex
var err error
if fromBatch == 0 {
prevBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(0)
batchNum, found, err := s.statelessValidator.inboxTracker.FindInboxBatchContainingMessage(arbutil.MessageIndex(fromHeight))
if err != nil {
return nil, nil, err
}
if !found {
return nil, nil, fmt.Errorf("could not find batch containing message %d", fromHeight)
}
if batchNum == 0 {
prevBatchMsgCount = 0
} else {
prevBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(uint64(fromBatch) - 1)
prevBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(batchNum - 1)
}
if err != nil {
return nil, nil, err
}
executionResult, err := s.statelessValidator.streamer.ResultAtCount(prevBatchMsgCount)
currBatchMsgCount, err := s.statelessValidator.inboxTracker.GetBatchMessageCount(batchNum)
if err != nil {
return nil, nil, err
}
startState := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(fromBatch),
PosInBatch: 0,
}
machineHashes := make([]common.Hash, 0, totalDesiredHashes)
states := make([]validator.GoGlobalState, 0, totalDesiredHashes)
machineHashes = append(machineHashes, machineHash(startState))
states = append(states, startState)

for batch := fromBatch; batch < toBatch; batch++ {
batchMessageCount, err := s.statelessValidator.inboxTracker.GetBatchMessageCount(uint64(batch))
if err != nil {
return nil, nil, err
posInBatch := uint64(fromHeight) - uint64(prevBatchMsgCount)
for pos := fromHeight; pos <= toHeight; pos++ {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
messagesInBatch := batchMessageCount - prevBatchMsgCount

// Obtain the states for each message in the batch.
for i := uint64(0); i < uint64(messagesInBatch); i++ {
msgIndex := uint64(prevBatchMsgCount) + i
messageCount := msgIndex + 1
executionResult, err := s.statelessValidator.streamer.ResultAtCount(arbutil.MessageIndex(messageCount))
if err != nil {
return nil, nil, err
}
// If the position in batch is equal to the number of messages in the batch,
// we do not include this state. Instead, we break and include the state
// that fully consumes the batch.
if i+1 == uint64(messagesInBatch) {
break
}
state := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(batch),
PosInBatch: i + 1,
}
states = append(states, state)
machineHashes = append(machineHashes, machineHash(state))
}

// Fully consume the batch.
executionResult, err := s.statelessValidator.streamer.ResultAtCount(batchMessageCount)
executionResult, err := s.statelessValidator.streamer.ResultAtCount(arbutil.MessageIndex(pos))
if err != nil {
return nil, nil, err
}
state := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(batch) + 1,
PosInBatch: 0,
Batch: batchNum,
PosInBatch: posInBatch,
}
states = append(states, state)
machineHashes = append(machineHashes, machineHash(state))
prevBatchMsgCount = batchMessageCount
if uint64(pos) == uint64(currBatchMsgCount) {
posInBatch = 0
batchNum++
currBatchMsgCount, err = s.statelessValidator.inboxTracker.GetBatchMessageCount(batchNum)
if err != nil {
return nil, nil, err
}
} else {
posInBatch++
}
}
for uint64(len(machineHashes)) < uint64(totalDesiredHashes) {
machineHashes = append(machineHashes, machineHashes[len(machineHashes)-1])
states = append(states, states[len(states)-1])
}
return machineHashes[fromHeight : toHeight+1], states[fromHeight : toHeight+1], nil
return machineHashes, states, nil
}

func machineHash(gs validator.GoGlobalState) common.Hash {
Expand Down Expand Up @@ -309,7 +291,7 @@ func (s *BOLDStateProvider) findGlobalStateFromMessageCountAndBatch(count arbuti
// and up to a required batch index. The hashes used for this commitment are the machine hashes
// at each message number.
func (s *BOLDStateProvider) L2MessageStatesUpTo(
_ context.Context,
ctx context.Context,
fromHeight l2stateprovider.Height,
toHeight option.Option[l2stateprovider.Height],
fromBatch,
Expand All @@ -321,7 +303,7 @@ func (s *BOLDStateProvider) L2MessageStatesUpTo(
} else {
to = s.blockChallengeLeafHeight
}
items, _, err := s.StatesInBatchRange(fromHeight, to, fromBatch, toBatch)
items, _, err := s.StatesInBatchRange(ctx, fromHeight, to, fromBatch, toBatch)
if err != nil {
return nil, err
}
Expand Down
12 changes: 5 additions & 7 deletions system_tests/bold_challenge_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,10 @@ func createTestNodeOnL1ForBoldProtocol(
l1info.SetContract("Rollup", addresses.Rollup)
l1info.SetContract("UpgradeExecutor", addresses.UpgradeExecutor)

cacheConfig := TestCachingConfig
cacheConfig.StateScheme = rawdb.HashScheme
_, l2stack, l2chainDb, l2arbDb, l2blockchain = createL2BlockChainWithStackConfig(t, l2info, "", chainConfig, getInitMessage(ctx, t, l1client, addresses), stackConfig, &cacheConfig)
execConfig := ExecConfigDefaultNonSequencerTest(t)
Require(t, execConfig.Validate())
execConfig.Caching.StateScheme = rawdb.HashScheme
_, l2stack, l2chainDb, l2arbDb, l2blockchain = createL2BlockChain(t, l2info, "", chainConfig, execConfig)
var sequencerTxOptsPtr *bind.TransactOpts
var dataSigner signature.DataSignerFunc
if isSequencer {
Expand All @@ -560,9 +561,6 @@ func createTestNodeOnL1ForBoldProtocol(

AddValNodeIfNeeded(t, ctx, nodeConfig, true, "", "")

execConfig := ExecConfigDefaultNonSequencerTest()
Require(t, execConfig.Validate())
execConfig.Caching.StateScheme = rawdb.HashScheme
execConfigFetcher := func() *gethexec.Config { return execConfig }
execNode, err := gethexec.CreateExecutionNode(ctx, l2stack, l2chainDb, l2blockchain, l1client, execConfigFetcher)
Require(t, err)
Expand Down Expand Up @@ -765,7 +763,7 @@ func create2ndNodeWithConfigForBoldProtocol(
initReader := statetransfer.NewMemoryInitDataReader(l2InitData)
initMessage := getInitMessage(ctx, t, l1client, first.DeployInfo)

execConfig := ExecConfigDefaultNonSequencerTest()
execConfig := ExecConfigDefaultNonSequencerTest(t)
Require(t, execConfig.Validate())
execConfig.Caching.StateScheme = rawdb.HashScheme
coreCacheConfig := gethexec.DefaultCacheConfigFor(l2stack, &execConfig.Caching)
Expand Down
2 changes: 1 addition & 1 deletion system_tests/bold_state_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
toBatch := l2stateprovider.Batch(3)
fromHeight := l2stateprovider.Height(0)
toHeight := l2stateprovider.Height(14)
stateRoots, states, err := stateManager.StatesInBatchRange(fromHeight, toHeight, fromBatch, toBatch)
stateRoots, states, err := stateManager.StatesInBatchRange(ctx, fromHeight, toHeight, fromBatch, toBatch)
Require(t, err)

if len(stateRoots) != 15 {
Expand Down

0 comments on commit b6cb5a1

Please sign in to comment.