diff --git a/core/state.go b/core/state.go index 6a62e9d86..5cbdcc5c4 100644 --- a/core/state.go +++ b/core/state.go @@ -48,11 +48,15 @@ type StateReader interface { type State struct { txn db.Transaction + + // This map holds the contract objects which are being updated in the current state update. + contracts map[felt.Felt]*StateContract } func NewState(txn db.Transaction) *State { return &State{ - txn: txn, + txn: txn, + contracts: make(map[felt.Felt]*StateContract), } } @@ -292,7 +296,6 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } - contracts := make(map[felt.Felt]*StateContract) // register deployed contracts for addr, classHash := range update.StateDiff.DeployedContracts { // check if contract is already deployed @@ -305,14 +308,14 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } - contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber) + s.contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber) } - if err = s.updateContracts(blockNumber, update.StateDiff, true, contracts); err != nil { + if err = s.updateContracts(blockNumber, update.StateDiff, true); err != nil { return err } - if err = s.Commit(stateTrie, contracts, true, blockNumber); err != nil { + if err = s.Commit(stateTrie, true, blockNumber); err != nil { return fmt.Errorf("state commit: %v", err) } @@ -386,7 +389,6 @@ var ( // Commit updates the state by committing the dirty contracts to the database. func (s *State) Commit( stateTrie *trie.Trie, - contracts map[felt.Felt]*StateContract, logChanges bool, blockNumber uint64, ) error { @@ -396,22 +398,21 @@ func (s *State) Commit( } // // sort the contracts in descending storage diff order - keys := slices.SortedStableFunc(maps.Keys(contracts), func(a, b felt.Felt) int { - return len(contracts[a].dirtyStorage) - len(contracts[b].dirtyStorage) + keys := slices.SortedStableFunc(maps.Keys(s.contracts), func(a, b felt.Felt) int { + return len(s.contracts[a].dirtyStorage) - len(s.contracts[b].dirtyStorage) }) contractPools := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) for _, addr := range keys { - contract := contracts[addr] contractPools.Go(func() (*bufferedTransactionWithAddress, error) { - txn, err := contract.BufferedCommit(s.txn, logChanges, blockNumber) + txn, err := s.contracts[addr].BufferedCommit(s.txn, logChanges, blockNumber) if err != nil { return nil, err } return &bufferedTransactionWithAddress{ txn: txn, - addr: contract.Address, + addr: &addr, }, nil }) } @@ -432,39 +433,37 @@ func (s *State) Commit( } } - for _, contract := range contracts { + for _, contract := range s.contracts { if err := s.updateContractCommitment(stateTrie, contract); err != nil { return err } } + // finally, clear the contracts map + s.contracts = make(map[felt.Felt]*StateContract) + return nil } -func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool, contracts map[felt.Felt]*StateContract) error { - if contracts == nil { - return fmt.Errorf("contracts is nil") - } - - if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges, contracts); err != nil { +func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool) error { + if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges); err != nil { return err } - if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges, contracts); err != nil { + if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges); err != nil { return err } - return s.updateContractStorages(blockNumber, diff.StorageDiffs, contracts) + return s.updateContractStorages(blockNumber, diff.StorageDiffs) } func (s *State) updateContractClasses( blockNumber uint64, replacedClasses map[felt.Felt]*felt.Felt, logChanges bool, - contracts map[felt.Felt]*StateContract, ) error { for addr, classHash := range replacedClasses { - contract, err := s.getContract(addr, contracts) + contract, err := s.getContract(addr) if err != nil { return err } @@ -484,10 +483,9 @@ func (s *State) updateContractNonces( blockNumber uint64, nonces map[felt.Felt]*felt.Felt, logChanges bool, - contracts map[felt.Felt]*StateContract, ) error { for addr, nonce := range nonces { - contract, err := s.getContract(addr, contracts) + contract, err := s.getContract(addr) if err != nil { return err } @@ -506,14 +504,13 @@ func (s *State) updateContractNonces( func (s *State) updateContractStorages( blockNumber uint64, storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt, - contracts map[felt.Felt]*StateContract, ) error { for addr, diff := range storageDiffs { - contract, err := s.getContract(addr, contracts) + contract, err := s.getContract(addr) if err != nil { if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) - contracts[addr] = contract + s.contracts[addr] = contract } else { return err } @@ -524,15 +521,15 @@ func (s *State) updateContractStorages( return nil } -func (s *State) getContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) { - contract, ok := contracts[addr] +func (s *State) getContract(addr felt.Felt) (*StateContract, error) { + contract, ok := s.contracts[addr] if !ok { var err error contract, err = GetContract(&addr, s.txn) if err != nil { return nil, err } - contracts[addr] = contract + s.contracts[addr] = contract } return contract, nil } @@ -655,12 +652,11 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return err } - contracts := make(map[felt.Felt]*StateContract) - if err = s.updateContracts(blockNumber, reversedDiff, false, contracts); err != nil { + if err = s.updateContracts(blockNumber, reversedDiff, false); err != nil { return fmt.Errorf("update contracts: %v", err) } - if err = s.Commit(stateTrie, contracts, false, blockNumber); err != nil { + if err = s.Commit(stateTrie, false, blockNumber); err != nil { return fmt.Errorf("state commit: %v", err) }