From 00b54ef90645178d52a00be83ad8933705344670 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Fri, 8 Sep 2023 17:34:58 +0300 Subject: [PATCH 01/10] collect state changes --- genesis/mock/userAccountMock.go | 5 +- state/accountsDB.go | 125 ++++++++++++------ state/accountsDB_test.go | 8 +- state/interface.go | 14 +- state/journalEntries.go | 4 +- state/stateChangesCollector.go | 38 ++++++ state/trackableDataTrie/trackableDataTrie.go | 125 +++++++++++++----- .../trackableDataTrie_test.go | 59 +++++++-- testscommon/state/accountWrapperMock.go | 4 +- testscommon/state/userAccountStub.go | 4 +- testscommon/trie/dataTrieTrackerStub.go | 7 +- 11 files changed, 294 insertions(+), 99 deletions(-) create mode 100644 state/stateChangesCollector.go diff --git a/genesis/mock/userAccountMock.go b/genesis/mock/userAccountMock.go index c64e7da2a70..10362edba89 100644 --- a/genesis/mock/userAccountMock.go +++ b/genesis/mock/userAccountMock.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/state" ) // ErrNegativeValue - @@ -147,8 +148,8 @@ func (uam *UserAccountMock) GetUserName() []byte { } // SaveDirtyData - -func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { - return nil, nil +func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) ([]state.DataTrieChange, []core.TrieData, error) { + return nil, nil, nil } // IsGuarded - diff --git a/state/accountsDB.go b/state/accountsDB.go index bc41d151da1..8446d8e12b7 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -85,13 +85,13 @@ type AccountsDB struct { obsoleteDataTrieHashes map[string][][]byte snapshotsManger SnapshotsManager - lastRootHash []byte - dataTries common.TriesHolder - entries []JournalEntry - - mutOp sync.RWMutex - loadCodeMeasurements *loadingMeasurements - addressConverter core.PubkeyConverter + lastRootHash []byte + dataTries common.TriesHolder + entries []JournalEntry + stateChangesCollector StateChangesCollector + mutOp sync.RWMutex + loadCodeMeasurements *loadingMeasurements + addressConverter core.PubkeyConverter stackDebug []byte } @@ -161,8 +161,9 @@ func createAccountsDb(args ArgsAccountsDB, snapshotManager SnapshotsManager) *Ac loadCodeMeasurements: &loadingMeasurements{ identifier: "load code", }, - addressConverter: args.AddressConverter, - snapshotsManger: snapshotManager, + addressConverter: args.AddressConverter, + snapshotsManger: snapshotManager, + stateChangesCollector: NewStateChangesCollector(), } } @@ -247,7 +248,8 @@ func (adb *AccountsDB) ImportAccount(account vmcommon.AccountHandler) error { } mainTrie := adb.getMainTrie() - return adb.saveAccountToTrie(account, mainTrie) + _, err := adb.saveAccountToTrie(account, mainTrie) + return err } func (adb *AccountsDB) getMainTrie() common.Trie { @@ -287,28 +289,45 @@ func (adb *AccountsDB) SaveAccount(account vmcommon.AccountHandler) error { adb.journalize(entry) } - err = adb.saveCodeAndDataTrie(oldAccount, account) + newDataTrieValues, err := adb.saveCodeAndDataTrie(oldAccount, account) + if err != nil { + return err + } + + marshalledAccount, err := adb.saveAccountToTrie(account, adb.mainTrie) if err != nil { return err } - return adb.saveAccountToTrie(account, adb.mainTrie) + stateChange := StateChangeDTO{ + MainTrieKey: account.AddressBytes(), + MainTrieVal: marshalledAccount, + DataTrieChanges: newDataTrieValues, + } + adb.stateChangesCollector.AddStateChange(stateChange) + + return err } -func (adb *AccountsDB) saveCodeAndDataTrie(oldAcc, newAcc vmcommon.AccountHandler) error { +func (adb *AccountsDB) saveCodeAndDataTrie(oldAcc, newAcc vmcommon.AccountHandler) ([]DataTrieChange, error) { baseNewAcc, newAccOk := newAcc.(baseAccountHandler) baseOldAccount, _ := oldAcc.(baseAccountHandler) if !newAccOk { - return nil + return make([]DataTrieChange, 0), nil } - err := adb.saveDataTrie(baseNewAcc) + newValues, err := adb.saveDataTrie(baseNewAcc) if err != nil { - return err + return nil, err + } + + err = adb.saveCode(baseNewAcc, baseOldAccount) + if err != nil { + return nil, err } - return adb.saveCode(baseNewAcc, baseOldAccount) + return newValues, err } func (adb *AccountsDB) saveCode(newAcc, oldAcc baseAccountHandler) error { @@ -375,15 +394,29 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error return nil, err } + stateChange := StateChangeDTO{ + MainTrieKey: oldCodeHash, + MainTrieVal: nil, + DataTrieChanges: nil, + } + adb.stateChangesCollector.AddStateChange(stateChange) + return unmodifiedOldCodeEntry, nil } oldCodeEntry.NumReferences-- - err = saveCodeEntry(oldCodeHash, oldCodeEntry, adb.mainTrie, adb.marshaller) + codeEntryBytes, err := saveCodeEntry(oldCodeHash, oldCodeEntry, adb.mainTrie, adb.marshaller) if err != nil { return nil, err } + stateChange := StateChangeDTO{ + MainTrieKey: oldCodeHash, + MainTrieVal: codeEntryBytes, + DataTrieChanges: nil, + } + adb.stateChangesCollector.AddStateChange(stateChange) + return unmodifiedOldCodeEntry, nil } @@ -404,11 +437,18 @@ func (adb *AccountsDB) updateNewCodeEntry(newCodeHash []byte, newCode []byte) er } newCodeEntry.NumReferences++ - err = saveCodeEntry(newCodeHash, newCodeEntry, adb.mainTrie, adb.marshaller) + codeEntryBytes, err := saveCodeEntry(newCodeHash, newCodeEntry, adb.mainTrie, adb.marshaller) if err != nil { return err } + stateChange := StateChangeDTO{ + MainTrieKey: newCodeHash, + MainTrieVal: codeEntryBytes, + DataTrieChanges: nil, + } + adb.stateChangesCollector.AddStateChange(stateChange) + return nil } @@ -431,18 +471,18 @@ func getCodeEntry(codeHash []byte, trie Updater, marshalizer marshal.Marshalizer return &codeEntry, nil } -func saveCodeEntry(codeHash []byte, entry *CodeEntry, trie Updater, marshalizer marshal.Marshalizer) error { +func saveCodeEntry(codeHash []byte, entry *CodeEntry, trie Updater, marshalizer marshal.Marshalizer) ([]byte, error) { codeEntry, err := marshalizer.Marshal(entry) if err != nil { - return err + return nil, err } err = trie.Update(codeHash, codeEntry) if err != nil { - return err + return nil, err } - return nil + return codeEntry, nil } // loadDataTrieConcurrentSafe retrieves and saves the SC data inside accountHandler object. @@ -473,27 +513,24 @@ func (adb *AccountsDB) loadDataTrieConcurrentSafe(accountHandler baseAccountHand // SaveDataTrie is used to save the data trie (not committing it) and to recompute the new Root value // If data is not dirtied, method will not create its JournalEntries to keep track of data modification -func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) error { - oldValues, err := accountHandler.SaveDirtyData(adb.mainTrie) +func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) ([]DataTrieChange, error) { + newValues, oldValues, err := accountHandler.SaveDirtyData(adb.mainTrie) if err != nil { - return err + return nil, err } if len(oldValues) == 0 { - return nil + return nil, nil } entry, err := NewJournalEntryDataTrieUpdates(oldValues, accountHandler) if err != nil { - return err + return nil, err } adb.journalize(entry) - //TODO in order to avoid recomputing the root hash after every transaction for the same data trie, - // benchmark if it is better to cache the account and compute the rootHash only when the state is committed. - // For this to work, LoadAccount should check that cache first, and only after load from the trie. rootHash, err := accountHandler.DataTrie().RootHash() if err != nil { - return err + return nil, err } accountHandler.SetRootHash(rootHash) @@ -501,16 +538,16 @@ func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) error { trie, ok := accountHandler.DataTrie().(common.Trie) if !ok { log.Warn("wrong type conversion", "trie type", fmt.Sprintf("%T", accountHandler.DataTrie())) - return nil + return nil, nil } adb.dataTries.Put(accountHandler.AddressBytes(), trie) } - return nil + return newValues, nil } -func (adb *AccountsDB) saveAccountToTrie(accountHandler vmcommon.AccountHandler, mainTrie common.Trie) error { +func (adb *AccountsDB) saveAccountToTrie(accountHandler vmcommon.AccountHandler, mainTrie common.Trie) ([]byte, error) { log.Trace("accountsDB.saveAccountToTrie", "address", hex.EncodeToString(accountHandler.AddressBytes()), ) @@ -518,10 +555,15 @@ func (adb *AccountsDB) saveAccountToTrie(accountHandler vmcommon.AccountHandler, // pass the reference to marshaller, otherwise it will fail marshalling balance buff, err := adb.marshaller.Marshal(accountHandler) if err != nil { - return err + return nil, err + } + + err = mainTrie.Update(accountHandler.AddressBytes(), buff) + if err != nil { + return nil, err } - return mainTrie.Update(accountHandler.AddressBytes(), buff) + return buff, nil } // RemoveAccount removes the account data from underlying trie. @@ -785,7 +827,7 @@ func (adb *AccountsDB) RevertToSnapshot(snapshot int) error { } if !check.IfNil(account) { - err = adb.saveAccountToTrie(account, adb.mainTrie) + _, err = adb.saveAccountToTrie(account, adb.mainTrie) if err != nil { return err } @@ -1225,6 +1267,13 @@ func collectStats( log.Debug(strings.Join(trieStats.ToString(), " ")) } +func (adb *AccountsDB) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { + stateChanges := adb.stateChangesCollector.GetStateChanges() + adb.stateChangesCollector.Reset() + + return stateChanges, nil +} + // IsSnapshotInProgress returns true if there is a snapshot in progress func (adb *AccountsDB) IsSnapshotInProgress() bool { return adb.snapshotsManger.IsSnapshotInProgress() diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 61bba6f978a..d4c902c1d5c 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -342,14 +342,16 @@ func TestAccountsDB_SaveAccountSavesCodeAndDataTrieForUserAccount(t *testing.T) }) dtt := &trieMock.DataTrieTrackerStub{ - SaveDirtyDataCalled: func(_ common.Trie) ([]core.TrieData, error) { - return []core.TrieData{ + SaveDirtyDataCalled: func(_ common.Trie) ([]state.DataTrieChange, []core.TrieData, error) { + var stateChanges []state.DataTrieChange + oldVal := []core.TrieData{ { Key: []byte("key"), Value: []byte("value"), Version: 0, }, - }, nil + } + return stateChanges, oldVal, nil }, DataTrieCalled: func() common.Trie { return trieStub diff --git a/state/interface.go b/state/interface.go index 56dd0e1b8c4..c42ef35256b 100644 --- a/state/interface.go +++ b/state/interface.go @@ -120,7 +120,7 @@ type baseAccountHandler interface { GetRootHash() []byte SetDataTrie(trie common.Trie) DataTrie() common.DataTrieHandler - SaveDirtyData(trie common.Trie) ([]core.TrieData, error) + SaveDirtyData(trie common.Trie) ([]DataTrieChange, []core.TrieData, error) IsInterfaceNil() bool } @@ -183,7 +183,8 @@ type DataTrie interface { } // PeerAccountHandler models a peer state account, which can journalize a normal account's data -// with some extra features like signing statistics or rating information +// +// with some extra features like signing statistics or rating information type PeerAccountHandler interface { SetBLSPublicKey([]byte) error GetRewardAddress() []byte @@ -255,7 +256,7 @@ type DataTrieTracker interface { SaveKeyValue(key []byte, value []byte) error SetDataTrie(tr common.Trie) DataTrie() common.DataTrieHandler - SaveDirtyData(common.Trie) ([]core.TrieData, error) + SaveDirtyData(common.Trie) ([]DataTrieChange, []core.TrieData, error) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDataTrieLeaves) error IsInterfaceNil() bool } @@ -265,3 +266,10 @@ type SignRate interface { GetNumSuccess() uint32 GetNumFailure() uint32 } + +type StateChangesCollector interface { + AddStateChange(stateChange StateChangeDTO) + GetStateChanges() []StateChangeDTO + Reset() + IsInterfaceNil() bool +} diff --git a/state/journalEntries.go b/state/journalEntries.go index cf26b504689..a4fa75dcf5a 100644 --- a/state/journalEntries.go +++ b/state/journalEntries.go @@ -66,7 +66,7 @@ func (jea *journalEntryCode) revertOldCodeEntry() error { return nil } - err := saveCodeEntry(jea.oldCodeHash, jea.oldCodeEntry, jea.trie, jea.marshalizer) + _, err := saveCodeEntry(jea.oldCodeHash, jea.oldCodeEntry, jea.trie, jea.marshalizer) if err != nil { return err } @@ -94,7 +94,7 @@ func (jea *journalEntryCode) revertNewCodeEntry() error { } newCodeEntry.NumReferences-- - err = saveCodeEntry(jea.newCodeHash, newCodeEntry, jea.trie, jea.marshalizer) + _, err = saveCodeEntry(jea.newCodeHash, newCodeEntry, jea.trie, jea.marshalizer) if err != nil { return err } diff --git a/state/stateChangesCollector.go b/state/stateChangesCollector.go new file mode 100644 index 00000000000..fe69cef0184 --- /dev/null +++ b/state/stateChangesCollector.go @@ -0,0 +1,38 @@ +package state + +type StateChangeDTO struct { + MainTrieKey []byte + MainTrieVal []byte + DataTrieChanges []DataTrieChange +} + +type DataTrieChange struct { + Key []byte + Val []byte +} + +type stateChangesCollector struct { + stateChanges []StateChangeDTO +} + +func NewStateChangesCollector() *stateChangesCollector { + return &stateChangesCollector{ + stateChanges: []StateChangeDTO{}, + } +} + +func (scc *stateChangesCollector) AddStateChange(stateChange StateChangeDTO) { + scc.stateChanges = append(scc.stateChanges, stateChange) +} + +func (scc *stateChangesCollector) GetStateChanges() []StateChangeDTO { + return scc.stateChanges +} + +func (scc *stateChangesCollector) Reset() { + scc.stateChanges = []StateChangeDTO{} +} + +func (scc *stateChangesCollector) IsInterfaceNil() bool { + return scc == nil +} diff --git a/state/trackableDataTrie/trackableDataTrie.go b/state/trackableDataTrie/trackableDataTrie.go index d08af345ef7..7ff84ddcfaf 100644 --- a/state/trackableDataTrie/trackableDataTrie.go +++ b/state/trackableDataTrie/trackableDataTrie.go @@ -1,7 +1,9 @@ package trackableDataTrie import ( + "bytes" "fmt" + "sort" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" @@ -19,6 +21,7 @@ import ( var log = logger.GetOrCreate("state/trackableDataTrie") type dirtyData struct { + index int value []byte newVersion core.TrieNodeVersion } @@ -97,6 +100,7 @@ func (tdt *trackableDataTrie) SaveKeyValue(key []byte, value []byte) error { } dataEntry := dirtyData{ + index: tdt.getIndexForKey(key), value: value, newVersion: core.GetVersionForNewData(tdt.enableEpochsHandler), } @@ -126,22 +130,32 @@ func (tdt *trackableDataTrie) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDat dataToBeMigrated := args.TrieMigrator.GetLeavesToBeMigrated() for _, leafData := range dataToBeMigrated { - dataEntry := dirtyData{ - value: leafData.Value, - newVersion: args.NewVersion, - } - originalKey, err := tdt.getOriginalKeyFromTrieData(leafData) if err != nil { return err } + dataEntry := dirtyData{ + index: tdt.getIndexForKey(originalKey), + value: leafData.Value, + newVersion: args.NewVersion, + } + tdt.dirtyData[string(originalKey)] = dataEntry } return nil } +func (tdt *trackableDataTrie) getIndexForKey(key []byte) int { + existingVal, ok := tdt.dirtyData[string(key)] + if ok { + return existingVal.index + } + + return len(tdt.dirtyData) +} + func (tdt *trackableDataTrie) getOriginalKeyFromTrieData(trieData core.TrieData) ([]byte, error) { if trieData.Version == core.AutoBalanceEnabled { valWithMetadata := &dataTrieValue.TrieLeafData{} @@ -196,15 +210,15 @@ func (tdt *trackableDataTrie) DataTrie() common.DataTrieHandler { } // SaveDirtyData saved the dirty data to the trie -func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]core.TrieData, error) { +func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]state.DataTrieChange, []core.TrieData, error) { if len(tdt.dirtyData) == 0 { - return make([]core.TrieData, 0), nil + return make([]state.DataTrieChange, 0), make([]core.TrieData, 0), nil } if check.IfNil(tdt.tr) { newDataTrie, err := mainTrie.Recreate(make([]byte, 0)) if err != nil { - return nil, err + return nil, nil, err } tdt.tr = newDataTrie @@ -212,39 +226,76 @@ func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]core.TrieDa dtr, ok := tdt.tr.(state.DataTrie) if !ok { - return nil, fmt.Errorf("invalid trie, type is %T", tdt.tr) + return nil, nil, fmt.Errorf("invalid trie, type is %T", tdt.tr) } return tdt.updateTrie(dtr) } -func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]core.TrieData, error) { +func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]state.DataTrieChange, []core.TrieData, error) { oldValues := make([]core.TrieData, len(tdt.dirtyData)) + stateChanges := make([]state.DataTrieChange, len(tdt.dirtyData)) + deletedKeys := make([]state.DataTrieChange, 0) index := 0 for key, dataEntry := range tdt.dirtyData { oldVal, _, err := tdt.retrieveValueFromTrie([]byte(key)) if err != nil { - return nil, err + return nil, nil, err } oldValues[index] = oldVal - err = tdt.deleteOldEntryIfMigrated([]byte(key), dataEntry, oldVal) + wasDeleted, err := tdt.deleteOldEntryIfMigrated([]byte(key), dataEntry, oldVal) if err != nil { - return nil, err + return nil, nil, err } - err = tdt.modifyTrie([]byte(key), dataEntry, oldVal, dtr) + if wasDeleted { + deletedKeys = append(deletedKeys, + state.DataTrieChange{ + Key: []byte(key), + Val: nil, + }, + ) + } + + dataTrieKey, dataTrieVal, err := tdt.modifyTrie([]byte(key), dataEntry, oldVal, dtr) if err != nil { - return nil, err + return nil, nil, err } index++ + + if len(dataTrieKey) == 0 { + continue + } + + if dataEntry.index > len(stateChanges)-1 { + return nil, nil, fmt.Errorf("index out of range") + } + + stateChanges[dataEntry.index] = state.DataTrieChange{ + Key: dataTrieKey, + Val: dataTrieVal, + } } tdt.dirtyData = make(map[string]dirtyData) - return oldValues, nil + for i := range stateChanges { + if len(stateChanges[i].Key) != 0 { + continue + } + + stateChanges = append(stateChanges[:i], stateChanges[i+1:]...) + } + + sort.Slice(deletedKeys, func(i, j int) bool { + return bytes.Compare(deletedKeys[i].Key, deletedKeys[j].Key) < 0 + }) + stateChanges = append(stateChanges, deletedKeys...) + + return stateChanges, oldValues, nil } func (tdt *trackableDataTrie) retrieveValueFromTrie(key []byte) (core.TrieData, uint32, error) { @@ -320,48 +371,60 @@ func (tdt *trackableDataTrie) getValueNotSpecifiedVersion(key []byte, val []byte return trimmedValue, nil } -func (tdt *trackableDataTrie) deleteOldEntryIfMigrated(key []byte, newData dirtyData, oldEntry core.TrieData) error { +func (tdt *trackableDataTrie) deleteOldEntryIfMigrated(key []byte, newData dirtyData, oldEntry core.TrieData) (bool, error) { if !tdt.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { - return nil + return false, nil } isMigration := oldEntry.Version == core.NotSpecified && newData.newVersion == core.AutoBalanceEnabled if isMigration && len(newData.value) != 0 { - return tdt.tr.Delete(key) + return true, tdt.tr.Delete(key) } - return nil + return false, nil } -func (tdt *trackableDataTrie) modifyTrie(key []byte, dataEntry dirtyData, oldVal core.TrieData, dtr state.DataTrie) error { +func (tdt *trackableDataTrie) modifyTrie(key []byte, dataEntry dirtyData, oldVal core.TrieData, dtr state.DataTrie) ([]byte, []byte, error) { + version := dataEntry.newVersion + newKey := tdt.getKeyForVersion(key, version) + if len(dataEntry.value) == 0 { - return tdt.deleteFromTrie(oldVal, key, dtr) + deletedKey, err := tdt.deleteFromTrie(oldVal, key, dtr) + if err != nil { + return nil, nil, err + } + + return deletedKey, nil, nil } - version := dataEntry.newVersion - newKey := tdt.getKeyForVersion(key, version) value, err := tdt.getValueForVersion(key, dataEntry.value, version) if err != nil { - return err + return nil, nil, err } - return dtr.UpdateWithVersion(newKey, value, version) + err = dtr.UpdateWithVersion(newKey, value, version) + if err != nil { + return nil, nil, err + } + + return newKey, value, nil } -func (tdt *trackableDataTrie) deleteFromTrie(oldVal core.TrieData, key []byte, dtr state.DataTrie) error { +func (tdt *trackableDataTrie) deleteFromTrie(oldVal core.TrieData, key []byte, dtr state.DataTrie) ([]byte, error) { if len(oldVal.Value) == 0 { - return nil + return nil, nil } if oldVal.Version == core.AutoBalanceEnabled { - return dtr.Delete(tdt.hasher.Compute(string(key))) + keyForTrie := tdt.hasher.Compute(string(key)) + return keyForTrie, dtr.Delete(keyForTrie) } if oldVal.Version == core.NotSpecified { - return dtr.Delete(key) + return key, dtr.Delete(key) } - return nil + return nil, nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index 38f5b9d33fa..7269b30d0b6 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -335,9 +335,10 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) - oldValues, err := tdt.SaveDirtyData(&trieMock.TrieStub{}) + stateChanges, oldValues, err := tdt.SaveDirtyData(&trieMock.TrieStub{}) assert.Nil(t, err) assert.Equal(t, 0, len(oldValues)) + assert.Equal(t, 0, len(stateChanges)) }) t.Run("nil trie creates a new trie", func(t *testing.T) { @@ -360,12 +361,17 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) key := []byte("key") - _ = tdt.SaveKeyValue(key, []byte("val")) - oldValues, err := tdt.SaveDirtyData(trie) + val := []byte("val") + newVal := []byte("valkeyidentifier") + _ = tdt.SaveKeyValue(key, val) + stateChanges, oldValues, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 1, len(oldValues)) assert.Equal(t, key, oldValues[0].Key) assert.Equal(t, []byte(nil), oldValues[0].Value) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, key, stateChanges[0].Key) + assert.Equal(t, newVal, stateChanges[0].Val) assert.True(t, recreateCalled) }) @@ -416,11 +422,16 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, expectedVal) - oldValues, err := tdt.SaveDirtyData(trie) + stateChanges, oldValues, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 1, len(oldValues)) assert.Equal(t, expectedKey, oldValues[0].Key) assert.Equal(t, value, oldValues[0].Value) + assert.Equal(t, 2, len(stateChanges)) + assert.Equal(t, hasher.Compute(string(expectedKey)), stateChanges[0].Key) + assert.Equal(t, serializedTrieVal, stateChanges[0].Val) + assert.Equal(t, expectedKey, stateChanges[1].Key) + assert.Equal(t, []byte(nil), stateChanges[1].Val) assert.True(t, deleteCalled) assert.True(t, updateCalled) }) @@ -463,11 +474,14 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) - oldValues, err := tdt.SaveDirtyData(trie) + stateChanges, oldValues, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 1, len(oldValues)) assert.Equal(t, expectedKey, oldValues[0].Key) assert.Equal(t, expectedVal, oldValues[0].Value) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, expectedKey, stateChanges[0].Key) + assert.Equal(t, expectedVal, stateChanges[0].Val) assert.True(t, updateCalled) }) @@ -522,11 +536,14 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) - oldValues, err := tdt.SaveDirtyData(trie) + stateChanges, oldValues, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 1, len(oldValues)) assert.Equal(t, hasher.Compute(string(expectedKey)), oldValues[0].Key) assert.Equal(t, serializedOldTrieVal, oldValues[0].Value) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, hasher.Compute(string(expectedKey)), stateChanges[0].Key) + assert.Equal(t, serializedNewTrieVal, stateChanges[0].Val) assert.True(t, updateCalled) }) @@ -570,11 +587,14 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) - oldValues, err := tdt.SaveDirtyData(trie) + stateChanges, oldValues, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 1, len(oldValues)) assert.Equal(t, hasher.Compute(string(expectedKey)), oldValues[0].Key) assert.Equal(t, []byte(nil), oldValues[0].Value) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, hasher.Compute(string(expectedKey)), stateChanges[0].Key) + assert.Equal(t, serializedNewTrieVal, stateChanges[0].Val) assert.True(t, updateCalled) }) @@ -597,7 +617,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) - _, err := tdt.SaveDirtyData(trie) + _, _, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 0, len(tdt.DirtyData())) }) @@ -622,10 +642,13 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) - _, err := tdt.SaveDirtyData(trie) + stateChanges, _, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 0, len(tdt.DirtyData())) assert.True(t, updateCalled) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, expectedKey, stateChanges[0].Key) + assert.Equal(t, []byte(nil), stateChanges[0].Val) }) t.Run("nil val and nil old val", func(t *testing.T) { @@ -648,10 +671,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) - _, err := tdt.SaveDirtyData(trie) + stateChanges, _, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 0, len(tdt.DirtyData())) assert.False(t, deleteCalled) + assert.Equal(t, 0, len(stateChanges)) }) t.Run("nil val autobalance enabled, old val saved at hashedKey", func(t *testing.T) { @@ -682,10 +706,13 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) - _, err := tdt.SaveDirtyData(trie) + stateChanges, _, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 0, len(tdt.DirtyData())) assert.True(t, deleteCalled) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, hasher.Compute(string(expectedKey)), stateChanges[0].Key) + assert.Equal(t, []byte(nil), stateChanges[0].Val) }) t.Run("nil val autobalance enabled, old val saved at key", func(t *testing.T) { @@ -715,10 +742,13 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) - _, err := tdt.SaveDirtyData(trie) + stateChanges, _, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 0, len(tdt.DirtyData())) assert.Equal(t, 1, deleteCalled) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, expectedKey, stateChanges[0].Key) + assert.Equal(t, []byte(nil), stateChanges[0].Val) }) t.Run("not present in trie - autobalance disabled", func(t *testing.T) { @@ -761,12 +791,15 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) - oldValues, err := tdt.SaveDirtyData(trie) + stateChanges, oldValues, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 1, len(oldValues)) assert.Equal(t, expectedKey, oldValues[0].Key) assert.Equal(t, []byte(nil), oldValues[0].Value) assert.True(t, updateCalled) + assert.Equal(t, 1, len(stateChanges)) + assert.Equal(t, expectedKey, stateChanges[0].Key) + assert.Equal(t, valueWithMetadata, stateChanges[0].Val) }) } diff --git a/testscommon/state/accountWrapperMock.go b/testscommon/state/accountWrapperMock.go index 9cbac29d8ce..75c7ab97517 100644 --- a/testscommon/state/accountWrapperMock.go +++ b/testscommon/state/accountWrapperMock.go @@ -196,7 +196,7 @@ func (awm *AccountWrapMock) DataTrie() common.DataTrieHandler { } // SaveDirtyData - -func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { +func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) ([]state.DataTrieChange, []core.TrieData, error) { return awm.trackableDataTrie.SaveDirtyData(trie) } @@ -205,7 +205,7 @@ func (awm *AccountWrapMock) SetDataTrie(trie common.Trie) { awm.trackableDataTrie.SetDataTrie(trie) } -//IncreaseNonce adds the given value to the current nonce +// IncreaseNonce adds the given value to the current nonce func (awm *AccountWrapMock) IncreaseNonce(val uint64) { awm.nonce = awm.nonce + val } diff --git a/testscommon/state/userAccountStub.go b/testscommon/state/userAccountStub.go index 3e4278b2d38..90fc1f88ab3 100644 --- a/testscommon/state/userAccountStub.go +++ b/testscommon/state/userAccountStub.go @@ -185,8 +185,8 @@ func (u *UserAccountStub) IsGuarded() bool { } // SaveDirtyData - -func (u *UserAccountStub) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { - return nil, nil +func (u *UserAccountStub) SaveDirtyData(_ common.Trie) ([]state.DataTrieChange, []core.TrieData, error) { + return nil, nil, nil } // IsInterfaceNil - diff --git a/testscommon/trie/dataTrieTrackerStub.go b/testscommon/trie/dataTrieTrackerStub.go index ead12e35af6..60eca7ef978 100644 --- a/testscommon/trie/dataTrieTrackerStub.go +++ b/testscommon/trie/dataTrieTrackerStub.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -15,7 +16,7 @@ type DataTrieTrackerStub struct { SaveKeyValueCalled func(key []byte, value []byte) error SetDataTrieCalled func(tr common.Trie) DataTrieCalled func() common.Trie - SaveDirtyDataCalled func(trie common.Trie) ([]core.TrieData, error) + SaveDirtyDataCalled func(trie common.Trie) ([]state.DataTrieChange, []core.TrieData, error) SaveTrieDataCalled func(trieData core.TrieData) error MigrateDataTrieLeavesCalled func(args vmcommon.ArgsMigrateDataTrieLeaves) error } @@ -61,12 +62,12 @@ func (dtts *DataTrieTrackerStub) DataTrie() common.DataTrieHandler { } // SaveDirtyData - -func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) ([]core.TrieData, error) { +func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) ([]state.DataTrieChange, []core.TrieData, error) { if dtts.SaveDirtyDataCalled != nil { return dtts.SaveDirtyDataCalled(mainTrie) } - return make([]core.TrieData, 0), nil + return make([]state.DataTrieChange, 0), make([]core.TrieData, 0), nil } // MigrateDataTrieLeaves - From 75dff9879bdc496a438044f9ec569feb32f6403c Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 13 Sep 2023 16:01:18 +0300 Subject: [PATCH 02/10] add more unit tests --- state/accountsDB_test.go | 140 ++++++++++++- state/trackableDataTrie/export_test.go | 6 + .../trackableDataTrie_test.go | 187 +++++++++++------- 3 files changed, 252 insertions(+), 81 deletions(-) diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index d4c902c1d5c..ef961c2f54c 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -24,6 +24,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/state/storagePruningManager" @@ -40,6 +41,7 @@ import ( trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/hashesHolder" + disabledHashesHolder "github.com/multiversx/mx-chain-go/trie/hashesHolder/disabled" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -97,19 +99,25 @@ func generateAddressAccountAccountsDB(trie common.Trie) ([]byte, *stateMock.Acco func getDefaultTrieAndAccountsDb() (common.Trie, *state.AccountsDB) { checkpointHashesHolder := hashesHolder.NewCheckpointHashesHolder(10000000, testscommon.HashSize) - adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) return tr, adb } func getDefaultTrieAndAccountsDbWithCustomDB(db common.BaseStorer) (common.Trie, *state.AccountsDB) { checkpointHashesHolder := hashesHolder.NewCheckpointHashesHolder(10000000, testscommon.HashSize) - adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, db) + adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, db, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + return tr, adb +} + +func getDefaultStateComponentsWithCustomEnableEpochs(enableEpochs common.EnableEpochsHandler) (common.Trie, *state.AccountsDB) { + adb, tr, _ := getDefaultStateComponents(disabledHashesHolder.NewDisabledCheckpointHashesHolder(), testscommon.NewSnapshotPruningStorerMock(), enableEpochs) return tr, adb } func getDefaultStateComponents( hashesHolder trie.CheckpointHashesHolder, db common.BaseStorer, + enableEpochs common.EnableEpochsHandler, ) (*state.AccountsDB, common.Trie, common.StorageManager) { generalCfg := config.TrieStorageManagerConfig{ PruningBufferLen: 1000, @@ -123,7 +131,7 @@ func getDefaultStateComponents( args.MainStorer = db args.CheckpointHashesHolder = hashesHolder trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, enableEpochs, 5) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, @@ -133,7 +141,7 @@ func getDefaultStateComponents( argsAccCreator := factory.ArgsAccountCreator{ Hasher: hasher, Marshaller: marshaller, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EnableEpochsHandler: enableEpochs, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) @@ -394,6 +402,118 @@ func TestAccountsDB_SaveAccountMalfunctionMarshallerShouldErr(t *testing.T) { assert.NotNil(t, err) } +func TestAccountsDB_SaveAccountCollectsAllStateChanges(t *testing.T) { + t.Parallel() + + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + _, adb := getDefaultStateComponentsWithCustomEnableEpochs(enableEpochs) + address := generateRandomByteArray(32) + + stepCreateAccountWithDataTrieAndCode(t, adb, address) + enableEpochs.IsAutoBalanceDataTriesEnabledField = true + stepMigrateDataTrieValAndChangeCode(t, adb, address) +} + +func stepCreateAccountWithDataTrieAndCode( + t *testing.T, + adb *state.AccountsDB, + address []byte, +) { + marshaller := &marshallerMock.MarshalizerMock{} + + acc, _ := adb.LoadAccount(address) + userAcc := acc.(state.UserAccountHandler) + code := []byte("smart contract code") + key1 := []byte("key1") + key2 := []byte("key2") + userAcc.SetCode(code) + _ = userAcc.SaveKeyValue(key1, []byte("value")) + _ = userAcc.SaveKeyValue(key2, []byte("value")) + _ = adb.SaveAccount(userAcc) + + serializedAcc, _ := marshaller.Marshal(userAcc) + codeHash := userAcc.GetCodeHash() + + stateChanges, err := adb.GetStateChangesForTheLatestTransaction() + assert.Nil(t, err) + assert.Equal(t, 2, len(stateChanges)) + + codeStateChange := stateChanges[0] + assert.Equal(t, codeHash, codeStateChange.MainTrieKey) + codeEntry := &state.CodeEntry{ + Code: code, + NumReferences: 1, + } + serializedCodeEntry, _ := marshaller.Marshal(codeEntry) + assert.Equal(t, serializedCodeEntry, codeStateChange.MainTrieVal) + assert.Equal(t, 0, len(codeStateChange.DataTrieChanges)) + + accountStateChange := stateChanges[1] + assert.Equal(t, address, accountStateChange.MainTrieKey) + assert.Equal(t, serializedAcc, accountStateChange.MainTrieVal) + assert.Equal(t, 2, len(accountStateChange.DataTrieChanges)) + assert.Equal(t, key1, accountStateChange.DataTrieChanges[0].Key) + valWithMetadata1 := append([]byte("value"), key1...) + valWithMetadata1 = append(valWithMetadata1, address...) + assert.Equal(t, valWithMetadata1, accountStateChange.DataTrieChanges[0].Val) + valWithMetadata2 := append([]byte("value"), key2...) + valWithMetadata2 = append(valWithMetadata2, address...) + assert.Equal(t, valWithMetadata2, accountStateChange.DataTrieChanges[1].Val) +} + +func stepMigrateDataTrieValAndChangeCode( + t *testing.T, + adb *state.AccountsDB, + address []byte, +) { + marshaller := &marshallerMock.MarshalizerMock{} + hasher := &hashingMocks.HasherMock{} + + acc, _ := adb.LoadAccount(address) + userAcc := acc.(state.UserAccountHandler) + oldCodeHash := userAcc.GetCodeHash() + code := []byte("new smart contract code") + userAcc.SetCode(code) + _ = userAcc.SaveKeyValue([]byte("key1"), []byte("value1")) + _ = adb.SaveAccount(userAcc) + + stateChanges, err := adb.GetStateChangesForTheLatestTransaction() + assert.Nil(t, err) + assert.Equal(t, 3, len(stateChanges)) + + oldCodeEntryChange := stateChanges[0] + assert.Equal(t, oldCodeHash, oldCodeEntryChange.MainTrieKey) + assert.Equal(t, []byte(nil), oldCodeEntryChange.MainTrieVal) + assert.Equal(t, 0, len(oldCodeEntryChange.DataTrieChanges)) + + newCodeEntryChange := stateChanges[1] + codeEntry := &state.CodeEntry{ + Code: code, + NumReferences: 1, + } + serializedCodeEntry, _ := marshaller.Marshal(codeEntry) + assert.Equal(t, userAcc.GetCodeHash(), newCodeEntryChange.MainTrieKey) + assert.Equal(t, serializedCodeEntry, newCodeEntryChange.MainTrieVal) + assert.Equal(t, 0, len(newCodeEntryChange.DataTrieChanges)) + + accountStateChange := stateChanges[2] + serializedAcc, _ := marshaller.Marshal(userAcc) + assert.Equal(t, address, accountStateChange.MainTrieKey) + assert.Equal(t, serializedAcc, accountStateChange.MainTrieVal) + assert.Equal(t, 2, len(accountStateChange.DataTrieChanges)) + trieVal := &dataTrieValue.TrieLeafData{ + Value: []byte("value1"), + Key: []byte("key1"), + Address: address, + } + serializedTrieVal, _ := marshaller.Marshal(trieVal) + assert.Equal(t, hasher.Compute("key1"), accountStateChange.DataTrieChanges[0].Key) + assert.Equal(t, serializedTrieVal, accountStateChange.DataTrieChanges[0].Val) + assert.Equal(t, []byte("key1"), accountStateChange.DataTrieChanges[1].Key) + assert.Equal(t, []byte(nil), accountStateChange.DataTrieChanges[1].Val) +} func TestAccountsDB_SaveAccountWithSomeValuesShouldWork(t *testing.T) { t.Parallel() @@ -2058,7 +2178,7 @@ func TestAccountsDB_CommitAddsDirtyHashesToCheckpointHashesHolder(t *testing.T) }, } - adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) accountsAddresses := generateAccounts(t, 3, adb) newHashes, _ = tr.GetDirtyHashes() @@ -2101,7 +2221,7 @@ func TestAccountsDB_CommitSetsStateCheckpointIfCheckpointHashesHolderIsFull(t *t }, } - adb, tr, trieStorage := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, tr, trieStorage := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) accountsAddresses := generateAccounts(t, 3, adb) newHashes = modifyDataTries(t, accountsAddresses, adb) @@ -2131,7 +2251,7 @@ func TestAccountsDB_SnapshotStateCleansCheckpointHashesHolder(t *testing.T) { return false }, } - adb, tr, trieStorage := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, tr, trieStorage := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) _ = trieStorage.Put([]byte(common.ActiveDBKey), []byte(common.ActiveDBVal)) accountsAddresses := generateAccounts(t, 3, adb) @@ -2152,7 +2272,7 @@ func TestAccountsDB_SetStateCheckpointCommitsOnlyMissingData(t *testing.T) { t.Parallel() checkpointHashesHolder := hashesHolder.NewCheckpointHashesHolder(100000, testscommon.HashSize) - adb, tr, trieStorage := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, tr, trieStorage := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) accountsAddresses := generateAccounts(t, 3, adb) rootHash, _ := tr.RootHash() @@ -2229,7 +2349,7 @@ func TestAccountsDB_CheckpointHashesHolderReceivesOnly32BytesData(t *testing.T) return false }, } - adb, _, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, _, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) accountsAddresses := generateAccounts(t, 3, adb) _ = modifyDataTries(t, accountsAddresses, adb) @@ -2250,7 +2370,7 @@ func TestAccountsDB_PruneRemovesDataFromCheckpointHashesHolder(t *testing.T) { removeCalled++ }, } - adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock()) + adb, tr, _ := getDefaultStateComponents(checkpointHashesHolder, testscommon.NewSnapshotPruningStorerMock(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) accountsAddresses := generateAccounts(t, 3, adb) newHashes, _ = tr.GetDirtyHashes() diff --git a/state/trackableDataTrie/export_test.go b/state/trackableDataTrie/export_test.go index cf76a31be37..ae44bbbdf7e 100644 --- a/state/trackableDataTrie/export_test.go +++ b/state/trackableDataTrie/export_test.go @@ -21,3 +21,9 @@ func (tdt *trackableDataTrie) DirtyData() map[string]DirtyData { return dd } + +// GetValueForVersion - +func (tdt *trackableDataTrie) GetValueForVersion(key []byte, val []byte, version core.TrieNodeVersion) []byte { + valWithMetadata, _ := tdt.getValueForVersion(key, val, version) + return valWithMetadata +} diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index 7269b30d0b6..26e4d7ce1cb 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -379,21 +379,19 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Parallel() identifier := []byte("identifier") - expectedKey := []byte("key") - expectedVal := []byte("value") - value := append(expectedVal, expectedKey...) - value = append(value, identifier...) hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} deleteCalled := false updateCalled := false - - trieVal := &dataTrieValue.TrieLeafData{ - Value: expectedVal, - Key: expectedKey, - Address: identifier, + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, } - serializedTrieVal, _ := marshaller.Marshal(trieVal) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + + expectedKey := []byte("key") + expectedVal := []byte("value") + value := tdt.GetValueForVersion(expectedKey, expectedVal, core.NotSpecified) + serializedTrieVal := tdt.GetValueForVersion(expectedKey, expectedVal, core.AutoBalanceEnabled) trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -414,11 +412,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, expectedVal) @@ -440,13 +433,17 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Parallel() identifier := []byte("identifier") - expectedKey := []byte("key") - val := []byte("value") - expectedVal := append(val, expectedKey...) - expectedVal = append(expectedVal, identifier...) hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + + expectedKey := []byte("key") + val := []byte("value") + expectedVal := tdt.GetValueForVersion(expectedKey, val, core.NotSpecified) trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -466,11 +463,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) @@ -489,26 +481,19 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Parallel() identifier := []byte("identifier") - expectedKey := []byte("key") - newVal := []byte("value") - oldVal := []byte("old val") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false - - oldTrieVal := &dataTrieValue.TrieLeafData{ - Value: oldVal, - Key: expectedKey, - Address: identifier, + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, } - serializedOldTrieVal, _ := marshaller.Marshal(oldTrieVal) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) - newTrieVal := &dataTrieValue.TrieLeafData{ - Value: newVal, - Key: expectedKey, - Address: identifier, - } - serializedNewTrieVal, _ := marshaller.Marshal(newTrieVal) + expectedKey := []byte("key") + newVal := []byte("value") + oldVal := []byte("old val") + serializedOldTrieVal := tdt.GetValueForVersion(expectedKey, oldVal, core.AutoBalanceEnabled) + serializedNewTrieVal := tdt.GetValueForVersion(expectedKey, newVal, core.AutoBalanceEnabled) trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -528,11 +513,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -551,18 +531,17 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Parallel() identifier := []byte("identifier") - expectedKey := []byte("key") - newVal := []byte("value") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false - - newTrieVal := &dataTrieValue.TrieLeafData{ - Value: newVal, - Key: expectedKey, - Address: identifier, + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, } - serializedNewTrieVal, _ := marshaller.Marshal(newTrieVal) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + + expectedKey := []byte("key") + newVal := []byte("value") + serializedNewTrieVal := tdt.GetValueForVersion(expectedKey, newVal, core.AutoBalanceEnabled) trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -579,11 +558,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -755,13 +729,22 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Parallel() identifier := []byte("identifier") - expectedKey := []byte("key") - newVal := []byte("value") - valueWithMetadata := append(newVal, expectedKey...) - valueWithMetadata = append(valueWithMetadata, identifier...) hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + identifier, + hasher, + marshaller, + enableEpochsHandler, + ) + + expectedKey := []byte("key") + newVal := []byte("value") + valueWithMetadata := tdt.GetValueForVersion(expectedKey, newVal, core.NotSpecified) trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -778,16 +761,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - ) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -801,6 +774,78 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { assert.Equal(t, expectedKey, stateChanges[0].Key) assert.Equal(t, valueWithMetadata, stateChanges[0].Val) }) + + t.Run("state changes are ordered deterministically", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + identifier, + hasher, + marshaller, + enableEpochsHandler, + ) + + key1 := "key1" + key2 := "key2" + key3 := "key3" + key4 := "key4" + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, []byte(key1)) { + return tdt.GetValueForVersion([]byte(key1), []byte("value1"), core.NotSpecified), 0, nil + } + if bytes.Equal(key, []byte(key2)) { + return tdt.GetValueForVersion([]byte(key2), []byte("value2"), core.NotSpecified), 0, nil + } + if bytes.Equal(key, hasher.Compute(key3)) { + return tdt.GetValueForVersion([]byte(key3), []byte("value3"), core.AutoBalanceEnabled), 0, nil + } + if bytes.Equal(key, hasher.Compute(key4)) { + return tdt.GetValueForVersion([]byte(key4), []byte("value4"), core.AutoBalanceEnabled), 0, nil + } + + return nil, 0, nil + }, + UpdateWithVersionCalled: func(_, _ []byte, _ core.TrieNodeVersion) error { + return nil + }, + DeleteCalled: func(_ []byte) error { + return nil + }, + } + tdt.SetDataTrie(trie) + + _ = tdt.SaveKeyValue([]byte(key1), []byte("value")) + _ = tdt.SaveKeyValue([]byte(key2), []byte("value")) + _ = tdt.SaveKeyValue([]byte(key3), []byte("value")) + _ = tdt.SaveKeyValue([]byte(key4), nil) + _ = tdt.SaveKeyValue([]byte("non existent key"), nil) + + stateChanges, oldVals, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 5, len(oldVals)) + assert.Equal(t, 6, len(stateChanges)) + + assert.Equal(t, hasher.Compute(key1), stateChanges[0].Key) + assert.Equal(t, tdt.GetValueForVersion([]byte(key1), []byte("value"), core.AutoBalanceEnabled), stateChanges[0].Val) + assert.Equal(t, hasher.Compute(key2), stateChanges[1].Key) + assert.Equal(t, tdt.GetValueForVersion([]byte(key2), []byte("value"), core.AutoBalanceEnabled), stateChanges[1].Val) + assert.Equal(t, hasher.Compute(key3), stateChanges[2].Key) + assert.Equal(t, tdt.GetValueForVersion([]byte(key3), []byte("value"), core.AutoBalanceEnabled), stateChanges[2].Val) + assert.Equal(t, hasher.Compute(key4), stateChanges[3].Key) + assert.Equal(t, []byte(nil), stateChanges[3].Val) + assert.Equal(t, []byte(key1), stateChanges[4].Key) + assert.Equal(t, []byte(nil), stateChanges[4].Val) + assert.Equal(t, []byte(key2), stateChanges[5].Key) + assert.Equal(t, []byte(nil), stateChanges[5].Val) + }) } func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { From ed0f9d60c1831a20a1d44617ab5108becb7626cd Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 14 Sep 2023 14:47:39 +0300 Subject: [PATCH 03/10] get state changes for each transaction --- .../disabled/disabledAccountsAdapter.go | 5 ++ process/transaction/metaProcess.go | 8 +++ process/transaction/shardProcess.go | 8 +++ .../simulationAccountsDB.go | 5 ++ state/accountsDB.go | 7 ++ state/accountsDBApi.go | 5 ++ state/accountsDBApiWithHistory.go | 5 ++ state/interface.go | 1 + state/trackableDataTrie/trackableDataTrie.go | 13 ++-- testscommon/state/accountsAdapterStub.go | 64 +++++++++++-------- 10 files changed, 88 insertions(+), 33 deletions(-) diff --git a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go index 61e06df194d..431b2fadf4d 100644 --- a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go +++ b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go @@ -137,6 +137,11 @@ func (a *accountsAdapter) GetStackDebugFirstEntry() []byte { return nil } +// GetStateChangesForTheLatestTransaction - +func (a *accountsAdapter) GetStateChangesForTheLatestTransaction() ([]state.StateChangeDTO, error) { + return nil, nil +} + // Close - func (a *accountsAdapter) Close() error { return nil diff --git a/process/transaction/metaProcess.go b/process/transaction/metaProcess.go index 51f2c721552..718dffc3c07 100644 --- a/process/transaction/metaProcess.go +++ b/process/transaction/metaProcess.go @@ -132,6 +132,14 @@ func (txProc *metaTxProcessor) ProcessTransaction(tx *transaction.Transaction) ( txType, _ := txProc.txTypeHandler.ComputeTransactionType(tx) + defer func() { + // TODO collect state changes from each transactions here + _, err = txProc.accounts.GetStateChangesForTheLatestTransaction() + if err != nil { + log.Error("GetStateChangesForTheLatestTransaction error", "err", err.Error()) + } + }() + switch txType { case process.SCDeployment: return txProc.processSCDeployment(tx, tx.SndAddr) diff --git a/process/transaction/shardProcess.go b/process/transaction/shardProcess.go index ea8eb375c56..1e175d9daed 100644 --- a/process/transaction/shardProcess.go +++ b/process/transaction/shardProcess.go @@ -211,6 +211,14 @@ func (txProc *txProcessor) ProcessTransaction(tx *transaction.Transaction) (vmco return vmcommon.UserError, err } + defer func() { + // TODO collect state changes from each transactions here + _, err = txProc.accounts.GetStateChangesForTheLatestTransaction() + if err != nil { + log.Error("GetStateChangesForTheLatestTransaction error", "err", err.Error()) + } + }() + switch txType { case process.MoveBalance: err = txProc.processMoveBalance(tx, acntSnd, acntDst, dstShardTxType, nil, false) diff --git a/process/transactionEvaluator/simulationAccountsDB.go b/process/transactionEvaluator/simulationAccountsDB.go index 25af794e196..46d87c7ab6c 100644 --- a/process/transactionEvaluator/simulationAccountsDB.go +++ b/process/transactionEvaluator/simulationAccountsDB.go @@ -176,6 +176,11 @@ func (r *simulationAccountsDB) GetStackDebugFirstEntry() []byte { return nil } +// GetStateChangesForTheLatestTransaction - +func (r *simulationAccountsDB) GetStateChangesForTheLatestTransaction() ([]state.StateChangeDTO, error) { + return r.originalAccounts.GetStateChangesForTheLatestTransaction() +} + // Close will handle the closing of the underlying components func (r *simulationAccountsDB) Close() error { return nil diff --git a/state/accountsDB.go b/state/accountsDB.go index 8446d8e12b7..05cb71f2acb 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -881,6 +881,12 @@ func (adb *AccountsDB) commit() ([]byte, error) { log.Trace("accountsDB.Commit started") adb.entries = make([]JournalEntry, 0) + stateChanges := adb.stateChangesCollector.GetStateChanges() + if len(stateChanges) != 0 { + log.Warn("state changes collector is not empty", "state changes", stateChanges) + adb.stateChangesCollector.Reset() + } + oldHashes := make(common.ModifiedHashes) newHashes := make(common.ModifiedHashes) // Step 1. commit all data tries @@ -1267,6 +1273,7 @@ func collectStats( log.Debug(strings.Join(trieStats.ToString(), " ")) } +// GetStateChangesForTheLatestTransaction will return the state changes since the last call of this method func (adb *AccountsDB) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { stateChanges := adb.stateChangesCollector.GetStateChanges() adb.stateChangesCollector.Reset() diff --git a/state/accountsDBApi.go b/state/accountsDBApi.go index 89c2a27a636..c7a56b16d6e 100644 --- a/state/accountsDBApi.go +++ b/state/accountsDBApi.go @@ -221,6 +221,11 @@ func (accountsDB *accountsDBApi) GetStackDebugFirstEntry() []byte { return accountsDB.innerAccountsAdapter.GetStackDebugFirstEntry() } +// GetStateChangesForTheLatestTransaction will call the inner accountsAdapter method +func (accountsDB *accountsDBApi) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { + return accountsDB.innerAccountsAdapter.GetStateChangesForTheLatestTransaction() +} + // Close will handle the closing of the underlying components func (accountsDB *accountsDBApi) Close() error { return accountsDB.innerAccountsAdapter.Close() diff --git a/state/accountsDBApiWithHistory.go b/state/accountsDBApiWithHistory.go index 97d698e0b68..b87fb00e480 100644 --- a/state/accountsDBApiWithHistory.go +++ b/state/accountsDBApiWithHistory.go @@ -144,6 +144,11 @@ func (accountsDB *accountsDBApiWithHistory) GetStackDebugFirstEntry() []byte { return nil } +// GetStateChangesForTheLatestTransaction returns nil +func (accountsDB *accountsDBApiWithHistory) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { + return nil, nil +} + // Close will handle the closing of the underlying components func (accountsDB *accountsDBApiWithHistory) Close() error { return accountsDB.innerAccountsAdapter.Close() diff --git a/state/interface.go b/state/interface.go index c42ef35256b..4d11bea9ddf 100644 --- a/state/interface.go +++ b/state/interface.go @@ -50,6 +50,7 @@ type AccountsAdapter interface { GetStackDebugFirstEntry() []byte SetSyncer(syncer AccountsDBSyncer) error StartSnapshotIfNeeded() error + GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) Close() error IsInterfaceNil() bool } diff --git a/state/trackableDataTrie/trackableDataTrie.go b/state/trackableDataTrie/trackableDataTrie.go index 7ff84ddcfaf..26167769006 100644 --- a/state/trackableDataTrie/trackableDataTrie.go +++ b/state/trackableDataTrie/trackableDataTrie.go @@ -234,7 +234,7 @@ func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]state.DataT func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]state.DataTrieChange, []core.TrieData, error) { oldValues := make([]core.TrieData, len(tdt.dirtyData)) - stateChanges := make([]state.DataTrieChange, len(tdt.dirtyData)) + newData := make([]state.DataTrieChange, len(tdt.dirtyData)) deletedKeys := make([]state.DataTrieChange, 0) index := 0 @@ -270,11 +270,11 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]state.DataTrieCh continue } - if dataEntry.index > len(stateChanges)-1 { + if dataEntry.index > len(newData)-1 { return nil, nil, fmt.Errorf("index out of range") } - stateChanges[dataEntry.index] = state.DataTrieChange{ + newData[dataEntry.index] = state.DataTrieChange{ Key: dataTrieKey, Val: dataTrieVal, } @@ -282,12 +282,13 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]state.DataTrieCh tdt.dirtyData = make(map[string]dirtyData) - for i := range stateChanges { - if len(stateChanges[i].Key) != 0 { + stateChanges := make([]state.DataTrieChange, 0) + for i := range newData { + if len(newData[i].Key) == 0 { continue } - stateChanges = append(stateChanges[:i], stateChanges[i+1:]...) + stateChanges = append(stateChanges, newData[i]) } sort.Slice(deletedKeys, func(i, j int) bool { diff --git a/testscommon/state/accountsAdapterStub.go b/testscommon/state/accountsAdapterStub.go index c5cf9f74535..6b1af94a1fc 100644 --- a/testscommon/state/accountsAdapterStub.go +++ b/testscommon/state/accountsAdapterStub.go @@ -13,33 +13,34 @@ var errNotImplemented = errors.New("not implemented") // AccountsStub - type AccountsStub struct { - GetExistingAccountCalled func(addressContainer []byte) (vmcommon.AccountHandler, error) - GetAccountFromBytesCalled func(address []byte, accountBytes []byte) (vmcommon.AccountHandler, error) - LoadAccountCalled func(container []byte) (vmcommon.AccountHandler, error) - SaveAccountCalled func(account vmcommon.AccountHandler) error - RemoveAccountCalled func(addressContainer []byte) error - CommitCalled func() ([]byte, error) - CommitInEpochCalled func(uint32, uint32) ([]byte, error) - JournalLenCalled func() int - RevertToSnapshotCalled func(snapshot int) error - RootHashCalled func() ([]byte, error) - RecreateTrieCalled func(rootHash []byte) error - RecreateTrieFromEpochCalled func(options common.RootHashHolder) error - PruneTrieCalled func(rootHash []byte, identifier state.TriePruningIdentifier, handler state.PruningHandler) - CancelPruneCalled func(rootHash []byte, identifier state.TriePruningIdentifier) - SnapshotStateCalled func(rootHash []byte, epoch uint32) - SetStateCheckpointCalled func(rootHash []byte) - IsPruningEnabledCalled func() bool - GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error - RecreateAllTriesCalled func(rootHash []byte) (map[string]common.Trie, error) - GetCodeCalled func([]byte) []byte - GetTrieCalled func([]byte) (common.Trie, error) - GetStackDebugFirstEntryCalled func() []byte - GetAccountWithBlockInfoCalled func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) - GetCodeWithBlockInfoCalled func(codeHash []byte, options common.RootHashHolder) ([]byte, common.BlockInfo, error) - CloseCalled func() error - SetSyncerCalled func(syncer state.AccountsDBSyncer) error - StartSnapshotIfNeededCalled func() error + GetExistingAccountCalled func(addressContainer []byte) (vmcommon.AccountHandler, error) + GetAccountFromBytesCalled func(address []byte, accountBytes []byte) (vmcommon.AccountHandler, error) + LoadAccountCalled func(container []byte) (vmcommon.AccountHandler, error) + SaveAccountCalled func(account vmcommon.AccountHandler) error + RemoveAccountCalled func(addressContainer []byte) error + CommitCalled func() ([]byte, error) + CommitInEpochCalled func(uint32, uint32) ([]byte, error) + JournalLenCalled func() int + RevertToSnapshotCalled func(snapshot int) error + RootHashCalled func() ([]byte, error) + RecreateTrieCalled func(rootHash []byte) error + RecreateTrieFromEpochCalled func(options common.RootHashHolder) error + PruneTrieCalled func(rootHash []byte, identifier state.TriePruningIdentifier, handler state.PruningHandler) + CancelPruneCalled func(rootHash []byte, identifier state.TriePruningIdentifier) + SnapshotStateCalled func(rootHash []byte, epoch uint32) + SetStateCheckpointCalled func(rootHash []byte) + IsPruningEnabledCalled func() bool + GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error + RecreateAllTriesCalled func(rootHash []byte) (map[string]common.Trie, error) + GetCodeCalled func([]byte) []byte + GetTrieCalled func([]byte) (common.Trie, error) + GetStackDebugFirstEntryCalled func() []byte + GetAccountWithBlockInfoCalled func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) + GetCodeWithBlockInfoCalled func(codeHash []byte, options common.RootHashHolder) ([]byte, common.BlockInfo, error) + CloseCalled func() error + SetSyncerCalled func(syncer state.AccountsDBSyncer) error + StartSnapshotIfNeededCalled func() error + GetStateChangesForTheLatestTransactionCalled func() ([]state.StateChangeDTO, error) } // CleanCache - @@ -265,6 +266,15 @@ func (as *AccountsStub) GetCodeWithBlockInfo(codeHash []byte, options common.Roo return nil, nil, nil } +// GetStateChangesForTheLatestTransaction - +func (as *AccountsStub) GetStateChangesForTheLatestTransaction() ([]state.StateChangeDTO, error) { + if as.GetStateChangesForTheLatestTransactionCalled != nil { + return as.GetStateChangesForTheLatestTransactionCalled() + } + + return nil, nil +} + // Close - func (as *AccountsStub) Close() error { if as.CloseCalled != nil { From a05808d64c28ce091c2ac8b888af0e53f2ca1ee7 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Fri, 15 Sep 2023 15:01:12 +0300 Subject: [PATCH 04/10] pass stateChangesCollector as argument --- epochStart/metachain/systemSCs_test.go | 1 + factory/api/apiResolverFactory.go | 2 ++ .../processing/blockProcessorCreator_test.go | 1 + factory/state/stateComponents.go | 4 +++ genesis/process/memoryComponents.go | 1 + .../state/stateTrie/stateTrie_test.go | 2 ++ integrationTests/testInitializer.go | 1 + state/accountsDB.go | 6 +++- state/accountsDB_test.go | 2 ++ .../disabled/disabledStateChangesCollector.go | 34 +++++++++++++++++++ state/errors.go | 3 ++ .../factory/accountsAdapterAPICreator_test.go | 1 + .../storagePruningManager_test.go | 1 + testscommon/integrationtests/factory.go | 1 + update/genesis/import.go | 3 ++ 15 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 state/disabled/disabledStateChangesCollector.go diff --git a/epochStart/metachain/systemSCs_test.go b/epochStart/metachain/systemSCs_test.go index e4e168e145e..6a3c6003398 100644 --- a/epochStart/metachain/systemSCs_test.go +++ b/epochStart/metachain/systemSCs_test.go @@ -903,6 +903,7 @@ func createAccountsDB( ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(args) return adb diff --git a/factory/api/apiResolverFactory.go b/factory/api/apiResolverFactory.go index bd5c1d4abc9..22f96edc915 100644 --- a/factory/api/apiResolverFactory.go +++ b/factory/api/apiResolverFactory.go @@ -37,6 +37,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/blockInfoProviders" + stateDisabled "github.com/multiversx/mx-chain-go/state/disabled" factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" @@ -607,6 +608,7 @@ func createNewAccountsAdapterApi(args *scQueryElementArgs, chainHandler data.Cha ProcessStatusHandler: args.coreComponents.ProcessStatusHandler(), AppStatusHandler: args.statusCoreComponents.AppStatusHandler(), AddressConverter: args.coreComponents.AddressPubKeyConverter(), + StateChangesCollector: stateDisabled.NewDisabledStateChangesCollector(), } provider, err := blockInfoProviders.NewCurrentBlockInfo(chainHandler) diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index f989bad2571..151a5728fd6 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -208,6 +208,7 @@ func createAccountAdapter( ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, err := state.NewAccountsDB(args) if err != nil { diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index baefcb6d590..9884724a327 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/factory" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/disabled" factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" @@ -135,6 +136,7 @@ func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common. ProcessStatusHandler: scf.core.ProcessStatusHandler(), AppStatusHandler: scf.statusCore.AppStatusHandler(), AddressConverter: scf.core.AddressPubKeyConverter(), + StateChangesCollector: state.NewStateChangesCollector(), } accountsAdapter, err := state.NewAccountsDB(argsProcessingAccountsDB) if err != nil { @@ -151,6 +153,7 @@ func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common. ProcessStatusHandler: scf.core.ProcessStatusHandler(), AppStatusHandler: scf.statusCore.AppStatusHandler(), AddressConverter: scf.core.AddressPubKeyConverter(), + StateChangesCollector: disabled.NewDisabledStateChangesCollector(), } accountsAdapterApiOnFinal, err := factoryState.CreateAccountsAdapterAPIOnFinal(argsAPIAccountsDB, scf.chainHandler) @@ -201,6 +204,7 @@ func (scf *stateComponentsFactory) createPeerAdapter(triesContainer common.Tries ProcessStatusHandler: scf.core.ProcessStatusHandler(), AppStatusHandler: scf.statusCore.AppStatusHandler(), AddressConverter: scf.core.AddressPubKeyConverter(), + StateChangesCollector: state.NewStateChangesCollector(), } peerAdapter, err := state.NewPeerAccountsDB(argsProcessingPeerAccountsDB) if err != nil { diff --git a/genesis/process/memoryComponents.go b/genesis/process/memoryComponents.go index 623c6f69f12..deb2b28ca51 100644 --- a/genesis/process/memoryComponents.go +++ b/genesis/process/memoryComponents.go @@ -36,6 +36,7 @@ func createAccountAdapter( ProcessStatusHandler: commonDisabled.NewProcessStatusHandler(), AppStatusHandler: commonDisabled.NewAppStatusHandler(), AddressConverter: addressConverter, + StateChangesCollector: state.NewStateChangesCollector(), } adb, err := state.NewAccountsDB(args) diff --git a/integrationTests/state/stateTrie/stateTrie_test.go b/integrationTests/state/stateTrie/stateTrie_test.go index f8a7bfae8c5..213c9a86fbe 100644 --- a/integrationTests/state/stateTrie/stateTrie_test.go +++ b/integrationTests/state/stateTrie/stateTrie_test.go @@ -1066,6 +1066,7 @@ func createAccounts( ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2530,6 +2531,7 @@ func createAccountsDBTestSetup() *state.AccountsDB { ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index 8d08a89f6fe..e318088ca50 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -473,6 +473,7 @@ func CreateAccountsDBWithEnableEpochsHandler( ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(args) diff --git a/state/accountsDB.go b/state/accountsDB.go index 05cb71f2acb..30615688849 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -110,6 +110,7 @@ type ArgsAccountsDB struct { ProcessStatusHandler common.ProcessStatusHandler AppStatusHandler core.AppStatusHandler AddressConverter core.PubkeyConverter + StateChangesCollector StateChangesCollector } // NewAccountsDB creates a new account manager @@ -163,7 +164,7 @@ func createAccountsDb(args ArgsAccountsDB, snapshotManager SnapshotsManager) *Ac }, addressConverter: args.AddressConverter, snapshotsManger: snapshotManager, - stateChangesCollector: NewStateChangesCollector(), + stateChangesCollector: args.StateChangesCollector, } } @@ -186,6 +187,9 @@ func checkArgsAccountsDB(args ArgsAccountsDB) error { if check.IfNil(args.AddressConverter) { return ErrNilAddressConverter } + if check.IfNil(args.StateChangesCollector) { + return ErrNilStateChangesCollector + } return nil } diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index ef961c2f54c..bc1faa059d3 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -68,6 +68,7 @@ func createMockAccountsDBArgs() state.ArgsAccountsDB { ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } } @@ -155,6 +156,7 @@ func getDefaultStateComponents( ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/state/disabled/disabledStateChangesCollector.go b/state/disabled/disabledStateChangesCollector.go new file mode 100644 index 00000000000..dc5d22bc1f2 --- /dev/null +++ b/state/disabled/disabledStateChangesCollector.go @@ -0,0 +1,34 @@ +package disabled + +import ( + "github.com/multiversx/mx-chain-go/state" +) + +// DisabledStateChangesCollector is a state changes collector that does nothing +type DisabledStateChangesCollector struct { +} + +// NewDisabledStateChangesCollector creates a new DisabledStateChangesCollector +func NewDisabledStateChangesCollector() *DisabledStateChangesCollector { + return &DisabledStateChangesCollector{} +} + +// AddStateChange does nothing +func (d *DisabledStateChangesCollector) AddStateChange(_ state.StateChangeDTO) { + +} + +// GetStateChanges returns an empty slice +func (d *DisabledStateChangesCollector) GetStateChanges() []state.StateChangeDTO { + return []state.StateChangeDTO{} +} + +// Reset does nothing +func (d *DisabledStateChangesCollector) Reset() { + +} + +// IsInterfaceNil returns true if there is no value under the interface +func (d *DisabledStateChangesCollector) IsInterfaceNil() bool { + return d == nil +} diff --git a/state/errors.go b/state/errors.go index 5a56aff40ff..d6cbd4697b8 100644 --- a/state/errors.go +++ b/state/errors.go @@ -144,3 +144,6 @@ var ErrNilStateMetrics = errors.New("nil sstate metrics") // ErrNilChannelsProvider signals that a nil channels provider has been given var ErrNilChannelsProvider = errors.New("nil channels provider") + +// ErrNilStateChangesCollector signals that a nil state changes collector has been given +var ErrNilStateChangesCollector = errors.New("nil state changes collector") diff --git a/state/factory/accountsAdapterAPICreator_test.go b/state/factory/accountsAdapterAPICreator_test.go index c6c579985c1..319a5ab2e8e 100644 --- a/state/factory/accountsAdapterAPICreator_test.go +++ b/state/factory/accountsAdapterAPICreator_test.go @@ -31,6 +31,7 @@ func createMockAccountsArgs() state.ArgsAccountsDB { ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } } diff --git a/state/storagePruningManager/storagePruningManager_test.go b/state/storagePruningManager/storagePruningManager_test.go index 104a198becd..40f9ba65206 100644 --- a/state/storagePruningManager/storagePruningManager_test.go +++ b/state/storagePruningManager/storagePruningManager_test.go @@ -54,6 +54,7 @@ func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state. ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/testscommon/integrationtests/factory.go b/testscommon/integrationtests/factory.go index 3abbabae250..8c427a34ed9 100644 --- a/testscommon/integrationtests/factory.go +++ b/testscommon/integrationtests/factory.go @@ -119,6 +119,7 @@ func CreateAccountsDB(db storage.Storer, enableEpochs common.EnableEpochsHandler ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + StateChangesCollector: state.NewStateChangesCollector(), } adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/update/genesis/import.go b/update/genesis/import.go index d0da6fac47c..efba2ff503c 100644 --- a/update/genesis/import.go +++ b/update/genesis/import.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" + stateDisabled "github.com/multiversx/mx-chain-go/state/disabled" "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" "github.com/multiversx/mx-chain-go/trie" @@ -422,6 +423,7 @@ func (si *stateImport) getAccountsDB(accType Type, shardID uint32, accountFactor ProcessStatusHandler: commonDisabled.NewProcessStatusHandler(), AppStatusHandler: commonDisabled.NewAppStatusHandler(), AddressConverter: si.addressConverter, + StateChangesCollector: stateDisabled.NewDisabledStateChangesCollector(), } accountsDB, errCreate := state.NewAccountsDB(argsAccountDB) if errCreate != nil { @@ -447,6 +449,7 @@ func (si *stateImport) getAccountsDB(accType Type, shardID uint32, accountFactor ProcessStatusHandler: commonDisabled.NewProcessStatusHandler(), AppStatusHandler: commonDisabled.NewAppStatusHandler(), AddressConverter: si.addressConverter, + StateChangesCollector: stateDisabled.NewDisabledStateChangesCollector(), } accountsDB, err = state.NewAccountsDB(argsAccountDB) si.accountDBsMap[shardID] = accountsDB From db454f15eb502ecfd71b4a3db823e8facb99e3d9 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Fri, 15 Sep 2023 16:10:49 +0300 Subject: [PATCH 05/10] add flag for collectStateChanges --- cmd/node/config/config.toml | 1 + config/config.go | 1 + factory/state/stateComponents.go | 14 ++++++++++++-- state/disabled/disabledStateChangesCollector.go | 2 +- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 310d6be8e8c..f5c82bb81ae 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -684,6 +684,7 @@ PeerStatePruningEnabled = true MaxStateTrieLevelInMemory = 5 MaxPeerTrieLevelInMemory = 5 + CollectStateChangesEnabled = false [BlockSizeThrottleConfig] MinSizeInBytes = 104857 # 104857 is 10% from 1MB diff --git a/config/config.go b/config/config.go index 37a98884e5d..e59f0dca160 100644 --- a/config/config.go +++ b/config/config.go @@ -297,6 +297,7 @@ type StateTriesConfig struct { SnapshotsEnabled bool AccountsStatePruningEnabled bool PeerStatePruningEnabled bool + CollectStateChangesEnabled bool MaxStateTrieLevelInMemory uint MaxPeerTrieLevelInMemory uint } diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 9884724a327..e0959b722d9 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -125,6 +125,11 @@ func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common. return nil, nil, nil, err } + stateChangesCollector := disabled.NewDisabledStateChangesCollector() + if scf.config.StateTriesConfig.CollectStateChangesEnabled { + stateChangesCollector = state.NewStateChangesCollector() + } + argsProcessingAccountsDB := state.ArgsAccountsDB{ Trie: merkleTrie, Hasher: scf.core.Hasher(), @@ -136,7 +141,7 @@ func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common. ProcessStatusHandler: scf.core.ProcessStatusHandler(), AppStatusHandler: scf.statusCore.AppStatusHandler(), AddressConverter: scf.core.AddressPubKeyConverter(), - StateChangesCollector: state.NewStateChangesCollector(), + StateChangesCollector: stateChangesCollector, } accountsAdapter, err := state.NewAccountsDB(argsProcessingAccountsDB) if err != nil { @@ -193,6 +198,11 @@ func (scf *stateComponentsFactory) createPeerAdapter(triesContainer common.Tries return nil, err } + stateChangesCollector := disabled.NewDisabledStateChangesCollector() + if scf.config.StateTriesConfig.CollectStateChangesEnabled { + stateChangesCollector = state.NewStateChangesCollector() + } + argsProcessingPeerAccountsDB := state.ArgsAccountsDB{ Trie: merkleTrie, Hasher: scf.core.Hasher(), @@ -204,7 +214,7 @@ func (scf *stateComponentsFactory) createPeerAdapter(triesContainer common.Tries ProcessStatusHandler: scf.core.ProcessStatusHandler(), AppStatusHandler: scf.statusCore.AppStatusHandler(), AddressConverter: scf.core.AddressPubKeyConverter(), - StateChangesCollector: state.NewStateChangesCollector(), + StateChangesCollector: stateChangesCollector, } peerAdapter, err := state.NewPeerAccountsDB(argsProcessingPeerAccountsDB) if err != nil { diff --git a/state/disabled/disabledStateChangesCollector.go b/state/disabled/disabledStateChangesCollector.go index dc5d22bc1f2..3d1407d3be5 100644 --- a/state/disabled/disabledStateChangesCollector.go +++ b/state/disabled/disabledStateChangesCollector.go @@ -9,7 +9,7 @@ type DisabledStateChangesCollector struct { } // NewDisabledStateChangesCollector creates a new DisabledStateChangesCollector -func NewDisabledStateChangesCollector() *DisabledStateChangesCollector { +func NewDisabledStateChangesCollector() state.StateChangesCollector { return &DisabledStateChangesCollector{} } From 01d832b4a76b933d3e0158ae46a33b8fd4aebd30 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Fri, 15 Sep 2023 16:24:04 +0300 Subject: [PATCH 06/10] add comments and unit test for stateChangesCollector --- state/stateChangesCollector.go | 7 +++ state/stateChangesCollector_test.go | 68 +++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 state/stateChangesCollector_test.go diff --git a/state/stateChangesCollector.go b/state/stateChangesCollector.go index fe69cef0184..9c23b1db464 100644 --- a/state/stateChangesCollector.go +++ b/state/stateChangesCollector.go @@ -1,11 +1,13 @@ package state +// StateChangeDTO is used to collect state changes type StateChangeDTO struct { MainTrieKey []byte MainTrieVal []byte DataTrieChanges []DataTrieChange } +// DataTrieChange represents a change in the data trie type DataTrieChange struct { Key []byte Val []byte @@ -15,24 +17,29 @@ type stateChangesCollector struct { stateChanges []StateChangeDTO } +// NewStateChangesCollector creates a new StateChangesCollector func NewStateChangesCollector() *stateChangesCollector { return &stateChangesCollector{ stateChanges: []StateChangeDTO{}, } } +// AddStateChange adds a new state change to the collector func (scc *stateChangesCollector) AddStateChange(stateChange StateChangeDTO) { scc.stateChanges = append(scc.stateChanges, stateChange) } +// GetStateChanges returns the accumulated state changes func (scc *stateChangesCollector) GetStateChanges() []StateChangeDTO { return scc.stateChanges } +// Reset resets the state changes collector func (scc *stateChangesCollector) Reset() { scc.stateChanges = []StateChangeDTO{} } +// IsInterfaceNil returns true if there is no value under the interface func (scc *stateChangesCollector) IsInterfaceNil() bool { return scc == nil } diff --git a/state/stateChangesCollector_test.go b/state/stateChangesCollector_test.go new file mode 100644 index 00000000000..6c678ea2a83 --- /dev/null +++ b/state/stateChangesCollector_test.go @@ -0,0 +1,68 @@ +package state + +import ( + "github.com/stretchr/testify/assert" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewStateChangesCollector(t *testing.T) { + t.Parallel() + + stateChangesCollector := NewStateChangesCollector() + require.False(t, stateChangesCollector.IsInterfaceNil()) +} + +func TestStateChangesCollector_AddStateChange(t *testing.T) { + t.Parallel() + + scc := NewStateChangesCollector() + assert.Equal(t, 0, len(scc.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + scc.AddStateChange(StateChangeDTO{}) + } + assert.Equal(t, numStateChanges, len(scc.stateChanges)) +} + +func TestStateChangesCollector_GetStateChanges(t *testing.T) { + t.Parallel() + + scc := NewStateChangesCollector() + assert.Equal(t, 0, len(scc.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + scc.AddStateChange(StateChangeDTO{ + MainTrieKey: []byte(strconv.Itoa(i)), + }) + } + assert.Equal(t, numStateChanges, len(scc.stateChanges)) + + stateChanges := scc.GetStateChanges() + assert.Equal(t, numStateChanges, len(stateChanges)) + for i := 0; i < numStateChanges; i++ { + assert.Equal(t, []byte(strconv.Itoa(i)), stateChanges[i].MainTrieKey) + } +} + +func TestStateChangesCollector_Reset(t *testing.T) { + t.Parallel() + + scc := NewStateChangesCollector() + assert.Equal(t, 0, len(scc.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + scc.AddStateChange(StateChangeDTO{}) + } + assert.Equal(t, numStateChanges, len(scc.stateChanges)) + stateChanges := scc.GetStateChanges() + + scc.Reset() + assert.Equal(t, 0, len(scc.stateChanges)) + assert.Equal(t, numStateChanges, len(stateChanges)) +} From 7e6aa457941d553c489c542f9eb27ce85033c23b Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Fri, 26 Jan 2024 17:40:20 +0200 Subject: [PATCH 07/10] fix after merge --- state/accountsDB_test.go | 9 ++-- .../trackableDataTrie_test.go | 51 ++++++++----------- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 5f823d28b88..847f4e244cf 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -42,8 +42,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" - "github.com/multiversx/mx-chain-go/trie/hashesHolder" - disabledHashesHolder "github.com/multiversx/mx-chain-go/trie/hashesHolder/disabled" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -427,14 +425,17 @@ func TestAccountsDB_SaveAccountMalfunctionMarshallerShouldErr(t *testing.T) { func TestAccountsDB_SaveAccountCollectsAllStateChanges(t *testing.T) { t.Parallel() + autoBalanceFlagEnabled := false enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return autoBalanceFlagEnabled + }, } _, adb := getDefaultStateComponentsWithCustomEnableEpochs(enableEpochs) address := generateRandomByteArray(32) stepCreateAccountWithDataTrieAndCode(t, adb, address) - enableEpochs.IsAutoBalanceDataTriesEnabledField = true + autoBalanceFlagEnabled = true stepMigrateDataTrieValAndChangeCode(t, adb, address) } diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index 646502a4536..de995a7082a 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -406,7 +406,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { deleteCalled := false updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) @@ -435,12 +437,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return flag == common.AutoBalanceDataTriesFlag - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, expectedVal) @@ -468,8 +465,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return false + }, } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) expectedKey := []byte("key") @@ -495,12 +495,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return false - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) @@ -523,7 +518,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) @@ -552,12 +549,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return flag == common.AutoBalanceDataTriesFlag - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -580,7 +572,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) @@ -604,12 +598,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return flag == common.AutoBalanceDataTriesFlag - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -789,7 +778,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return false + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie( identifier, @@ -888,7 +879,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { stateChanges, oldVals, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) - assert.Equal(t, 5, len(oldVals)) + assert.Equal(t, 7, len(oldVals)) assert.Equal(t, 6, len(stateChanges)) assert.Equal(t, hasher.Compute(key1), stateChanges[0].Key) From 33b6daf7b22b13638e625d1b284df4c588032581 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Tue, 30 Jan 2024 11:56:09 +0200 Subject: [PATCH 08/10] correlate state changes to a tx hash --- .../disabled/disabledAccountsAdapter.go | 10 +- process/transaction/metaProcess.go | 10 +- process/transaction/shardProcess.go | 10 +- .../simulationAccountsDB.go | 11 ++- state/accountsDB.go | 26 ++--- state/accountsDBApi.go | 11 ++- state/accountsDBApiWithHistory.go | 10 +- state/accountsDB_test.go | 21 ++-- .../disabled/disabledStateChangesCollector.go | 22 +++-- state/interface.go | 6 +- state/stateChangesCollector.go | 38 ++++++-- state/stateChangesCollector_test.go | 95 ++++++++++++++++--- testscommon/state/accountsAdapterStub.go | 73 +++++++------- 13 files changed, 231 insertions(+), 112 deletions(-) diff --git a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go index 56c29e31331..f2ac368babf 100644 --- a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go +++ b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go @@ -133,9 +133,13 @@ func (a *accountsAdapter) GetStackDebugFirstEntry() []byte { return nil } -// GetStateChangesForTheLatestTransaction - -func (a *accountsAdapter) GetStateChangesForTheLatestTransaction() ([]state.StateChangeDTO, error) { - return nil, nil +// SetTxHashForLatestStateChanges - +func (a *accountsAdapter) SetTxHashForLatestStateChanges(_ []byte) { +} + +// ResetStateChangesCollector - +func (a *accountsAdapter) ResetStateChangesCollector() []state.StateChangesForTx { + return nil } // Close - diff --git a/process/transaction/metaProcess.go b/process/transaction/metaProcess.go index 58965dadd4c..f4a10a0fb6c 100644 --- a/process/transaction/metaProcess.go +++ b/process/transaction/metaProcess.go @@ -124,6 +124,8 @@ func (txProc *metaTxProcessor) ProcessTransaction(tx *transaction.Transaction) ( txProc.pubkeyConv, ) + defer txProc.accounts.SetTxHashForLatestStateChanges(txHash) + err = txProc.checkTxValues(tx, acntSnd, acntDst, false) if err != nil { if errors.Is(err, process.ErrUserNameDoesNotMatchInCrossShardTx) { @@ -138,14 +140,6 @@ func (txProc *metaTxProcessor) ProcessTransaction(tx *transaction.Transaction) ( txType, _ := txProc.txTypeHandler.ComputeTransactionType(tx) - defer func() { - // TODO collect state changes from each transactions here - _, err = txProc.accounts.GetStateChangesForTheLatestTransaction() - if err != nil { - log.Error("GetStateChangesForTheLatestTransaction error", "err", err.Error()) - } - }() - switch txType { case process.SCDeployment: return txProc.processSCDeployment(tx, tx.SndAddr) diff --git a/process/transaction/shardProcess.go b/process/transaction/shardProcess.go index fee263640ba..43c5b4878bf 100644 --- a/process/transaction/shardProcess.go +++ b/process/transaction/shardProcess.go @@ -195,6 +195,8 @@ func (txProc *txProcessor) ProcessTransaction(tx *transaction.Transaction) (vmco txProc.pubkeyConv, ) + defer txProc.accounts.SetTxHashForLatestStateChanges(txHash) + txType, dstShardTxType := txProc.txTypeHandler.ComputeTransactionType(tx) err = txProc.checkTxValues(tx, acntSnd, acntDst, false) if err != nil { @@ -222,14 +224,6 @@ func (txProc *txProcessor) ProcessTransaction(tx *transaction.Transaction) (vmco return vmcommon.UserError, err } - defer func() { - // TODO collect state changes from each transactions here - _, err = txProc.accounts.GetStateChangesForTheLatestTransaction() - if err != nil { - log.Error("GetStateChangesForTheLatestTransaction error", "err", err.Error()) - } - }() - switch txType { case process.MoveBalance: err = txProc.processMoveBalance(tx, acntSnd, acntDst, dstShardTxType, nil, false) diff --git a/process/transactionEvaluator/simulationAccountsDB.go b/process/transactionEvaluator/simulationAccountsDB.go index 2c87689227a..72299570455 100644 --- a/process/transactionEvaluator/simulationAccountsDB.go +++ b/process/transactionEvaluator/simulationAccountsDB.go @@ -172,9 +172,14 @@ func (r *simulationAccountsDB) GetStackDebugFirstEntry() []byte { return nil } -// GetStateChangesForTheLatestTransaction - -func (r *simulationAccountsDB) GetStateChangesForTheLatestTransaction() ([]state.StateChangeDTO, error) { - return r.originalAccounts.GetStateChangesForTheLatestTransaction() +// SetTxHashForLatestStateChanges - +func (r *simulationAccountsDB) SetTxHashForLatestStateChanges(txHash []byte) { + r.originalAccounts.SetTxHashForLatestStateChanges(txHash) +} + +// ResetStateChangesCollector - +func (r *simulationAccountsDB) ResetStateChangesCollector() []state.StateChangesForTx { + return nil } // Close will handle the closing of the underlying components diff --git a/state/accountsDB.go b/state/accountsDB.go index b515d4cc9e9..db1396686d4 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -95,14 +95,14 @@ var log = logger.GetOrCreate("state") // ArgsAccountsDB is the arguments DTO for the AccountsDB instance type ArgsAccountsDB struct { - Trie common.Trie - Hasher hashing.Hasher - Marshaller marshal.Marshalizer - AccountFactory AccountFactory - StoragePruningManager StoragePruningManager - AddressConverter core.PubkeyConverter - SnapshotsManager SnapshotsManager - StateChangesCollector StateChangesCollector + Trie common.Trie + Hasher hashing.Hasher + Marshaller marshal.Marshalizer + AccountFactory AccountFactory + StoragePruningManager StoragePruningManager + AddressConverter core.PubkeyConverter + SnapshotsManager SnapshotsManager + StateChangesCollector StateChangesCollector } // NewAccountsDB creates a new account manager @@ -1248,12 +1248,16 @@ func collectStats( log.Debug(strings.Join(trieStats.ToString(), " ")) } -// GetStateChangesForTheLatestTransaction will return the state changes since the last call of this method -func (adb *AccountsDB) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { +// SetTxHashForLatestStateChanges will return the state changes since the last call of this method +func (adb *AccountsDB) SetTxHashForLatestStateChanges(txHash []byte) { + adb.stateChangesCollector.AddTxHashToCollectedStateChanges(txHash) +} + +func (adb *AccountsDB) ResetStateChangesCollector() []StateChangesForTx { stateChanges := adb.stateChangesCollector.GetStateChanges() adb.stateChangesCollector.Reset() - return stateChanges, nil + return stateChanges } // IsSnapshotInProgress returns true if there is a snapshot in progress diff --git a/state/accountsDBApi.go b/state/accountsDBApi.go index 2da28fbbc6c..0c283d4daec 100644 --- a/state/accountsDBApi.go +++ b/state/accountsDBApi.go @@ -217,9 +217,14 @@ func (accountsDB *accountsDBApi) GetStackDebugFirstEntry() []byte { return accountsDB.innerAccountsAdapter.GetStackDebugFirstEntry() } -// GetStateChangesForTheLatestTransaction will call the inner accountsAdapter method -func (accountsDB *accountsDBApi) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { - return accountsDB.innerAccountsAdapter.GetStateChangesForTheLatestTransaction() +// SetTxHashForLatestStateChanges will call the inner accountsAdapter method +func (accountsDB *accountsDBApi) SetTxHashForLatestStateChanges(txHash []byte) { + accountsDB.innerAccountsAdapter.SetTxHashForLatestStateChanges(txHash) +} + +// ResetStateChangesCollector returns nil +func (accountsDB *accountsDBApi) ResetStateChangesCollector() []StateChangesForTx { + return nil } // Close will handle the closing of the underlying components diff --git a/state/accountsDBApiWithHistory.go b/state/accountsDBApiWithHistory.go index 1927b6dd174..0377a3b07bb 100644 --- a/state/accountsDBApiWithHistory.go +++ b/state/accountsDBApiWithHistory.go @@ -140,9 +140,13 @@ func (accountsDB *accountsDBApiWithHistory) GetStackDebugFirstEntry() []byte { return nil } -// GetStateChangesForTheLatestTransaction returns nil -func (accountsDB *accountsDBApiWithHistory) GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) { - return nil, nil +// SetTxHashForLatestStateChanges returns nil +func (accountsDB *accountsDBApiWithHistory) SetTxHashForLatestStateChanges(_ []byte) { +} + +// ResetStateChangesCollector returns nil +func (accountsDB *accountsDBApiWithHistory) ResetStateChangesCollector() []StateChangesForTx { + return nil } // Close will handle the closing of the underlying components diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 847f4e244cf..7f2658058fd 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "errors" "fmt" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" mathRand "math/rand" "strings" "sync" @@ -25,7 +26,6 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" - "github.com/multiversx/mx-chain-go/state/dataTrieValue" "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/iteratorChannelsProvider" "github.com/multiversx/mx-chain-go/state/lastSnapshotMarker" @@ -455,13 +455,16 @@ func stepCreateAccountWithDataTrieAndCode( _ = userAcc.SaveKeyValue(key1, []byte("value")) _ = userAcc.SaveKeyValue(key2, []byte("value")) _ = adb.SaveAccount(userAcc) - + adb.SetTxHashForLatestStateChanges([]byte("accountCreationTxHash")) serializedAcc, _ := marshaller.Marshal(userAcc) codeHash := userAcc.GetCodeHash() - stateChanges, err := adb.GetStateChangesForTheLatestTransaction() - assert.Nil(t, err) + stateChangesForTx := adb.ResetStateChangesCollector() + assert.Equal(t, 1, len(stateChangesForTx)) + + stateChanges := stateChangesForTx[0].StateChanges assert.Equal(t, 2, len(stateChanges)) + assert.Equal(t, []byte("accountCreationTxHash"), stateChangesForTx[0].TxHash) codeStateChange := stateChanges[0] assert.Equal(t, codeHash, codeStateChange.MainTrieKey) @@ -501,11 +504,14 @@ func stepMigrateDataTrieValAndChangeCode( userAcc.SetCode(code) _ = userAcc.SaveKeyValue([]byte("key1"), []byte("value1")) _ = adb.SaveAccount(userAcc) + adb.SetTxHashForLatestStateChanges([]byte("accountUpdateTxHash")) - stateChanges, err := adb.GetStateChangesForTheLatestTransaction() - assert.Nil(t, err) - assert.Equal(t, 3, len(stateChanges)) + stateChangesForTx := adb.ResetStateChangesCollector() + assert.Equal(t, 1, len(stateChangesForTx)) + assert.Equal(t, 3, len(stateChangesForTx[0].StateChanges)) + assert.Equal(t, []byte("accountUpdateTxHash"), stateChangesForTx[0].TxHash) + stateChanges := stateChangesForTx[0].StateChanges oldCodeEntryChange := stateChanges[0] assert.Equal(t, oldCodeHash, oldCodeEntryChange.MainTrieKey) assert.Equal(t, []byte(nil), oldCodeEntryChange.MainTrieVal) @@ -537,6 +543,7 @@ func stepMigrateDataTrieValAndChangeCode( assert.Equal(t, []byte("key1"), accountStateChange.DataTrieChanges[1].Key) assert.Equal(t, []byte(nil), accountStateChange.DataTrieChanges[1].Val) } + func TestAccountsDB_SaveAccountWithSomeValuesShouldWork(t *testing.T) { t.Parallel() diff --git a/state/disabled/disabledStateChangesCollector.go b/state/disabled/disabledStateChangesCollector.go index 3d1407d3be5..a4f5c15c069 100644 --- a/state/disabled/disabledStateChangesCollector.go +++ b/state/disabled/disabledStateChangesCollector.go @@ -4,31 +4,33 @@ import ( "github.com/multiversx/mx-chain-go/state" ) -// DisabledStateChangesCollector is a state changes collector that does nothing -type DisabledStateChangesCollector struct { +// disabledStateChangesCollector is a state changes collector that does nothing +type disabledStateChangesCollector struct { } -// NewDisabledStateChangesCollector creates a new DisabledStateChangesCollector +// NewDisabledStateChangesCollector creates a new disabledStateChangesCollector func NewDisabledStateChangesCollector() state.StateChangesCollector { - return &DisabledStateChangesCollector{} + return &disabledStateChangesCollector{} } // AddStateChange does nothing -func (d *DisabledStateChangesCollector) AddStateChange(_ state.StateChangeDTO) { - +func (d *disabledStateChangesCollector) AddStateChange(_ state.StateChangeDTO) { } // GetStateChanges returns an empty slice -func (d *DisabledStateChangesCollector) GetStateChanges() []state.StateChangeDTO { - return []state.StateChangeDTO{} +func (d *disabledStateChangesCollector) GetStateChanges() []state.StateChangesForTx { + return make([]state.StateChangesForTx, 0) } // Reset does nothing -func (d *DisabledStateChangesCollector) Reset() { +func (d *disabledStateChangesCollector) Reset() { +} +// AddTxHashToCollectedStateChanges does nothing +func (d *disabledStateChangesCollector) AddTxHashToCollectedStateChanges(_ []byte) { } // IsInterfaceNil returns true if there is no value under the interface -func (d *DisabledStateChangesCollector) IsInterfaceNil() bool { +func (d *disabledStateChangesCollector) IsInterfaceNil() bool { return d == nil } diff --git a/state/interface.go b/state/interface.go index 19806264452..ad415547983 100644 --- a/state/interface.go +++ b/state/interface.go @@ -49,7 +49,8 @@ type AccountsAdapter interface { GetStackDebugFirstEntry() []byte SetSyncer(syncer AccountsDBSyncer) error StartSnapshotIfNeeded() error - GetStateChangesForTheLatestTransaction() ([]StateChangeDTO, error) + SetTxHashForLatestStateChanges(txHash []byte) + ResetStateChangesCollector() []StateChangesForTx Close() error IsInterfaceNil() bool } @@ -283,7 +284,8 @@ type LastSnapshotMarker interface { // StateChangesCollector defines the methods needed for an StateChangesCollector implementation type StateChangesCollector interface { AddStateChange(stateChange StateChangeDTO) - GetStateChanges() []StateChangeDTO + GetStateChanges() []StateChangesForTx Reset() + AddTxHashToCollectedStateChanges(txHash []byte) IsInterfaceNil() bool } diff --git a/state/stateChangesCollector.go b/state/stateChangesCollector.go index 9c23b1db464..9e3f8fabf64 100644 --- a/state/stateChangesCollector.go +++ b/state/stateChangesCollector.go @@ -1,5 +1,11 @@ package state +// DataTrieChange represents a change in the data trie +type DataTrieChange struct { + Key []byte + Val []byte +} + // StateChangeDTO is used to collect state changes type StateChangeDTO struct { MainTrieKey []byte @@ -7,14 +13,15 @@ type StateChangeDTO struct { DataTrieChanges []DataTrieChange } -// DataTrieChange represents a change in the data trie -type DataTrieChange struct { - Key []byte - Val []byte +// StateChangesForTx is used to collect state changes for a transaction hash +type StateChangesForTx struct { + TxHash []byte + StateChanges []StateChangeDTO } type stateChangesCollector struct { - stateChanges []StateChangeDTO + stateChanges []StateChangeDTO + stateChangesForTx []StateChangesForTx } // NewStateChangesCollector creates a new StateChangesCollector @@ -30,13 +37,28 @@ func (scc *stateChangesCollector) AddStateChange(stateChange StateChangeDTO) { } // GetStateChanges returns the accumulated state changes -func (scc *stateChangesCollector) GetStateChanges() []StateChangeDTO { - return scc.stateChanges +func (scc *stateChangesCollector) GetStateChanges() []StateChangesForTx { + if len(scc.stateChanges) > 0 { + scc.AddTxHashToCollectedStateChanges([]byte{}) + } + + return scc.stateChangesForTx } // Reset resets the state changes collector func (scc *stateChangesCollector) Reset() { - scc.stateChanges = []StateChangeDTO{} + scc.stateChanges = make([]StateChangeDTO, 0) + scc.stateChangesForTx = make([]StateChangesForTx, 0) +} + +func (scc *stateChangesCollector) AddTxHashToCollectedStateChanges(txHash []byte) { + stateChangesForTx := StateChangesForTx{ + TxHash: txHash, + StateChanges: scc.stateChanges, + } + + scc.stateChanges = make([]StateChangeDTO, 0) + scc.stateChangesForTx = append(scc.stateChangesForTx, stateChangesForTx) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/stateChangesCollector_test.go b/state/stateChangesCollector_test.go index 6c678ea2a83..6a192601f2e 100644 --- a/state/stateChangesCollector_test.go +++ b/state/stateChangesCollector_test.go @@ -31,22 +31,87 @@ func TestStateChangesCollector_AddStateChange(t *testing.T) { func TestStateChangesCollector_GetStateChanges(t *testing.T) { t.Parallel() + t.Run("getStateChanges with tx hash", func(t *testing.T) { + t.Parallel() + + scc := NewStateChangesCollector() + assert.Equal(t, 0, len(scc.stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + scc.AddStateChange(StateChangeDTO{ + MainTrieKey: []byte(strconv.Itoa(i)), + }) + } + assert.Equal(t, numStateChanges, len(scc.stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) + scc.AddTxHashToCollectedStateChanges([]byte("txHash")) + assert.Equal(t, 0, len(scc.stateChanges)) + assert.Equal(t, 1, len(scc.stateChangesForTx)) + assert.Equal(t, []byte("txHash"), scc.stateChangesForTx[0].TxHash) + assert.Equal(t, numStateChanges, len(scc.stateChangesForTx[0].StateChanges)) + + stateChangesForTx := scc.GetStateChanges() + assert.Equal(t, 1, len(stateChangesForTx)) + assert.Equal(t, []byte("txHash"), stateChangesForTx[0].TxHash) + for i := 0; i < len(stateChangesForTx[0].StateChanges); i++ { + assert.Equal(t, []byte(strconv.Itoa(i)), stateChangesForTx[0].StateChanges[i].MainTrieKey) + } + }) + + t.Run("getStateChanges without tx hash", func(t *testing.T) { + t.Parallel() + + scc := NewStateChangesCollector() + assert.Equal(t, 0, len(scc.stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + scc.AddStateChange(StateChangeDTO{ + MainTrieKey: []byte(strconv.Itoa(i)), + }) + } + assert.Equal(t, numStateChanges, len(scc.stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) + + stateChangesForTx := scc.GetStateChanges() + assert.Equal(t, 1, len(stateChangesForTx)) + assert.Equal(t, []byte{}, stateChangesForTx[0].TxHash) + for i := 0; i < len(stateChangesForTx[0].StateChanges); i++ { + assert.Equal(t, []byte(strconv.Itoa(i)), stateChangesForTx[0].StateChanges[i].MainTrieKey) + } + }) +} + +func TestStateChangesCollector_AddTxHashToCollectedStateChanges(t *testing.T) { + t.Parallel() + scc := NewStateChangesCollector() assert.Equal(t, 0, len(scc.stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(StateChangeDTO{ - MainTrieKey: []byte(strconv.Itoa(i)), - }) + stateChange := StateChangeDTO{ + MainTrieKey: []byte("mainTrieKey"), + MainTrieVal: []byte("mainTrieVal"), + DataTrieChanges: []DataTrieChange{{Key: []byte("dataTrieKey"), Val: []byte("dataTrieVal")}}, } - assert.Equal(t, numStateChanges, len(scc.stateChanges)) + scc.AddStateChange(stateChange) - stateChanges := scc.GetStateChanges() - assert.Equal(t, numStateChanges, len(stateChanges)) - for i := 0; i < numStateChanges; i++ { - assert.Equal(t, []byte(strconv.Itoa(i)), stateChanges[i].MainTrieKey) - } + assert.Equal(t, 1, len(scc.stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) + scc.AddTxHashToCollectedStateChanges([]byte("txHash")) + assert.Equal(t, 0, len(scc.stateChanges)) + assert.Equal(t, 1, len(scc.stateChangesForTx)) + + stateChangesForTx := scc.GetStateChanges() + assert.Equal(t, 1, len(stateChangesForTx)) + assert.Equal(t, []byte("txHash"), stateChangesForTx[0].TxHash) + assert.Equal(t, 1, len(stateChangesForTx[0].StateChanges)) + assert.Equal(t, []byte("mainTrieKey"), stateChangesForTx[0].StateChanges[0].MainTrieKey) + assert.Equal(t, []byte("mainTrieVal"), stateChangesForTx[0].StateChanges[0].MainTrieVal) + assert.Equal(t, 1, len(stateChangesForTx[0].StateChanges[0].DataTrieChanges)) } func TestStateChangesCollector_Reset(t *testing.T) { @@ -59,10 +124,14 @@ func TestStateChangesCollector_Reset(t *testing.T) { for i := 0; i < numStateChanges; i++ { scc.AddStateChange(StateChangeDTO{}) } + scc.AddTxHashToCollectedStateChanges([]byte("txHash")) + for i := numStateChanges; i < numStateChanges*2; i++ { + scc.AddStateChange(StateChangeDTO{}) + } assert.Equal(t, numStateChanges, len(scc.stateChanges)) - stateChanges := scc.GetStateChanges() + assert.Equal(t, 1, len(scc.stateChangesForTx)) scc.Reset() assert.Equal(t, 0, len(scc.stateChanges)) - assert.Equal(t, numStateChanges, len(stateChanges)) + assert.Equal(t, 0, len(scc.stateChangesForTx)) } diff --git a/testscommon/state/accountsAdapterStub.go b/testscommon/state/accountsAdapterStub.go index c119b52bcf3..b7dcbbb484a 100644 --- a/testscommon/state/accountsAdapterStub.go +++ b/testscommon/state/accountsAdapterStub.go @@ -13,34 +13,34 @@ var errNotImplemented = errors.New("not implemented") // AccountsStub - type AccountsStub struct { - GetExistingAccountCalled func(addressContainer []byte) (vmcommon.AccountHandler, error) - GetAccountFromBytesCalled func(address []byte, accountBytes []byte) (vmcommon.AccountHandler, error) - LoadAccountCalled func(container []byte) (vmcommon.AccountHandler, error) - SaveAccountCalled func(account vmcommon.AccountHandler) error - RemoveAccountCalled func(addressContainer []byte) error - CommitCalled func() ([]byte, error) - CommitInEpochCalled func(uint32, uint32) ([]byte, error) - JournalLenCalled func() int - RevertToSnapshotCalled func(snapshot int) error - RootHashCalled func() ([]byte, error) - RecreateTrieCalled func(rootHash []byte) error - RecreateTrieFromEpochCalled func(options common.RootHashHolder) error - PruneTrieCalled func(rootHash []byte, identifier state.TriePruningIdentifier, handler state.PruningHandler) - CancelPruneCalled func(rootHash []byte, identifier state.TriePruningIdentifier) - SnapshotStateCalled func(rootHash []byte, epoch uint32) - IsPruningEnabledCalled func() bool - GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error - RecreateAllTriesCalled func(rootHash []byte) (map[string]common.Trie, error) - GetCodeCalled func([]byte) []byte - GetTrieCalled func([]byte) (common.Trie, error) - GetStackDebugFirstEntryCalled func() []byte - GetAccountWithBlockInfoCalled func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) - GetCodeWithBlockInfoCalled func(codeHash []byte, options common.RootHashHolder) ([]byte, common.BlockInfo, error) - CloseCalled func() error - SetSyncerCalled func(syncer state.AccountsDBSyncer) error - StartSnapshotIfNeededCalled func() error - GetStateChangesForTheLatestTransactionCalled func() ([]state.StateChangeDTO, error) - + GetExistingAccountCalled func(addressContainer []byte) (vmcommon.AccountHandler, error) + GetAccountFromBytesCalled func(address []byte, accountBytes []byte) (vmcommon.AccountHandler, error) + LoadAccountCalled func(container []byte) (vmcommon.AccountHandler, error) + SaveAccountCalled func(account vmcommon.AccountHandler) error + RemoveAccountCalled func(addressContainer []byte) error + CommitCalled func() ([]byte, error) + CommitInEpochCalled func(uint32, uint32) ([]byte, error) + JournalLenCalled func() int + RevertToSnapshotCalled func(snapshot int) error + RootHashCalled func() ([]byte, error) + RecreateTrieCalled func(rootHash []byte) error + RecreateTrieFromEpochCalled func(options common.RootHashHolder) error + PruneTrieCalled func(rootHash []byte, identifier state.TriePruningIdentifier, handler state.PruningHandler) + CancelPruneCalled func(rootHash []byte, identifier state.TriePruningIdentifier) + SnapshotStateCalled func(rootHash []byte, epoch uint32) + IsPruningEnabledCalled func() bool + GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error + RecreateAllTriesCalled func(rootHash []byte) (map[string]common.Trie, error) + GetCodeCalled func([]byte) []byte + GetTrieCalled func([]byte) (common.Trie, error) + GetStackDebugFirstEntryCalled func() []byte + GetAccountWithBlockInfoCalled func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) + GetCodeWithBlockInfoCalled func(codeHash []byte, options common.RootHashHolder) ([]byte, common.BlockInfo, error) + CloseCalled func() error + SetSyncerCalled func(syncer state.AccountsDBSyncer) error + StartSnapshotIfNeededCalled func() error + SetTxHashForLatestStateChangesCalled func(txHash []byte) + ResetStateChangesCollectorCalled func() []state.StateChangesForTx } // CleanCache - @@ -259,13 +259,20 @@ func (as *AccountsStub) GetCodeWithBlockInfo(codeHash []byte, options common.Roo return nil, nil, nil } -// GetStateChangesForTheLatestTransaction - -func (as *AccountsStub) GetStateChangesForTheLatestTransaction() ([]state.StateChangeDTO, error) { - if as.GetStateChangesForTheLatestTransactionCalled != nil { - return as.GetStateChangesForTheLatestTransactionCalled() +// SetTxHashForLatestStateChanges - +func (as *AccountsStub) SetTxHashForLatestStateChanges(txHash []byte) { + if as.SetTxHashForLatestStateChangesCalled != nil { + as.SetTxHashForLatestStateChangesCalled(txHash) } +} - return nil, nil +// ResetStateChangesCollector - +func (as *AccountsStub) ResetStateChangesCollector() []state.StateChangesForTx { + if as.ResetStateChangesCollectorCalled != nil { + return as.ResetStateChangesCollectorCalled() + } + + return nil } // Close - From 53310726e0e3d4b8f394debdcd1fc975bede412b Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 18 Mar 2024 11:26:19 +0200 Subject: [PATCH 09/10] fix after merge --- integrationTests/vm/staking/componentsHolderCreator.go | 2 ++ state/accountsDB_test.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/integrationTests/vm/staking/componentsHolderCreator.go b/integrationTests/vm/staking/componentsHolderCreator.go index e3673b08ec7..251e7eec511 100644 --- a/integrationTests/vm/staking/componentsHolderCreator.go +++ b/integrationTests/vm/staking/componentsHolderCreator.go @@ -27,6 +27,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" + disabledState "github.com/multiversx/mx-chain-go/state/disabled" stateFactory "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" @@ -215,6 +216,7 @@ func createAccountsDB( StoragePruningManager: spm, AddressConverter: coreComponents.AddressPubKeyConverter(), SnapshotsManager: &stateTests.SnapshotsManagerStub{}, + StateChangesCollector: disabledState.NewDisabledStateChangesCollector(), } adb, _ := state.NewAccountsDB(argsAccountsDb) return adb diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index d3886f2d944..5625bea069a 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -6,7 +6,6 @@ import ( "crypto/rand" "errors" "fmt" - "github.com/multiversx/mx-chain-go/state/dataTrieValue" mathRand "math/rand" "strings" "sync" @@ -26,6 +25,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/iteratorChannelsProvider" "github.com/multiversx/mx-chain-go/state/lastSnapshotMarker" From 9836fde68031e311dffc231b56bc444d8371dc4a Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 4 Sep 2024 15:35:48 +0300 Subject: [PATCH 10/10] fixes after review --- epochStart/metachain/systemSCs_test.go | 2 +- factory/processing/blockProcessorCreator_test.go | 2 +- process/transaction/shardProcess.go | 1 + state/accountsDB.go | 7 ++++--- state/accountsDBApi.go | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/epochStart/metachain/systemSCs_test.go b/epochStart/metachain/systemSCs_test.go index d203a2c1075..4d2b5d466a7 100644 --- a/epochStart/metachain/systemSCs_test.go +++ b/epochStart/metachain/systemSCs_test.go @@ -764,7 +764,7 @@ func createAccountsDB( StoragePruningManager: spm, AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), - StateChangesCollector: state.NewStateChangesCollector(), + StateChangesCollector: disabledState.NewDisabledStateChangesCollector(), } adb, _ := state.NewAccountsDB(args) return adb diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index 92941e4778e..938e1f65ac1 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -208,7 +208,7 @@ func createAccountAdapter( StoragePruningManager: disabled.NewDisabledStoragePruningManager(), AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), - StateChangesCollector: state.NewStateChangesCollector(), + StateChangesCollector: disabledState.NewDisabledStateChangesCollector(), } adb, err := state.NewAccountsDB(args) if err != nil { diff --git a/process/transaction/shardProcess.go b/process/transaction/shardProcess.go index 43c5b4878bf..a3242861f92 100644 --- a/process/transaction/shardProcess.go +++ b/process/transaction/shardProcess.go @@ -195,6 +195,7 @@ func (txProc *txProcessor) ProcessTransaction(tx *transaction.Transaction) (vmco txProc.pubkeyConv, ) + // TODO refactor to set the tx hash for the following state changes before the processing occurs defer txProc.accounts.SetTxHashForLatestStateChanges(txHash) txType, dstShardTxType := txProc.txTypeHandler.ComputeTransactionType(tx) diff --git a/state/accountsDB.go b/state/accountsDB.go index ce450bdcb95..fe92920e383 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -280,7 +280,7 @@ func (adb *AccountsDB) SaveAccount(account vmcommon.AccountHandler) error { } adb.stateChangesCollector.AddStateChange(stateChange) - return err + return nil } func (adb *AccountsDB) saveCodeAndDataTrie(oldAcc, newAcc vmcommon.AccountHandler) ([]DataTrieChange, error) { @@ -288,7 +288,7 @@ func (adb *AccountsDB) saveCodeAndDataTrie(oldAcc, newAcc vmcommon.AccountHandle baseOldAccount, _ := oldAcc.(baseAccountHandler) if !newAccOk { - return make([]DataTrieChange, 0), nil + return nil, nil } newValues, err := adb.saveDataTrie(baseNewAcc) @@ -301,7 +301,7 @@ func (adb *AccountsDB) saveCodeAndDataTrie(oldAcc, newAcc vmcommon.AccountHandle return nil, err } - return newValues, err + return newValues, nil } func (adb *AccountsDB) saveCode(newAcc, oldAcc baseAccountHandler) error { @@ -384,6 +384,7 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error return nil, err } + // TODO refactor this after remove code leaf is merged stateChange := StateChangeDTO{ MainTrieKey: oldCodeHash, MainTrieVal: codeEntryBytes, diff --git a/state/accountsDBApi.go b/state/accountsDBApi.go index 2d07d99e818..b548901f167 100644 --- a/state/accountsDBApi.go +++ b/state/accountsDBApi.go @@ -244,7 +244,7 @@ func (accountsDB *accountsDBApi) SetTxHashForLatestStateChanges(txHash []byte) { // ResetStateChangesCollector returns nil func (accountsDB *accountsDBApi) ResetStateChangesCollector() []StateChangesForTx { - return nil + return accountsDB.innerAccountsAdapter.ResetStateChangesCollector() } // Close will handle the closing of the underlying components