diff --git a/cmd/node/CLI.md b/cmd/node/CLI.md index cd5b4b6e2ac..fc95fa874a4 100644 --- a/cmd/node/CLI.md +++ b/cmd/node/CLI.md @@ -73,6 +73,7 @@ GLOBAL OPTIONS: --operation-mode operation mode String flag for specifying the desired operation mode(s) of the node, resulting in altering some configuration values accordingly. Possible values are: snapshotless-observer, full-archive, db-lookup-extension, historical-balances or `""` (empty). Multiple values can be separated via , --repopulate-tokens-supplies Boolean flag for repopulating the tokens supplies database. It will delete the current data, iterate over the entire trie and add he new obtained supplies --p2p-prometheus-metrics Boolean option for enabling the /debug/metrics/prometheus route for p2p prometheus metrics + --state-changes-types-to-collect value String slice option for enabling collecting specified state changes types. Can be (READ, WRITE) --help, -h show help --version, -v print the version diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 091d4d780c1..05c08687ad5 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -662,7 +662,9 @@ MaxStateTrieLevelInMemory = 5 MaxPeerTrieLevelInMemory = 5 StateStatisticsEnabled = false - CollectStateChangesEnabled = false + StateChangesTypesToCollect = [] + StateChangesDataAnalysis = false + StateChangesPeerAccountsEnabled = false [BlockSizeThrottleConfig] MinSizeInBytes = 104857 # 104857 is 10% from 1MB diff --git a/cmd/node/flags.go b/cmd/node/flags.go index 72c86c04f96..29c8528b78e 100644 --- a/cmd/node/flags.go +++ b/cmd/node/flags.go @@ -6,12 +6,13 @@ import ( "os" "runtime" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/urfave/cli" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/operationmodes" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/facade" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/urfave/cli" ) var ( @@ -408,6 +409,13 @@ var ( Name: "p2p-prometheus-metrics", Usage: "Boolean option for enabling the /debug/metrics/prometheus route for p2p prometheus metrics", } + + // stateChangesTypesToCollect defines a flag for collecting specified types of state changes + // If enabled, it will override the configuration + stateChangesTypesToCollect = cli.StringSliceFlag{ + Name: "state-changes-types-to-collect", + Usage: "String slice option for enabling collecting specified state changes types. Can be (READ, WRITE)", + } ) func getFlags() []cli.Flag { @@ -470,6 +478,7 @@ func getFlags() []cli.Flag { operationMode, repopulateTokensSupplies, p2pPrometheusMetrics, + stateChangesTypesToCollect, } } diff --git a/cmd/node/main.go b/cmd/node/main.go index c7cc3c1085c..0169aa278a5 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -9,14 +9,15 @@ import ( "github.com/klauspost/cpuid/v2" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-logger-go/file" + "github.com/urfave/cli" + "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/config/overridableConfig" "github.com/multiversx/mx-chain-go/node" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/multiversx/mx-chain-logger-go/file" - "github.com/urfave/cli" // test point 1 for custom profiler ) @@ -253,6 +254,9 @@ func readConfigs(ctx *cli.Context, log logger.Logger) (*config.Configs, error) { if ctx.IsSet(identityFlagName.Name) { preferencesConfig.Preferences.Identity = ctx.GlobalString(identityFlagName.Name) } + if ctx.IsSet(stateChangesTypesToCollect.Name) { + generalConfig.StateTriesConfig.StateChangesTypesToCollect = ctx.GlobalStringSlice(stateChangesTypesToCollect.Name) + } return &config.Configs{ GeneralConfig: generalConfig, diff --git a/config/config.go b/config/config.go index d023a4fd522..cdd64162b4b 100644 --- a/config/config.go +++ b/config/config.go @@ -308,14 +308,15 @@ type FacadeConfig struct { // StateTriesConfig will hold information about state tries type StateTriesConfig struct { - SnapshotsEnabled bool - AccountsStatePruningEnabled bool - PeerStatePruningEnabled bool - CollectStateChangesEnabled bool - CollectStateChangesWithReadEnabled bool - MaxStateTrieLevelInMemory uint - MaxPeerTrieLevelInMemory uint - StateStatisticsEnabled bool + SnapshotsEnabled bool + AccountsStatePruningEnabled bool + PeerStatePruningEnabled bool + StateChangesDataAnalysis bool + StateChangesTypesToCollect []string + StateChangesPeerAccountsEnabled bool + MaxStateTrieLevelInMemory uint + MaxPeerTrieLevelInMemory uint + StateStatisticsEnabled bool } // TrieStorageManagerConfig will hold config information about trie storage manager diff --git a/epochStart/metachain/baseRewards_test.go b/epochStart/metachain/baseRewards_test.go index 87ccb625643..2578b0bde75 100644 --- a/epochStart/metachain/baseRewards_test.go +++ b/epochStart/metachain/baseRewards_test.go @@ -15,12 +15,15 @@ import ( "github.com/multiversx/mx-chain-core-go/data/rewardTx" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/state/factory" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/testscommon" txExecOrderStub "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" @@ -30,9 +33,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestBaseRewardsCreator_NilShardCoordinator(t *testing.T) { @@ -1181,7 +1181,7 @@ func getBaseRewardsArguments() BaseRewardsCreatorArgs { Hasher: hasher, Marshaller: marshalizer, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateChangesCollector: stateChanges.NewStateChangesCollector(), + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index cc8ef9f8ddf..998ff091473 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -8,6 +8,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" dataComp "github.com/multiversx/mx-chain-go/factory/data" @@ -17,7 +20,6 @@ import ( "github.com/multiversx/mx-chain-go/state/accounts" disabledState "github.com/multiversx/mx-chain-go/state/disabled" factoryState "github.com/multiversx/mx-chain-go/state/factory" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" "github.com/multiversx/mx-chain-go/testscommon" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" @@ -27,8 +29,6 @@ import ( storageManager "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) func Test_newBlockProcessorCreatorForShard(t *testing.T) { @@ -107,7 +107,7 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { Hasher: coreComponents.Hasher(), Marshaller: coreComponents.InternalMarshalizer(), EnableEpochsHandler: coreComponents.EnableEpochsHandler(), - StateChangesCollector: stateChanges.NewStateChangesCollector(), + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } accCreator, _ := factoryState.NewAccountCreator(argsAccCreator) diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index dce52d14547..cd5959685ed 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -2,9 +2,12 @@ package state import ( "fmt" + "strings" "github.com/multiversx/mx-chain-core-go/core/check" chainData "github.com/multiversx/mx-chain-core-go/data" + data "github.com/multiversx/mx-chain-core-go/data/stateChange" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -122,34 +125,55 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { } func (scf *stateComponentsFactory) createStateChangesCollector() (state.StateChangesCollector, error) { - if !scf.config.StateTriesConfig.CollectStateChangesEnabled { + if len(scf.config.StateTriesConfig.StateChangesTypesToCollect) == 0 { return disabled.NewDisabledStateChangesCollector(), nil } - if !scf.config.StateTriesConfig.CollectStateChangesWithReadEnabled { - return stateChanges.NewStateChangesCollector(), nil + collectRead, collectWrite, err := parseStateChangesTypesToCollect(scf.config.StateTriesConfig.StateChangesTypesToCollect) + if err != nil { + return nil, fmt.Errorf("failed to parse state changes types to collect: %w", err) } - // TODO: move to toml config file - dbConfig := config.DBConfig{ - FilePath: "stateChanges", - Type: "LvlDBSerial", - BatchDelaySeconds: 2, - MaxBatchSize: 100, - MaxOpenFiles: 10, + var opts []stateChanges.CollectorOption + if collectRead { + opts = append(opts, stateChanges.WithCollectRead()) + } + if collectWrite { + opts = append(opts, stateChanges.WithCollectWrite()) } - persisterFactory, err := storageFactory.NewPersisterFactory(dbConfig) - if err != nil { - return nil, err + if scf.config.StateTriesConfig.StateChangesDataAnalysis { + // TODO: move to toml config file + dbConfig := config.DBConfig{ + FilePath: "stateChanges", + Type: "LvlDBSerial", + BatchDelaySeconds: 2, + MaxBatchSize: 100, + MaxOpenFiles: 10, + } + + persisterFactory, err := storageFactory.NewPersisterFactory(dbConfig) + if err != nil { + return nil, err + } + + db, err := persisterFactory.CreateWithRetries(dbConfig.FilePath) + if err != nil { + return nil, fmt.Errorf("%w while creating the db for the trie nodes", err) + } + + opts = append(opts, stateChanges.WithStorer(db)) } - db, err := persisterFactory.CreateWithRetries(dbConfig.FilePath) - if err != nil { - return nil, fmt.Errorf("%w while creating the db for the trie nodes", err) + return stateChanges.NewCollector(opts...), nil +} + +func (scf *stateComponentsFactory) createStateChangesCollectorPeerAccounts() (state.StateChangesCollector, error) { + if !scf.config.StateTriesConfig.StateChangesPeerAccountsEnabled { + return disabled.NewDisabledStateChangesCollector(), nil } - return stateChanges.NewDataAnalysisStateChangesCollector(db) + return scf.createStateChangesCollector() } func (scf *stateComponentsFactory) createSnapshotManager( @@ -287,7 +311,7 @@ func (scf *stateComponentsFactory) createPeerAdapter(triesContainer common.Tries return nil, err } - stateChangesCollector, err := scf.createStateChangesCollector() + stateChangesCollector, err := scf.createStateChangesCollectorPeerAccounts() if err != nil { return nil, err } @@ -375,3 +399,32 @@ func (pc *stateComponents) Close() error { } return nil } + +func parseStateChangesTypesToCollect(stateChangesTypes []string) (collectRead bool, collectWrite bool, err error) { + types := sanitizeActionTypes(data.ActionType_value) + for _, stateChangeType := range stateChangesTypes { + if value, ok := types[strings.ToLower(stateChangeType)]; ok { + switch value { + case 0: + collectRead = true + + case 1: + collectWrite = true + } + } else { + return false, false, fmt.Errorf("unknown action type %s", stateChangeType) + } + } + + return collectRead, collectWrite, nil +} + +func sanitizeActionTypes(actionTypes map[string]int32) map[string]int32 { + sanitizedActionTypes := make(map[string]int32, len(actionTypes)) + + for actionType, value := range actionTypes { + sanitizedActionTypes[strings.ToLower(actionType)] = value + } + + return sanitizedActionTypes +} diff --git a/factory/state/stateComponents_test.go b/factory/state/stateComponents_test.go index bf5068e8dd7..50a7ad8184d 100644 --- a/factory/state/stateComponents_test.go +++ b/factory/state/stateComponents_test.go @@ -5,12 +5,13 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/errors" stateComp "github.com/multiversx/mx-chain-go/factory/state" "github.com/multiversx/mx-chain-go/testscommon" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/multiversx/mx-chain-go/testscommon/factory" - "github.com/stretchr/testify/require" ) func TestNewStateComponentsFactory(t *testing.T) { diff --git a/factory/state/stateParser_test.go b/factory/state/stateParser_test.go new file mode 100644 index 00000000000..1b6606751f5 --- /dev/null +++ b/factory/state/stateParser_test.go @@ -0,0 +1,79 @@ +package state + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStateComponents_ParseStateChangesTypesToCollect(t *testing.T) { + t.Parallel() + + t.Run("should parse state changes: 1 read 1 write", func(t *testing.T) { + t.Parallel() + + collectRead, collectWrite, err := parseStateChangesTypesToCollect([]string{"read", "write"}) + require.NoError(t, err) + require.True(t, collectRead) + require.True(t, collectWrite) + }) + + t.Run("should parse state changes: multiple types", func(t *testing.T) { + t.Parallel() + + collectRead, collectWrite, err := parseStateChangesTypesToCollect([]string{"read", "read", "write", "write"}) + require.NoError(t, err) + require.True(t, collectRead) + require.True(t, collectWrite) + }) + + t.Run("should parse state changes: inconsistent strings", func(t *testing.T) { + collectRead, collectWrite, err := parseStateChangesTypesToCollect([]string{"Read", "read", "Write", "write"}) + require.NoError(t, err) + require.True(t, collectRead) + require.True(t, collectWrite) + + collectRead, collectWrite, err = parseStateChangesTypesToCollect([]string{"Read"}) + require.NoError(t, err) + require.True(t, collectRead) + require.False(t, collectWrite) + + collectRead, collectWrite, err = parseStateChangesTypesToCollect([]string{"Read", "rEaD"}) + require.NoError(t, err) + require.True(t, collectRead) + require.False(t, collectWrite) + + collectRead, collectWrite, err = parseStateChangesTypesToCollect([]string{"Write"}) + require.NoError(t, err) + require.False(t, collectRead) + require.True(t, collectWrite) + + collectRead, collectWrite, err = parseStateChangesTypesToCollect([]string{"Write", "write", "wRiTe"}) + require.NoError(t, err) + require.False(t, collectRead) + require.True(t, collectWrite) + }) + + t.Run("should parse state changes: no types", func(t *testing.T) { + t.Parallel() + + collectRead, collectWrite, err := parseStateChangesTypesToCollect([]string{}) + require.NoError(t, err) + require.False(t, collectRead) + require.False(t, collectWrite) + + collectRead, collectWrite, err = parseStateChangesTypesToCollect(nil) + require.NoError(t, err) + require.False(t, collectRead) + require.False(t, collectWrite) + }) + + t.Run("should not parse state changes: invalid types", func(t *testing.T) { + t.Parallel() + + collectRead, collectWrite, err := parseStateChangesTypesToCollect([]string{"r3ad", "writ3"}) + require.ErrorContains(t, err, "unknown action type") + require.False(t, collectRead) + require.False(t, collectWrite) + }) +} diff --git a/factory/status/export_test.go b/factory/status/export_test.go index e2a33e93a65..68825521d2d 100644 --- a/factory/status/export_test.go +++ b/factory/status/export_test.go @@ -2,6 +2,7 @@ package status import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/epochStart" outportDriverFactory "github.com/multiversx/mx-chain-go/outport/factory" "github.com/multiversx/mx-chain-go/p2p" diff --git a/go.mod b/go.mod index d879c1cf45b..e745e183008 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.5 github.com/mitchellh/mapstructure v1.5.0 github.com/multiversx/mx-chain-communication-go v1.1.0 - github.com/multiversx/mx-chain-core-go v1.2.23-0.20240918093335-b9e28fbed67c + github.com/multiversx/mx-chain-core-go v1.2.23-0.20240924120353-a1e60f8d53f0 github.com/multiversx/mx-chain-crypto-go v1.2.12 github.com/multiversx/mx-chain-es-indexer-go v1.7.5-0.20240807095116-4f2f595e52d9 github.com/multiversx/mx-chain-logger-go v1.0.15 diff --git a/go.sum b/go.sum index 3a6a35eb248..73921ea309d 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/multiversx/concurrent-map v0.1.4 h1:hdnbM8VE4b0KYJaGY5yJS2aNIW9TFFsUY github.com/multiversx/concurrent-map v0.1.4/go.mod h1:8cWFRJDOrWHOTNSqgYCUvwT7c7eFQ4U2vKMOp4A/9+o= github.com/multiversx/mx-chain-communication-go v1.1.0 h1:J7bX6HoN3HiHY7cUeEjG8AJWgQDDPcY+OPDOsSUOkRE= github.com/multiversx/mx-chain-communication-go v1.1.0/go.mod h1:WK6bP4pGEHGDDna/AYRIMtl6G9OA0NByI1Lw8PmOnRM= -github.com/multiversx/mx-chain-core-go v1.2.23-0.20240918093335-b9e28fbed67c h1:wPqkaTaiSnMXXGmnqJNtn+xMZUoJEAw16QgtrlUGSgk= -github.com/multiversx/mx-chain-core-go v1.2.23-0.20240918093335-b9e28fbed67c/go.mod h1:B5zU4MFyJezmEzCsAHE9YNULmGCm2zbPHvl9hazNxmE= +github.com/multiversx/mx-chain-core-go v1.2.23-0.20240924120353-a1e60f8d53f0 h1:5Bm6Hg5jO+OuwtRfmwQc8XGmw0z7tQJHIcrZ8IaBtQ4= +github.com/multiversx/mx-chain-core-go v1.2.23-0.20240924120353-a1e60f8d53f0/go.mod h1:B5zU4MFyJezmEzCsAHE9YNULmGCm2zbPHvl9hazNxmE= github.com/multiversx/mx-chain-crypto-go v1.2.12 h1:zWip7rpUS4CGthJxfKn5MZfMfYPjVjIiCID6uX5BSOk= github.com/multiversx/mx-chain-crypto-go v1.2.12/go.mod h1:HzcPpCm1zanNct/6h2rIh+MFrlXbjA5C8+uMyXj3LI4= github.com/multiversx/mx-chain-es-indexer-go v1.7.5-0.20240807095116-4f2f595e52d9 h1:VJOigTM9JbjFdy9ICVhsDfM9YQkFqMigAaQCHaM0iwY= diff --git a/integrationTests/vm/wasm/wasmvm/mockContracts.go b/integrationTests/vm/wasm/wasmvm/mockContracts.go index d8eeccdbd25..cd76450618a 100644 --- a/integrationTests/vm/wasm/wasmvm/mockContracts.go +++ b/integrationTests/vm/wasm/wasmvm/mockContracts.go @@ -8,22 +8,23 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" + "github.com/multiversx/mx-chain-scenario-go/worldmock" + "github.com/multiversx/mx-chain-vm-go/executor" + contextmock "github.com/multiversx/mx-chain-vm-go/mock/context" + "github.com/multiversx/mx-chain-vm-go/testcommon" + "github.com/multiversx/mx-chain-vm-go/vmhost" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" stateFactory "github.com/multiversx/mx-chain-go/state/factory" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" - worldmock "github.com/multiversx/mx-chain-scenario-go/worldmock" - "github.com/multiversx/mx-chain-vm-go/executor" - contextmock "github.com/multiversx/mx-chain-vm-go/mock/context" - "github.com/multiversx/mx-chain-vm-go/testcommon" - "github.com/multiversx/mx-chain-vm-go/vmhost" - "github.com/stretchr/testify/require" ) // MockInitialBalance represents a mock balance @@ -110,7 +111,7 @@ func GetAddressForNewAccountOnWalletAndNodeWithVM( Hasher: &hashingMocks.HasherMock{}, Marshaller: &marshallerMock.MarshalizerMock{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateChangesCollector: stateChanges.NewStateChangesCollector(), + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } accountFactory, _ := stateFactory.NewAccountCreator(argsAccCreation) diff --git a/outport/process/outportDataProvider.go b/outport/process/outportDataProvider.go index 0528dd6b06d..40ad13aade8 100644 --- a/outport/process/outportDataProvider.go +++ b/outport/process/outportDataProvider.go @@ -139,7 +139,10 @@ func (odp *outportDataProvider) PrepareOutportSaveBlockData(arg ArgPrepareOutpor return nil, err } - stateChanges := odp.stateChangesCollector.GetStateChangesForTxs() + stateChanges, err := odp.stateChangesCollector.Publish() + if err != nil { + return nil, fmt.Errorf("failed to publish state changes: %w", err) + } return &outportcore.OutportBlockWithHeaderAndBody{ OutportBlock: &outportcore.OutportBlock{ diff --git a/process/sync/trieIterators/tokensSuppliesProcessor_test.go b/process/sync/trieIterators/tokensSuppliesProcessor_test.go index d791e28fe66..284af34078a 100644 --- a/process/sync/trieIterators/tokensSuppliesProcessor_test.go +++ b/process/sync/trieIterators/tokensSuppliesProcessor_test.go @@ -9,12 +9,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-core-go/data/esdt" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" coreEsdt "github.com/multiversx/mx-chain-go/dblookupext/esdtSupply" "github.com/multiversx/mx-chain-go/state/accounts" "github.com/multiversx/mx-chain-go/state/parsers" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/state/trackableDataTrie" chainStorage "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -24,8 +26,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) func getTokensSuppliesProcessorArgs() ArgsTokensSuppliesProcessor { @@ -202,7 +202,7 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) dtlp, _ := parsers.NewDataTrieLeafParser([]byte("addr"), &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) userAcc, _ := accounts.NewUserAccount([]byte("addr"), dtt, dtlp) diff --git a/state/accountsDB.go b/state/accountsDB.go index 52cb1b87b51..31e7ebb3931 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -210,14 +210,14 @@ func (adb *AccountsDB) GetCode(codeHash []byte) []byte { return nil } - stateChange := &stateChange.StateChange{ - Type: "read", + sc := &stateChange.StateChange{ + Type: stateChange.Read, MainTrieKey: codeHash, MainTrieVal: val, - Operation: "getCode", + Operation: stateChange.GetCode, DataTrieChanges: nil, } - adb.stateChangesCollector.AddStateChange(stateChange) + adb.stateChangesCollector.AddStateChange(sc) err = adb.marshaller.Unmarshal(&codeEntry, val) if err != nil { @@ -285,15 +285,15 @@ func (adb *AccountsDB) SaveAccount(account vmcommon.AccountHandler) error { return err } - stateChange := &stateChange.StateChange{ - Type: "write", + sc := &stateChange.StateChange{ + Type: stateChange.Write, MainTrieKey: account.AddressBytes(), MainTrieVal: marshalledAccount, DataTrieChanges: newDataTrieValues, - Operation: "saveAccount", + Operation: stateChange.SaveAccount, } - adb.stateChangesCollector.AddSaveAccountStateChange(oldAccount, account, stateChange) + adb.stateChangesCollector.AddSaveAccountStateChange(oldAccount, account, sc) return nil } @@ -373,10 +373,10 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error } sc := &stateChange.StateChange{ - Type: "read", + Type: stateChange.Read, MainTrieKey: oldCodeHash, MainTrieVal: nil, - Operation: "getCode", + Operation: stateChange.GetCode, DataTrieChanges: nil, } adb.stateChangesCollector.AddStateChange(sc) @@ -393,10 +393,10 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error } sc1 := &stateChange.StateChange{ - Type: "write", + Type: stateChange.Write, MainTrieKey: oldCodeHash, MainTrieVal: nil, - Operation: "writeCode", + Operation: stateChange.WriteCode, DataTrieChanges: nil, } adb.stateChangesCollector.AddStateChange(sc1) @@ -411,10 +411,10 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error } sc = &stateChange.StateChange{ - Type: "write", + Type: stateChange.Write, MainTrieKey: oldCodeHash, MainTrieVal: codeEntryBytes, - Operation: "writeCode", + Operation: stateChange.WriteCode, DataTrieChanges: nil, } adb.stateChangesCollector.AddStateChange(sc) @@ -445,10 +445,10 @@ func (adb *AccountsDB) updateNewCodeEntry(newCodeHash []byte, newCode []byte) er } sc := &stateChange.StateChange{ - Type: "write", + Type: stateChange.Write, MainTrieKey: newCodeHash, MainTrieVal: codeEntryBytes, - Operation: "writeCode", + Operation: stateChange.WriteCode, DataTrieChanges: nil, } adb.stateChangesCollector.AddStateChange(sc) @@ -654,9 +654,9 @@ func (adb *AccountsDB) removeDataTrie(baseAcc baseAccountHandler) error { adb.journalize(entry) sc := &stateChange.StateChange{ - Type: "write", + Type: stateChange.Write, MainTrieKey: baseAcc.AddressBytes(), - Operation: "removeDataTrie", + Operation: stateChange.RemoveDataTrie, DataTrieChanges: nil, } adb.stateChangesCollector.AddStateChange(sc) @@ -720,10 +720,10 @@ func (adb *AccountsDB) getAccount(address []byte, mainTrie common.Trie) (vmcommo } stateChange := &stateChange.StateChange{ - Type: "read", + Type: stateChange.Read, MainTrieKey: address, MainTrieVal: val, - Operation: "getAccount", + Operation: stateChange.GetAccount, DataTrieChanges: nil, } adb.stateChangesCollector.AddStateChange(stateChange) @@ -921,8 +921,8 @@ func (adb *AccountsDB) commit() ([]byte, error) { log.Trace("accountsDB.Commit started") adb.entries = make([]JournalEntry, 0) - // TODO: evaluate moving this to procesing on CommitBlock - err := adb.stateChangesCollector.Publish() + // If the stateChangesCollector is configured in data analysis mode, it will persist the state changes locally + err := adb.stateChangesCollector.Store() if err != nil { return nil, err } diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index b316899cb97..58b339ffebe 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -180,7 +180,7 @@ func getDefaultStateComponents( StoragePruningManager: spm, AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, - StateChangesCollector: stateChanges.NewStateChangesCollector(), + StateChangesCollector: stateChanges.NewCollector(stateChanges.WithCollectWrite()), } adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -469,7 +469,8 @@ func stepCreateAccountWithDataTrieAndCode( serializedAcc, _ := marshaller.Marshal(userAcc) codeHash := userAcc.GetCodeHash() - stateChangesForTx := adb.ResetStateChangesCollector() + stateChangesForTx, err := adb.ResetStateChangesCollector() + assert.Nil(t, err) assert.Equal(t, 1, len(stateChangesForTx)) stateChanges := stateChangesForTx[string(txHash)].StateChanges @@ -519,7 +520,8 @@ func stepMigrateDataTrieValAndChangeCode( adb.SetTxHashForLatestStateChanges(txHash, &transaction.Transaction{}) - stateChangesForTx := adb.ResetStateChangesCollector() + stateChangesForTx, err := adb.ResetStateChangesCollector() + assert.Nil(t, err) assert.Equal(t, 1, len(stateChangesForTx)) assert.Equal(t, 3, len(stateChangesForTx[string(txHash)].StateChanges)) assert.Equal(t, txHash, stateChangesForTx[string(txHash)].StateChanges[0].GetTxHash()) diff --git a/state/disabled/disabledStateChangesCollector.go b/state/disabled/disabledStateChangesCollector.go index 297d89ce5dd..00027bac2cb 100644 --- a/state/disabled/disabledStateChangesCollector.go +++ b/state/disabled/disabledStateChangesCollector.go @@ -1,7 +1,7 @@ package disabled import ( - "github.com/multiversx/mx-chain-core-go/data/stateChange" + data "github.com/multiversx/mx-chain-core-go/data/stateChange" "github.com/multiversx/mx-chain-core-go/data/transaction" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -43,13 +43,13 @@ func (d *disabledStateChangesCollector) RevertToIndex(index int) error { return nil } -// Publish returns nil -func (d *disabledStateChangesCollector) Publish() error { - return nil +// Publish - +func (d *disabledStateChangesCollector) Publish() (map[string]*data.StateChanges, error) { + return nil, nil } -// GetStateChangesForTxs - -func (d *disabledStateChangesCollector) GetStateChangesForTxs() map[string]*stateChange.StateChanges { +// Store - +func (d *disabledStateChangesCollector) Store() error { return nil } diff --git a/state/export_test.go b/state/export_test.go index 8353abd92d4..629ca6a7e80 100644 --- a/state/export_test.go +++ b/state/export_test.go @@ -3,8 +3,9 @@ package state import ( data "github.com/multiversx/mx-chain-core-go/data/stateChange" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + + "github.com/multiversx/mx-chain-go/common" ) // LoadCode - @@ -28,12 +29,15 @@ func (adb *AccountsDB) GetObsoleteHashes() map[string][][]byte { } // ResetStateChangesCollector - -func (adb *AccountsDB) ResetStateChangesCollector() map[string]*data.StateChanges { - stateChanges := adb.stateChangesCollector.GetStateChangesForTxs() +func (adb *AccountsDB) ResetStateChangesCollector() (map[string]*data.StateChanges, error) { + stateChanges, err := adb.stateChangesCollector.Publish() + if err != nil { + return nil, err + } adb.stateChangesCollector.Reset() - return stateChanges + return stateChanges, nil } // GetCode - diff --git a/state/factory/accountCreator_test.go b/state/factory/accountCreator_test.go index 11490deae17..6f1ab0bbcca 100644 --- a/state/factory/accountCreator_test.go +++ b/state/factory/accountCreator_test.go @@ -4,13 +4,15 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/factory" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + stateMock "github.com/multiversx/mx-chain-go/testscommon/state" + "github.com/stretchr/testify/assert" ) @@ -19,7 +21,7 @@ func getDefaultArgs() factory.ArgsAccountCreator { Hasher: &hashingMocks.HasherMock{}, Marshaller: &marshallerMock.MarshalizerMock{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateChangesCollector: stateChanges.NewStateChangesCollector(), + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } } diff --git a/state/factory/accountsAdapterAPICreator_test.go b/state/factory/accountsAdapterAPICreator_test.go index fd524231e5d..13a0bae62eb 100644 --- a/state/factory/accountsAdapterAPICreator_test.go +++ b/state/factory/accountsAdapterAPICreator_test.go @@ -5,15 +5,16 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" mockState "github.com/multiversx/mx-chain-go/testscommon/state" + stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/storageManager" mockTrie "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" ) func createMockAccountsArgs() state.ArgsAccountsDB { @@ -29,7 +30,7 @@ func createMockAccountsArgs() state.ArgsAccountsDB { StoragePruningManager: &mockState.StoragePruningManagerStub{}, AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: &mockState.SnapshotsManagerStub{}, - StateChangesCollector: stateChanges.NewStateChangesCollector(), + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } } diff --git a/state/interface.go b/state/interface.go index 73ac6c9d50a..154ac92917c 100644 --- a/state/interface.go +++ b/state/interface.go @@ -362,22 +362,22 @@ type StateChangesCollector interface { AddStateChange(stateChange StateChange) AddSaveAccountStateChange(oldAccount, account vmcommon.AccountHandler, stateChange StateChange) Reset() + Publish() (map[string]*data.StateChanges, error) + Store() error AddTxHashToCollectedStateChanges(txHash []byte, tx *transaction.Transaction) SetIndexToLastStateChange(index int) error RevertToIndex(index int) error - Publish() error IsInterfaceNil() bool - GetStateChangesForTxs() map[string]*data.StateChanges } // StateChange defines the behaviour of a state change holder type StateChange interface { - GetType() string + GetType() data.ActionType GetIndex() int32 GetTxHash() []byte GetMainTrieKey() []byte GetMainTrieVal() []byte - GetOperation() string + GetOperation() data.Operation GetDataTrieChanges() []*data.DataTrieChange SetTxHash(txHash []byte) diff --git a/state/stateChanges/collector.go b/state/stateChanges/collector.go new file mode 100644 index 00000000000..1994c8cd73a --- /dev/null +++ b/state/stateChanges/collector.go @@ -0,0 +1,267 @@ +package stateChanges + +import ( + "bytes" + "encoding/json" + "fmt" + "sync" + + data "github.com/multiversx/mx-chain-core-go/data/stateChange" + "github.com/multiversx/mx-chain-core-go/data/transaction" + logger "github.com/multiversx/mx-chain-logger-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/storage" +) + +var log = logger.GetOrCreate("state/stateChanges") + +// StateChangesForTx is used to collect state changes for a transaction hash +type StateChangesForTx struct { + TxHash []byte `json:"txHash"` + StateChanges []state.StateChange `json:"stateChanges"` +} + +type collector struct { + collectRead bool + collectWrite bool + stateChanges []state.StateChange + stateChangesMut sync.RWMutex + cachedTxs map[string]*transaction.Transaction + storer storage.Persister +} + +// NewCollector will collect based on the options the state changes. +func NewCollector(opts ...CollectorOption) *collector { + c := &collector{stateChanges: make([]state.StateChange, 0)} + for _, opt := range opts { + opt(c) + } + + if c.storer != nil { + c.cachedTxs = make(map[string]*transaction.Transaction) + } + + return c +} + +// AddStateChange adds a new state change to the collector +func (c *collector) AddStateChange(stateChange state.StateChange) { + c.stateChangesMut.Lock() + defer c.stateChangesMut.Unlock() + + if stateChange.GetType() == data.Write && c.collectWrite { + c.stateChanges = append(c.stateChanges, stateChange) + } + + if stateChange.GetType() == data.Read && c.collectRead { + c.stateChanges = append(c.stateChanges, stateChange) + } +} + +// AddSaveAccountStateChange adds a new state change for the save account operation +func (c *collector) AddSaveAccountStateChange(oldAccount, account vmcommon.AccountHandler, stateChange state.StateChange) { + if c.storer != nil { + dataAnalysisStateChange := &dataAnalysisStateChangeDTO{ + StateChange: stateChange, + } + + checkAccountChanges(oldAccount, account, dataAnalysisStateChange) + + c.AddStateChange(dataAnalysisStateChange) + return + } + + c.AddStateChange(stateChange) +} + +// Reset resets the state changes collector +func (c *collector) Reset() { + c.stateChangesMut.Lock() + defer c.stateChangesMut.Unlock() + + c.stateChanges = make([]state.StateChange, 0) + if c.storer != nil { + c.cachedTxs = make(map[string]*transaction.Transaction) + } +} + +// Publish will export state changes +func (c *collector) Publish() (map[string]*data.StateChanges, error) { + c.stateChangesMut.RLock() + defer c.stateChangesMut.RUnlock() + + stateChangesForTxs := make(map[string]*data.StateChanges) + for _, stateChange := range c.stateChanges { + txHash := string(stateChange.GetTxHash()) + + st, ok := stateChange.(*data.StateChange) + if !ok { + continue + } + + _, ok = stateChangesForTxs[txHash] + if !ok { + stateChangesForTxs[txHash] = &data.StateChanges{ + StateChanges: []*data.StateChange{st}, + } + } else { + stateChangesForTxs[txHash].StateChanges = append(stateChangesForTxs[txHash].StateChanges, st) + } + } + + return stateChangesForTxs, nil +} + +// Store will store the collected state changes if it has been configured with a storer +func (c *collector) Store() error { + if c.storer != nil { + return nil + } + + stateChangesForTx, err := c.getDataAnalysisStateChangesForTxs() + if err != nil { + return fmt.Errorf("failed to retrieve data analysis state changes for tx: %w", err) + } + + for _, stateChange := range stateChangesForTx { + marshalledData, err := json.Marshal(stateChange) + if err != nil { + return fmt.Errorf("failed to marshal state changes to JSON: %w", err) + } + + err = c.storer.Put(stateChange.TxHash, marshalledData) + if err != nil { + return fmt.Errorf("failed to store marshalled data: %w", err) + } + } + + return nil +} + +// AddTxHashToCollectedStateChanges will try to set txHash field to each state change +// if the field is not already set +func (c *collector) AddTxHashToCollectedStateChanges(txHash []byte, tx *transaction.Transaction) { + c.stateChangesMut.Lock() + defer c.stateChangesMut.Unlock() + + if c.storer != nil { + c.cachedTxs[string(txHash)] = tx + } + + for i := len(c.stateChanges) - 1; i >= 0; i-- { + if len(c.stateChanges[i].GetTxHash()) > 0 { + break + } + + c.stateChanges[i].SetTxHash(txHash) + } +} + +// SetIndexToLastStateChange will set index to the last state change +func (c *collector) SetIndexToLastStateChange(index int) error { + c.stateChangesMut.Lock() + defer c.stateChangesMut.Unlock() + + if index > len(c.stateChanges) || index < 0 { + return state.ErrStateChangesIndexOutOfBounds + } + + if len(c.stateChanges) == 0 { + return nil + } + + c.stateChanges[len(c.stateChanges)-1].SetIndex(int32(index)) + + return nil +} + +// RevertToIndex will revert to index +func (c *collector) RevertToIndex(index int) error { + if index > len(c.stateChanges) || index < 0 { + return state.ErrStateChangesIndexOutOfBounds + } + + if index == 0 { + c.Reset() + return nil + } + + c.stateChangesMut.Lock() + defer c.stateChangesMut.Unlock() + + for i := len(c.stateChanges) - 1; i >= 0; i-- { + if c.stateChanges[i].GetIndex() == int32(index) { + c.stateChanges = c.stateChanges[:i] + break + } + } + + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (c *collector) IsInterfaceNil() bool { + return c == nil +} + +func (c *collector) getStateChangesForTxs() ([]StateChangesForTx, error) { + c.stateChangesMut.Lock() + defer c.stateChangesMut.Unlock() + + stateChangesForTxs := make([]StateChangesForTx, 0) + + for i := 0; i < len(c.stateChanges); i++ { + txHash := c.stateChanges[i].GetTxHash() + + if len(txHash) == 0 { + log.Warn("empty tx hash, state change event not associated to a transaction") + continue + } + + innerStateChangesForTx := make([]state.StateChange, 0) + for j := i; j < len(c.stateChanges); j++ { + txHash2 := c.stateChanges[j].GetTxHash() + if !bytes.Equal(txHash, txHash2) { + i = j + break + } + + innerStateChangesForTx = append(innerStateChangesForTx, c.stateChanges[j]) + i = j + } + + stateChangesForTx := StateChangesForTx{ + TxHash: txHash, + StateChanges: innerStateChangesForTx, + } + stateChangesForTxs = append(stateChangesForTxs, stateChangesForTx) + } + + return stateChangesForTxs, nil +} + +func (c *collector) getDataAnalysisStateChangesForTxs() ([]dataAnalysisStateChangesForTx, error) { + stateChangesForTxs, err := c.getStateChangesForTxs() + if err != nil { + return nil, err + } + + dataAnalysisStateChangesForTxs := make([]dataAnalysisStateChangesForTx, 0) + + for _, stateChangeForTx := range stateChangesForTxs { + cachedTx, txOk := c.cachedTxs[string(stateChangeForTx.TxHash)] + if !txOk { + return nil, fmt.Errorf("did not find tx in cache") + } + + stateChangesForTx := dataAnalysisStateChangesForTx{ + StateChangesForTx: stateChangeForTx, + Tx: cachedTx, + } + dataAnalysisStateChangesForTxs = append(dataAnalysisStateChangesForTxs, stateChangesForTx) + } + + return dataAnalysisStateChangesForTxs, nil +} diff --git a/state/stateChanges/collectorOptions.go b/state/stateChanges/collectorOptions.go new file mode 100644 index 00000000000..d97a9584e5c --- /dev/null +++ b/state/stateChanges/collectorOptions.go @@ -0,0 +1,29 @@ +package stateChanges + +import ( + "github.com/multiversx/mx-chain-go/storage" +) + +// CollectorOption specifies the possible options for the collector +type CollectorOption func(*collector) + +// WithCollectRead will enable collecting read action types +func WithCollectRead() func(c *collector) { + return func(c *collector) { + c.collectRead = true + } +} + +// WithCollectWrite will enable collecting write action types +func WithCollectWrite() func(c *collector) { + return func(c *collector) { + c.collectWrite = true + } +} + +// WithStorer will enable storing action types +func WithStorer(storer storage.Persister) func(c *collector) { + return func(c *collector) { + c.storer = storer + } +} diff --git a/state/stateChanges/collector_test.go b/state/stateChanges/collector_test.go new file mode 100644 index 00000000000..4c2730331a0 --- /dev/null +++ b/state/stateChanges/collector_test.go @@ -0,0 +1,635 @@ +package stateChanges + +import ( + "fmt" + "math/big" + "strconv" + "testing" + + data "github.com/multiversx/mx-chain-core-go/data/stateChange" + "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/storage/mock" + mockState "github.com/multiversx/mx-chain-go/testscommon/state" + + "github.com/stretchr/testify/require" +) + +func getWriteStateChange() *data.StateChange { + return &data.StateChange{ + Type: data.Write, + } +} + +func getReadStateChange() *data.StateChange { + return &data.StateChange{ + Type: data.Read, + } +} + +func TestNewStateChangesCollector(t *testing.T) { + t.Parallel() + + stateChangesCollector := NewCollector() + require.False(t, stateChangesCollector.IsInterfaceNil()) +} + +func TestStateChangesCollector_AddStateChange(t *testing.T) { + t.Parallel() + + t.Run("default collector", func(t *testing.T) { + t.Parallel() + + c := NewCollector() + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getWriteStateChange()) + } + assert.Equal(t, 0, len(c.stateChanges)) + }) + + t.Run("collect only write", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getWriteStateChange()) + } + + c.AddStateChange(getReadStateChange()) + assert.Equal(t, numStateChanges, len(c.stateChanges)) + }) + + t.Run("collect only read", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectRead()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getReadStateChange()) + } + + c.AddStateChange(getWriteStateChange()) + assert.Equal(t, numStateChanges, len(c.stateChanges)) + }) + + t.Run("collect both read and write", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectRead(), WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + if i%2 == 0 { + c.AddStateChange(getReadStateChange()) + } else { + c.AddStateChange(getWriteStateChange()) + } + } + assert.Equal(t, numStateChanges, len(c.stateChanges)) + }) +} + +func TestStateChangesCollector_GetStateChanges(t *testing.T) { + t.Parallel() + + t.Run("getStateChanges with tx hash", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(&data.StateChange{ + Type: data.Write, + MainTrieKey: []byte(strconv.Itoa(i)), + }) + } + assert.Equal(t, numStateChanges, len(c.stateChanges)) + stateChangesForTxs, _ := c.getStateChangesForTxs() + assert.Equal(t, 0, len(stateChangesForTxs)) + c.AddTxHashToCollectedStateChanges([]byte("txHash"), &transaction.Transaction{}) + assert.Equal(t, numStateChanges, len(c.stateChanges)) + assert.Equal(t, 1, len(c.GetStateChanges())) + assert.Equal(t, []byte("txHash"), c.GetStateChanges()[0].TxHash) + assert.Equal(t, numStateChanges, len(c.GetStateChanges()[0].StateChanges)) + + stateChangesForTx := c.GetStateChanges() + assert.Equal(t, 1, len(stateChangesForTx)) + assert.Equal(t, []byte("txHash"), stateChangesForTx[0].TxHash) + for i := 0; i < len(stateChangesForTx[0].StateChanges); i++ { + sc, ok := stateChangesForTx[0].StateChanges[i].(*data.StateChange) + require.True(t, ok) + + assert.Equal(t, []byte(strconv.Itoa(i)), sc.MainTrieKey) + } + }) + + t.Run("getStateChanges without tx hash", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite(), WithStorer(&mock.PersisterStub{})) + assert.Equal(t, 0, len(c.stateChanges)) + assert.Equal(t, 0, len(c.GetStateChanges())) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(&data.StateChange{ + Type: data.Write, + MainTrieKey: []byte(strconv.Itoa(i)), + }) + } + assert.Equal(t, numStateChanges, len(c.stateChanges)) + assert.Equal(t, 0, len(c.GetStateChanges())) + + stateChangesForTx := c.GetStateChanges() + assert.Equal(t, 0, len(stateChangesForTx)) + }) +} + +func TestStateChangesCollector_AddTxHashToCollectedStateChanges(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + assert.Equal(t, 0, len(c.GetStateChanges())) + + c.AddTxHashToCollectedStateChanges([]byte("txHash0"), &transaction.Transaction{}) + + stateChange := &data.StateChange{ + Type: data.Write, + MainTrieKey: []byte("mainTrieKey"), + MainTrieVal: []byte("mainTrieVal"), + DataTrieChanges: []*data.DataTrieChange{{Key: []byte("dataTrieKey"), Val: []byte("dataTrieVal")}}, + } + c.AddStateChange(stateChange) + + assert.Equal(t, 1, len(c.stateChanges)) + assert.Equal(t, 0, len(c.GetStateChanges())) + c.AddTxHashToCollectedStateChanges([]byte("txHash"), &transaction.Transaction{}) + assert.Equal(t, 1, len(c.stateChanges)) + assert.Equal(t, 1, len(c.GetStateChanges())) + + stateChangesForTx := c.GetStateChanges() + assert.Equal(t, 1, len(stateChangesForTx)) + assert.Equal(t, []byte("txHash"), stateChangesForTx[0].TxHash) + assert.Equal(t, 1, len(stateChangesForTx[0].StateChanges)) + + sc, ok := stateChangesForTx[0].StateChanges[0].(*data.StateChange) + require.True(t, ok) + + assert.Equal(t, []byte("mainTrieKey"), sc.MainTrieKey) + assert.Equal(t, []byte("mainTrieVal"), sc.MainTrieVal) + assert.Equal(t, 1, len(sc.DataTrieChanges)) +} + +func TestStateChangesCollector_RevertToIndex_FailIfWrongIndex(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + numStateChanges := len(c.stateChanges) + + err := c.RevertToIndex(-1) + require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) + + err = c.RevertToIndex(numStateChanges + 1) + require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) +} + +func TestStateChangesCollector_RevertToIndex(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getWriteStateChange()) + err := c.SetIndexToLastStateChange(i) + require.Nil(t, err) + } + c.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) + + for i := numStateChanges; i < numStateChanges*2; i++ { + c.AddStateChange(getWriteStateChange()) + c.AddTxHashToCollectedStateChanges([]byte("txHash"+fmt.Sprintf("%d", i)), &transaction.Transaction{}) + } + err := c.SetIndexToLastStateChange(numStateChanges) + require.Nil(t, err) + + assert.Equal(t, numStateChanges*2, len(c.stateChanges)) + + err = c.RevertToIndex(numStateChanges) + require.Nil(t, err) + assert.Equal(t, numStateChanges*2-1, len(c.stateChanges)) + + err = c.RevertToIndex(numStateChanges - 1) + require.Nil(t, err) + assert.Equal(t, numStateChanges-1, len(c.stateChanges)) + + err = c.RevertToIndex(numStateChanges / 2) + require.Nil(t, err) + assert.Equal(t, numStateChanges/2, len(c.stateChanges)) + + err = c.RevertToIndex(1) + require.Nil(t, err) + assert.Equal(t, 1, len(c.stateChanges)) + + err = c.RevertToIndex(0) + require.Nil(t, err) + assert.Equal(t, 0, len(c.stateChanges)) +} + +func TestStateChangesCollector_SetIndexToLastStateChange(t *testing.T) { + t.Parallel() + + t.Run("should fail if valid index", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + + err := c.SetIndexToLastStateChange(-1) + require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) + + numStateChanges := len(c.stateChanges) + err = c.SetIndexToLastStateChange(numStateChanges + 1) + require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getWriteStateChange()) + err := c.SetIndexToLastStateChange(i) + require.Nil(t, err) + } + c.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) + + for i := numStateChanges; i < numStateChanges*2; i++ { + c.AddStateChange(getWriteStateChange()) + c.AddTxHashToCollectedStateChanges([]byte("txHash"+fmt.Sprintf("%d", i)), &transaction.Transaction{}) + } + err := c.SetIndexToLastStateChange(numStateChanges) + require.Nil(t, err) + + assert.Equal(t, numStateChanges*2, len(c.stateChanges)) + }) +} + +func TestStateChangesCollector_Reset(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getWriteStateChange()) + } + c.AddTxHashToCollectedStateChanges([]byte("txHash"), &transaction.Transaction{}) + for i := numStateChanges; i < numStateChanges*2; i++ { + c.AddStateChange(getWriteStateChange()) + } + assert.Equal(t, numStateChanges*2, len(c.stateChanges)) + + assert.Equal(t, 1, len(c.GetStateChanges())) + + c.Reset() + assert.Equal(t, 0, len(c.stateChanges)) + + assert.Equal(t, 0, len(c.GetStateChanges())) +} + +func TestStateChangesCollector_Publish(t *testing.T) { + t.Parallel() + + t.Run("collect only write", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 20 + for i := 0; i < numStateChanges; i++ { + if i%2 == 0 { + c.AddStateChange(&data.StateChange{ + Type: data.Write, + // distribute evenly based on parity of the index + TxHash: []byte(fmt.Sprintf("hash%d", i%2)), + }) + } else { + c.AddStateChange(&data.StateChange{ + Type: data.Read, + // distribute evenly based on parity of the index + TxHash: []byte(fmt.Sprintf("hash%d", i%2)), + }) + } + } + + stateChangesForTx, err := c.Publish() + require.NoError(t, err) + + require.Len(t, stateChangesForTx, 1) + require.Len(t, stateChangesForTx["hash0"].StateChanges, 10) + + require.Equal(t, stateChangesForTx, map[string]*data.StateChanges{ + "hash0": { + StateChanges: []*data.StateChange{ + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + }, + }, + }) + }) + + t.Run("collect only read", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectRead()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 20 + for i := 0; i < numStateChanges; i++ { + if i%2 == 0 { + c.AddStateChange(&data.StateChange{ + Type: data.Write, + // distribute evenly based on parity of the index + TxHash: []byte(fmt.Sprintf("hash%d", i%2)), + }) + } else { + c.AddStateChange(&data.StateChange{ + Type: data.Read, + // distribute evenly based on parity of the index + TxHash: []byte(fmt.Sprintf("hash%d", i%2)), + }) + } + } + + stateChangesForTx, err := c.Publish() + require.NoError(t, err) + + require.Len(t, stateChangesForTx, 1) + require.Len(t, stateChangesForTx["hash1"].StateChanges, 10) + + require.Equal(t, stateChangesForTx, map[string]*data.StateChanges{ + "hash1": { + StateChanges: []*data.StateChange{ + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + }, + }, + }) + }) + + t.Run("collect both read and write", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectRead(), WithCollectWrite()) + assert.Equal(t, 0, len(c.stateChanges)) + + numStateChanges := 20 + for i := 0; i < numStateChanges; i++ { + if i%2 == 0 { + c.AddStateChange(&data.StateChange{ + Type: data.Write, + // distribute evenly based on parity of the index + TxHash: []byte(fmt.Sprintf("hash%d", i%2)), + }) + } else { + c.AddStateChange(&data.StateChange{ + Type: data.Read, + // distribute evenly based on parity of the index + TxHash: []byte(fmt.Sprintf("hash%d", i%2)), + }) + } + } + + stateChangesForTx, err := c.Publish() + require.NoError(t, err) + + require.Len(t, stateChangesForTx, 2) + require.Len(t, stateChangesForTx["hash0"].StateChanges, 10) + require.Len(t, stateChangesForTx["hash1"].StateChanges, 10) + + require.Equal(t, stateChangesForTx, map[string]*data.StateChanges{ + "hash0": { + StateChanges: []*data.StateChange{ + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + {Type: data.Write, TxHash: []byte("hash0")}, + }, + }, + "hash1": { + StateChanges: []*data.StateChange{ + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + {Type: data.Read, TxHash: []byte("hash1")}, + }, + }, + }) + }) +} + +func TestNewDataAnalysisCollector(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithStorer(&mock.PersisterStub{})) + require.False(t, c.IsInterfaceNil()) + require.NotNil(t, c.storer) + }) +} + +func TestDataAnalysisStateChangesCollector_AddSaveAccountStateChange(t *testing.T) { + t.Parallel() + + t.Run("nil old account should return early", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite(), WithStorer(&mock.PersisterStub{})) + + c.AddSaveAccountStateChange( + nil, + &mockState.UserAccountStub{}, + &data.StateChange{ + Type: data.Write, + Index: 2, + TxHash: []byte("txHash1"), + MainTrieKey: []byte("key1"), + }, + ) + + c.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) + + stateChangesForTx := c.GetStateChanges() + require.Equal(t, 1, len(stateChangesForTx)) + + sc := stateChangesForTx[0].StateChanges[0] + dasc, ok := sc.(*dataAnalysisStateChangeDTO) + require.True(t, ok) + + require.False(t, dasc.Nonce) + require.False(t, dasc.Balance) + require.False(t, dasc.CodeHash) + require.False(t, dasc.RootHash) + require.False(t, dasc.DeveloperReward) + require.False(t, dasc.OwnerAddress) + require.False(t, dasc.UserName) + require.False(t, dasc.CodeMetadata) + }) + + t.Run("nil new account should return early", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite(), WithStorer(&mock.PersisterStub{})) + + c.AddSaveAccountStateChange( + &mockState.UserAccountStub{}, + nil, + &data.StateChange{ + Type: data.Write, + Index: 2, + TxHash: []byte("txHash1"), + MainTrieKey: []byte("key1"), + }, + ) + + c.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) + + stateChangesForTx := c.GetStateChanges() + require.Equal(t, 1, len(stateChangesForTx)) + + sc := stateChangesForTx[0].StateChanges[0] + dasc, ok := sc.(*dataAnalysisStateChangeDTO) + require.True(t, ok) + + require.False(t, dasc.Nonce) + require.False(t, dasc.Balance) + require.False(t, dasc.CodeHash) + require.False(t, dasc.RootHash) + require.False(t, dasc.DeveloperReward) + require.False(t, dasc.OwnerAddress) + require.False(t, dasc.UserName) + require.False(t, dasc.CodeMetadata) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite(), WithStorer(&mock.PersisterStub{})) + + c.AddSaveAccountStateChange( + &mockState.UserAccountStub{ + Nonce: 0, + Balance: big.NewInt(0), + DeveloperRewards: big.NewInt(0), + UserName: []byte{0}, + Owner: []byte{0}, + Address: []byte{0}, + CodeMetadata: []byte{0}, + CodeHash: []byte{0}, + GetRootHashCalled: func() []byte { + return []byte{0} + }, + }, + &mockState.UserAccountStub{ + Nonce: 1, + Balance: big.NewInt(1), + DeveloperRewards: big.NewInt(1), + UserName: []byte{1}, + Owner: []byte{1}, + Address: []byte{1}, + CodeMetadata: []byte{1}, + CodeHash: []byte{1}, + GetRootHashCalled: func() []byte { + return []byte{1} + }, + }, + &data.StateChange{ + Type: data.Write, + Index: 2, + TxHash: []byte("txHash1"), + MainTrieKey: []byte("key1"), + }, + ) + + c.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) + + stateChangesForTx := c.GetStateChanges() + require.Equal(t, 1, len(stateChangesForTx)) + + sc := stateChangesForTx[0].StateChanges[0] + dasc, ok := sc.(*dataAnalysisStateChangeDTO) + require.True(t, ok) + + require.True(t, dasc.Nonce) + require.True(t, dasc.Balance) + require.True(t, dasc.CodeHash) + require.True(t, dasc.RootHash) + require.True(t, dasc.DeveloperReward) + require.True(t, dasc.OwnerAddress) + require.True(t, dasc.UserName) + require.True(t, dasc.CodeMetadata) + }) +} + +func TestDataAnalysisStateChangesCollector_Reset(t *testing.T) { + t.Parallel() + + c := NewCollector(WithCollectWrite(), WithStorer(&mock.PersisterStub{})) + + numStateChanges := 10 + for i := 0; i < numStateChanges; i++ { + c.AddStateChange(getWriteStateChange()) + } + require.Equal(t, numStateChanges, len(c.stateChanges)) + + c.Reset() + require.Equal(t, 0, len(c.GetStateChanges())) +} diff --git a/state/stateChanges/dataAnalysisCollector.go b/state/stateChanges/dataAnalysisCollector.go index 481eb8e9fc6..233242ec795 100644 --- a/state/stateChanges/dataAnalysisCollector.go +++ b/state/stateChanges/dataAnalysisCollector.go @@ -2,16 +2,12 @@ package stateChanges import ( "bytes" - "encoding/json" - "fmt" "math/big" - "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/transaction" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/storage" ) type dataAnalysisStateChangeDTO struct { @@ -43,37 +39,6 @@ type userAccountHandler interface { vmcommon.AccountHandler } -type dataAnalysisCollector struct { - *stateChangesCollector - - cachedTxs map[string]*transaction.Transaction - storer storage.Persister -} - -// NewDataAnalysisStateChangesCollector will create a new instance of data analysis collector -func NewDataAnalysisStateChangesCollector(storer storage.Persister) (*dataAnalysisCollector, error) { - if check.IfNil(storer) { - return nil, storage.ErrNilPersister - } - - return &dataAnalysisCollector{ - stateChangesCollector: NewStateChangesCollector(), - cachedTxs: make(map[string]*transaction.Transaction), - storer: storer, - }, nil -} - -// AddSaveAccountStateChange adds a new state change for the save account operation -func (scc *dataAnalysisCollector) AddSaveAccountStateChange(oldAccount, account vmcommon.AccountHandler, stateChange state.StateChange) { - dataAnalysisStateChange := &dataAnalysisStateChangeDTO{ - StateChange: stateChange, - } - - checkAccountChanges(oldAccount, account, dataAnalysisStateChange) - - scc.AddStateChange(dataAnalysisStateChange) -} - func checkAccountChanges(oldAcc, newAcc vmcommon.AccountHandler, stateChange *dataAnalysisStateChangeDTO) { baseNewAcc, newAccOk := newAcc.(userAccountHandler) if !newAccOk { @@ -116,76 +81,3 @@ func checkAccountChanges(oldAcc, newAcc vmcommon.AccountHandler, stateChange *da stateChange.CodeMetadata = true } } - -// AddStateChange adds a new state change to the collector -func (scc *dataAnalysisCollector) AddStateChange(stateChange state.StateChange) { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - scc.stateChanges = append(scc.stateChanges, stateChange) -} - -func (scc *dataAnalysisCollector) getDataAnalysisStateChangesForTxs() ([]dataAnalysisStateChangesForTx, error) { - stateChangesForTxs, err := scc.getStateChangesForTxs() - if err != nil { - return nil, err - } - - dataAnalysisStateChangesForTxs := make([]dataAnalysisStateChangesForTx, 0) - - for _, stateChangeForTx := range stateChangesForTxs { - cachedTx, txOk := scc.cachedTxs[string(stateChangeForTx.TxHash)] - if !txOk { - return nil, fmt.Errorf("did not find tx in cache") - } - - stateChangesForTx := dataAnalysisStateChangesForTx{ - StateChangesForTx: stateChangeForTx, - Tx: cachedTx, - } - dataAnalysisStateChangesForTxs = append(dataAnalysisStateChangesForTxs, stateChangesForTx) - } - - return dataAnalysisStateChangesForTxs, nil -} - -func (scc *dataAnalysisCollector) AddTxHashToCollectedStateChanges(txHash []byte, tx *transaction.Transaction) { - scc.cachedTxs[string(txHash)] = tx - scc.addTxHashToCollectedStateChanges(txHash) -} - -// Reset resets the state changes collector -func (scc *dataAnalysisCollector) Reset() { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - scc.resetStateChangesUnprotected() - scc.cachedTxs = make(map[string]*transaction.Transaction) -} - -// Publish will export state changes -func (scc *dataAnalysisCollector) Publish() error { - stateChangesForTx, err := scc.getDataAnalysisStateChangesForTxs() - if err != nil { - return err - } - - for _, stateChange := range stateChangesForTx { - marshalledData, err := json.Marshal(stateChange) - if err != nil { - return err - } - - err = scc.storer.Put(stateChange.TxHash, marshalledData) - if err != nil { - return err - } - } - - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (scc *dataAnalysisCollector) IsInterfaceNil() bool { - return scc == nil -} diff --git a/state/stateChanges/dataAnalysisCollector_test.go b/state/stateChanges/dataAnalysisCollector_test.go deleted file mode 100644 index 1d33d0e79d4..00000000000 --- a/state/stateChanges/dataAnalysisCollector_test.go +++ /dev/null @@ -1,206 +0,0 @@ -package stateChanges - -import ( - "math/big" - "testing" - - data "github.com/multiversx/mx-chain-core-go/data/stateChange" - "github.com/multiversx/mx-chain-core-go/data/transaction" - "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/storage/mock" - "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/stretchr/testify/require" -) - -func TestNewDataAnalysisCollector(t *testing.T) { - t.Parallel() - - t.Run("nil storer", func(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(nil) - require.Nil(t, dsc) - require.Equal(t, storage.ErrNilPersister, err) - }) - - t.Run("should work", func(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(&mock.PersisterStub{}) - require.Nil(t, err) - require.False(t, dsc.IsInterfaceNil()) - }) -} - -func TestDataAnalysisStateChangesCollector_AddStateChange(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(&mock.PersisterStub{}) - require.Nil(t, err) - - require.Equal(t, 0, len(dsc.stateChanges)) - - dsc.AddStateChange(&data.StateChange{ - Type: "write", - }) - dsc.AddStateChange(&data.StateChange{ - Type: "read", - }) - dsc.AddStateChange(&data.StateChange{ - Type: "write", - }) - - require.Equal(t, 3, len(dsc.stateChanges)) -} - -func TestDataAnalysisStateChangesCollector_AddSaveAccountStateChange(t *testing.T) { - t.Parallel() - - t.Run("nil old account should return early", func(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(&mock.PersisterStub{}) - require.Nil(t, err) - - dsc.AddSaveAccountStateChange( - nil, - &state.UserAccountStub{}, - &data.StateChange{ - Type: "saveAccount", - Index: 2, - TxHash: []byte("txHash1"), - MainTrieKey: []byte("key1"), - }, - ) - - dsc.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) - - stateChangesForTx := dsc.GetStateChanges() - require.Equal(t, 1, len(stateChangesForTx)) - - sc := stateChangesForTx[0].StateChanges[0] - dasc, ok := sc.(*dataAnalysisStateChangeDTO) - require.True(t, ok) - - require.False(t, dasc.Nonce) - require.False(t, dasc.Balance) - require.False(t, dasc.CodeHash) - require.False(t, dasc.RootHash) - require.False(t, dasc.DeveloperReward) - require.False(t, dasc.OwnerAddress) - require.False(t, dasc.UserName) - require.False(t, dasc.CodeMetadata) - }) - - t.Run("nil new account should return early", func(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(&mock.PersisterStub{}) - require.Nil(t, err) - - dsc.AddSaveAccountStateChange( - &state.UserAccountStub{}, - nil, - &data.StateChange{ - Type: "saveAccount", - Index: 2, - TxHash: []byte("txHash1"), - MainTrieKey: []byte("key1"), - }, - ) - - dsc.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) - - stateChangesForTx := dsc.GetStateChanges() - require.Equal(t, 1, len(stateChangesForTx)) - - sc := stateChangesForTx[0].StateChanges[0] - dasc, ok := sc.(*dataAnalysisStateChangeDTO) - require.True(t, ok) - - require.False(t, dasc.Nonce) - require.False(t, dasc.Balance) - require.False(t, dasc.CodeHash) - require.False(t, dasc.RootHash) - require.False(t, dasc.DeveloperReward) - require.False(t, dasc.OwnerAddress) - require.False(t, dasc.UserName) - require.False(t, dasc.CodeMetadata) - }) - - t.Run("should work", func(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(&mock.PersisterStub{}) - require.Nil(t, err) - - dsc.AddSaveAccountStateChange( - &state.UserAccountStub{ - Nonce: 0, - Balance: big.NewInt(0), - DeveloperRewards: big.NewInt(0), - UserName: []byte{0}, - Owner: []byte{0}, - Address: []byte{0}, - CodeMetadata: []byte{0}, - CodeHash: []byte{0}, - GetRootHashCalled: func() []byte { - return []byte{0} - }, - }, - &state.UserAccountStub{ - Nonce: 1, - Balance: big.NewInt(1), - DeveloperRewards: big.NewInt(1), - UserName: []byte{1}, - Owner: []byte{1}, - Address: []byte{1}, - CodeMetadata: []byte{1}, - CodeHash: []byte{1}, - GetRootHashCalled: func() []byte { - return []byte{1} - }, - }, - &data.StateChange{ - Type: "saveAccount", - Index: 2, - TxHash: []byte("txHash1"), - MainTrieKey: []byte("key1"), - }, - ) - - dsc.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) - - stateChangesForTx := dsc.GetStateChanges() - require.Equal(t, 1, len(stateChangesForTx)) - - sc := stateChangesForTx[0].StateChanges[0] - dasc, ok := sc.(*dataAnalysisStateChangeDTO) - require.True(t, ok) - - require.True(t, dasc.Nonce) - require.True(t, dasc.Balance) - require.True(t, dasc.CodeHash) - require.True(t, dasc.RootHash) - require.True(t, dasc.DeveloperReward) - require.True(t, dasc.OwnerAddress) - require.True(t, dasc.UserName) - require.True(t, dasc.CodeMetadata) - }) -} - -func TestDataAnalysisStateChangesCollector_Reset(t *testing.T) { - t.Parallel() - - dsc, err := NewDataAnalysisStateChangesCollector(&mock.PersisterStub{}) - require.Nil(t, err) - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - dsc.AddStateChange(getDefaultStateChange()) - } - require.Equal(t, numStateChanges, len(dsc.stateChanges)) - - dsc.Reset() - require.Equal(t, 0, len(dsc.GetStateChanges())) -} diff --git a/state/stateChanges/export_test.go b/state/stateChanges/export_test.go index ed4b541ee20..963b6c812b2 100644 --- a/state/stateChanges/export_test.go +++ b/state/stateChanges/export_test.go @@ -1,13 +1,13 @@ package stateChanges // GetStateChanges - -func (scc *stateChangesCollector) GetStateChanges() []StateChangesForTx { - scs, _ := scc.getStateChangesForTxs() +func (c *collector) GetStateChanges() []StateChangesForTx { + scs, _ := c.getStateChangesForTxs() return scs } // GetStateChanges - -func (dsc *dataAnalysisCollector) GetStateChanges() []dataAnalysisStateChangesForTx { - scs, _ := dsc.getDataAnalysisStateChangesForTxs() +func (c *collector) GetDataAnalysisStateChanges() []dataAnalysisStateChangesForTx { + scs, _ := c.getDataAnalysisStateChangesForTxs() return scs } diff --git a/state/stateChanges/writeCollector.go b/state/stateChanges/writeCollector.go deleted file mode 100644 index 1691ca0fedb..00000000000 --- a/state/stateChanges/writeCollector.go +++ /dev/null @@ -1,195 +0,0 @@ -package stateChanges - -import ( - "bytes" - "sync" - - data "github.com/multiversx/mx-chain-core-go/data/stateChange" - "github.com/multiversx/mx-chain-core-go/data/transaction" - "github.com/multiversx/mx-chain-go/state" - logger "github.com/multiversx/mx-chain-logger-go" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" -) - -var log = logger.GetOrCreate("state/stateChanges") - -// StateChangesForTx is used to collect state changes for a transaction hash -type StateChangesForTx struct { - TxHash []byte `json:"txHash"` - StateChanges []state.StateChange `json:"stateChanges"` -} - -type stateChangesCollector struct { - stateChanges []state.StateChange - stateChangesMut sync.RWMutex -} - -// NewStateChangesCollector creates a new StateChangesCollector -func NewStateChangesCollector() *stateChangesCollector { - return &stateChangesCollector{ - stateChanges: make([]state.StateChange, 0), - } -} - -// AddSaveAccountStateChange adds a new state change for the save account operation -func (scc *stateChangesCollector) AddSaveAccountStateChange(_, _ vmcommon.AccountHandler, stateChange state.StateChange) { - scc.AddStateChange(stateChange) -} - -// AddStateChange adds a new state change to the collector -func (scc *stateChangesCollector) AddStateChange(stateChange state.StateChange) { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - // TODO: add custom type for stateChange type - if stateChange.GetType() == "write" { - scc.stateChanges = append(scc.stateChanges, stateChange) - } -} - -func (scc *stateChangesCollector) getStateChangesForTxs() ([]StateChangesForTx, error) { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - stateChangesForTxs := make([]StateChangesForTx, 0) - - for i := 0; i < len(scc.stateChanges); i++ { - txHash := scc.stateChanges[i].GetTxHash() - - if len(txHash) == 0 { - log.Warn("empty tx hash, state change event not associated to a transaction") - continue - } - - innerStateChangesForTx := make([]state.StateChange, 0) - for j := i; j < len(scc.stateChanges); j++ { - txHash2 := scc.stateChanges[j].GetTxHash() - if !bytes.Equal(txHash, txHash2) { - i = j - break - } - - innerStateChangesForTx = append(innerStateChangesForTx, scc.stateChanges[j]) - i = j - } - - stateChangesForTx := StateChangesForTx{ - TxHash: txHash, - StateChanges: innerStateChangesForTx, - } - stateChangesForTxs = append(stateChangesForTxs, stateChangesForTx) - } - - return stateChangesForTxs, nil -} - -// GetStateChangesForTxs will retrieve the state changes linked with the tx hash. -func (scc *stateChangesCollector) GetStateChangesForTxs() map[string]*data.StateChanges { - scc.stateChangesMut.RLock() - defer scc.stateChangesMut.RUnlock() - - stateChangesForTxs := make(map[string]*data.StateChanges) - - for _, stateChange := range scc.stateChanges { - txHash := string(stateChange.GetTxHash()) - - st, ok := stateChange.(*data.StateChange) - if !ok { - continue - } - - _, ok = stateChangesForTxs[txHash] - if !ok { - stateChangesForTxs[txHash] = &data.StateChanges{ - StateChanges: []*data.StateChange{st}, - } - } else { - stateChangesForTxs[txHash].StateChanges = append(stateChangesForTxs[txHash].StateChanges, st) - } - } - - return stateChangesForTxs -} - -// Reset resets the state changes collector -func (scc *stateChangesCollector) Reset() { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - scc.resetStateChangesUnprotected() -} - -func (scc *stateChangesCollector) resetStateChangesUnprotected() { - scc.stateChanges = make([]state.StateChange, 0) -} - -// AddTxHashToCollectedStateChanges will try to set txHash field to each state change -// if the field is not already set -func (scc *stateChangesCollector) AddTxHashToCollectedStateChanges(txHash []byte, _ *transaction.Transaction) { - scc.addTxHashToCollectedStateChanges(txHash) -} - -func (scc *stateChangesCollector) addTxHashToCollectedStateChanges(txHash []byte) { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - for i := len(scc.stateChanges) - 1; i >= 0; i-- { - if len(scc.stateChanges[i].GetTxHash()) > 0 { - break - } - - scc.stateChanges[i].SetTxHash(txHash) - } -} - -// SetIndexToLastStateChange will set index to the last state change -func (scc *stateChangesCollector) SetIndexToLastStateChange(index int) error { - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - if index > len(scc.stateChanges) || index < 0 { - return state.ErrStateChangesIndexOutOfBounds - } - - if len(scc.stateChanges) == 0 { - return nil - } - - scc.stateChanges[len(scc.stateChanges)-1].SetIndex(int32(index)) - - return nil -} - -// RevertToIndex will revert to index -func (scc *stateChangesCollector) RevertToIndex(index int) error { - if index > len(scc.stateChanges) || index < 0 { - return state.ErrStateChangesIndexOutOfBounds - } - - if index == 0 { - scc.Reset() - return nil - } - - scc.stateChangesMut.Lock() - defer scc.stateChangesMut.Unlock() - - for i := len(scc.stateChanges) - 1; i >= 0; i-- { - if scc.stateChanges[i].GetIndex() == int32(index) { - scc.stateChanges = scc.stateChanges[:i] - break - } - } - - return nil -} - -// Publish returns nil -func (scc *stateChangesCollector) Publish() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (scc *stateChangesCollector) IsInterfaceNil() bool { - return scc == nil -} diff --git a/state/stateChanges/writeCollector_test.go b/state/stateChanges/writeCollector_test.go deleted file mode 100644 index 1dc76a124c7..00000000000 --- a/state/stateChanges/writeCollector_test.go +++ /dev/null @@ -1,298 +0,0 @@ -package stateChanges - -import ( - "fmt" - "strconv" - "testing" - - data "github.com/multiversx/mx-chain-core-go/data/stateChange" - "github.com/multiversx/mx-chain-core-go/data/transaction" - "github.com/stretchr/testify/assert" - - "github.com/multiversx/mx-chain-go/state" - - "github.com/stretchr/testify/require" -) - -func getDefaultStateChange() *data.StateChange { - return &data.StateChange{ - Type: "write", - } -} - -func TestNewStateChangesCollector(t *testing.T) { - t.Parallel() - - stateChangesCollector := NewStateChangesCollector() - require.False(t, stateChangesCollector.IsInterfaceNil()) -} - -func TestStateChangesCollector_AddStateChange(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - assert.Equal(t, 0, len(scc.stateChanges)) - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(getDefaultStateChange()) - } - assert.Equal(t, numStateChanges, len(scc.stateChanges)) -} - -func TestStateChangesCollector_GetStateChanges(t *testing.T) { - t.Parallel() - - t.Run("getStateChanges with tx hash", func(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - assert.Equal(t, 0, len(scc.stateChanges)) - assert.Equal(t, 0, len(scc.GetStateChanges())) - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(&data.StateChange{ - Type: "write", - MainTrieKey: []byte(strconv.Itoa(i)), - }) - } - assert.Equal(t, numStateChanges, len(scc.stateChanges)) - assert.Equal(t, 0, len(scc.GetStateChanges())) - scc.AddTxHashToCollectedStateChanges([]byte("txHash"), &transaction.Transaction{}) - assert.Equal(t, numStateChanges, len(scc.stateChanges)) - assert.Equal(t, 1, len(scc.GetStateChanges())) - assert.Equal(t, []byte("txHash"), scc.GetStateChanges()[0].TxHash) - assert.Equal(t, numStateChanges, len(scc.GetStateChanges()[0].StateChanges)) - - stateChangesForTx := scc.GetStateChanges() - assert.Equal(t, 1, len(stateChangesForTx)) - assert.Equal(t, []byte("txHash"), stateChangesForTx[0].TxHash) - for i := 0; i < len(stateChangesForTx[0].StateChanges); i++ { - sc, ok := stateChangesForTx[0].StateChanges[i].(*data.StateChange) - require.True(t, ok) - - assert.Equal(t, []byte(strconv.Itoa(i)), sc.MainTrieKey) - } - }) - - t.Run("getStateChanges without tx hash", func(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - assert.Equal(t, 0, len(scc.stateChanges)) - assert.Equal(t, 0, len(scc.GetStateChanges())) - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(&data.StateChange{ - Type: "write", - MainTrieKey: []byte(strconv.Itoa(i)), - }) - } - assert.Equal(t, numStateChanges, len(scc.stateChanges)) - assert.Equal(t, 0, len(scc.GetStateChanges())) - - stateChangesForTx := scc.GetStateChanges() - assert.Equal(t, 0, len(stateChangesForTx)) - }) -} - -func TestStateChangesCollector_AddTxHashToCollectedStateChanges(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - assert.Equal(t, 0, len(scc.stateChanges)) - assert.Equal(t, 0, len(scc.GetStateChanges())) - - scc.AddTxHashToCollectedStateChanges([]byte("txHash0"), &transaction.Transaction{}) - - stateChange := &data.StateChange{ - Type: "write", - MainTrieKey: []byte("mainTrieKey"), - MainTrieVal: []byte("mainTrieVal"), - DataTrieChanges: []*data.DataTrieChange{{Key: []byte("dataTrieKey"), Val: []byte("dataTrieVal")}}, - } - scc.AddStateChange(stateChange) - - assert.Equal(t, 1, len(scc.stateChanges)) - assert.Equal(t, 0, len(scc.GetStateChanges())) - scc.AddTxHashToCollectedStateChanges([]byte("txHash"), &transaction.Transaction{}) - assert.Equal(t, 1, len(scc.stateChanges)) - assert.Equal(t, 1, len(scc.GetStateChanges())) - - stateChangesForTx := scc.GetStateChanges() - assert.Equal(t, 1, len(stateChangesForTx)) - assert.Equal(t, []byte("txHash"), stateChangesForTx[0].TxHash) - assert.Equal(t, 1, len(stateChangesForTx[0].StateChanges)) - - sc, ok := stateChangesForTx[0].StateChanges[0].(*data.StateChange) - require.True(t, ok) - - assert.Equal(t, []byte("mainTrieKey"), sc.MainTrieKey) - assert.Equal(t, []byte("mainTrieVal"), sc.MainTrieVal) - assert.Equal(t, 1, len(sc.DataTrieChanges)) -} - -func TestStateChangesCollector_RevertToIndex_FailIfWrongIndex(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - numStateChanges := len(scc.stateChanges) - - err := scc.RevertToIndex(-1) - require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) - - err = scc.RevertToIndex(numStateChanges + 1) - require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) -} - -func TestStateChangesCollector_RevertToIndex(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(getDefaultStateChange()) - err := scc.SetIndexToLastStateChange(i) - require.Nil(t, err) - } - scc.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) - - for i := numStateChanges; i < numStateChanges*2; i++ { - scc.AddStateChange(getDefaultStateChange()) - scc.AddTxHashToCollectedStateChanges([]byte("txHash"+fmt.Sprintf("%d", i)), &transaction.Transaction{}) - } - err := scc.SetIndexToLastStateChange(numStateChanges) - require.Nil(t, err) - - assert.Equal(t, numStateChanges*2, len(scc.stateChanges)) - - err = scc.RevertToIndex(numStateChanges) - require.Nil(t, err) - assert.Equal(t, numStateChanges*2-1, len(scc.stateChanges)) - - err = scc.RevertToIndex(numStateChanges - 1) - require.Nil(t, err) - assert.Equal(t, numStateChanges-1, len(scc.stateChanges)) - - err = scc.RevertToIndex(numStateChanges / 2) - require.Nil(t, err) - assert.Equal(t, numStateChanges/2, len(scc.stateChanges)) - - err = scc.RevertToIndex(1) - require.Nil(t, err) - assert.Equal(t, 1, len(scc.stateChanges)) - - err = scc.RevertToIndex(0) - require.Nil(t, err) - assert.Equal(t, 0, len(scc.stateChanges)) -} - -func TestStateChangesCollector_SetIndexToLastStateChange(t *testing.T) { - t.Parallel() - - t.Run("should fail if valid index", func(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - - err := scc.SetIndexToLastStateChange(-1) - require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) - - numStateChanges := len(scc.stateChanges) - err = scc.SetIndexToLastStateChange(numStateChanges + 1) - require.Equal(t, state.ErrStateChangesIndexOutOfBounds, err) - }) - - t.Run("should work", func(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(getDefaultStateChange()) - err := scc.SetIndexToLastStateChange(i) - require.Nil(t, err) - } - scc.AddTxHashToCollectedStateChanges([]byte("txHash1"), &transaction.Transaction{}) - - for i := numStateChanges; i < numStateChanges*2; i++ { - scc.AddStateChange(getDefaultStateChange()) - scc.AddTxHashToCollectedStateChanges([]byte("txHash"+fmt.Sprintf("%d", i)), &transaction.Transaction{}) - } - err := scc.SetIndexToLastStateChange(numStateChanges) - require.Nil(t, err) - - assert.Equal(t, numStateChanges*2, len(scc.stateChanges)) - }) -} - -func TestStateChangesCollector_Reset(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - assert.Equal(t, 0, len(scc.stateChanges)) - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(getDefaultStateChange()) - } - scc.AddTxHashToCollectedStateChanges([]byte("txHash"), &transaction.Transaction{}) - for i := numStateChanges; i < numStateChanges*2; i++ { - scc.AddStateChange(getDefaultStateChange()) - } - assert.Equal(t, numStateChanges*2, len(scc.stateChanges)) - - assert.Equal(t, 1, len(scc.GetStateChanges())) - - scc.Reset() - assert.Equal(t, 0, len(scc.stateChanges)) - - assert.Equal(t, 0, len(scc.GetStateChanges())) -} - -func TestStateChangesCollector_GetStateChangesForTx(t *testing.T) { - t.Parallel() - - scc := NewStateChangesCollector() - assert.Equal(t, 0, len(scc.stateChanges)) - - numStateChanges := 10 - for i := 0; i < numStateChanges; i++ { - scc.AddStateChange(&data.StateChange{ - Type: "write", - // distribute evenly based on parity of the index - TxHash: []byte(fmt.Sprintf("hash%d", i%2)), - }) - } - - stateChangesForTx := scc.GetStateChangesForTxs() - - require.Len(t, stateChangesForTx, 2) - require.Len(t, stateChangesForTx["hash0"].StateChanges, 5) - require.Len(t, stateChangesForTx["hash1"].StateChanges, 5) - - require.Equal(t, stateChangesForTx, map[string]*data.StateChanges{ - "hash0": { - StateChanges: []*data.StateChange{ - {Type: "write", TxHash: []byte("hash0")}, - {Type: "write", TxHash: []byte("hash0")}, - {Type: "write", TxHash: []byte("hash0")}, - {Type: "write", TxHash: []byte("hash0")}, - {Type: "write", TxHash: []byte("hash0")}, - }, - }, - "hash1": { - StateChanges: []*data.StateChange{ - {Type: "write", TxHash: []byte("hash1")}, - {Type: "write", TxHash: []byte("hash1")}, - {Type: "write", TxHash: []byte("hash1")}, - {Type: "write", TxHash: []byte("hash1")}, - {Type: "write", TxHash: []byte("hash1")}, - }, - }, - }) -} diff --git a/state/storagePruningManager/storagePruningManager_test.go b/state/storagePruningManager/storagePruningManager_test.go index 8f8c39cee1f..dee2913a012 100644 --- a/state/storagePruningManager/storagePruningManager_test.go +++ b/state/storagePruningManager/storagePruningManager_test.go @@ -3,6 +3,8 @@ package storagePruningManager import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/common/statistics" @@ -11,16 +13,15 @@ import ( "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/iteratorChannelsProvider" "github.com/multiversx/mx-chain-go/state/lastSnapshotMarker" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + stateMock "github.com/multiversx/mx-chain-go/testscommon/state" testStorage "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" - "github.com/stretchr/testify/assert" ) func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state.AccountsDB, *storagePruningManager) { @@ -41,13 +42,11 @@ func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state. ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) spm, _ := NewStoragePruningManager(ewl, generalCfg.PruningBufferLen) - stateChangesCollector := stateChanges.NewStateChangesCollector() - argsAccCreator := factory.ArgsAccountCreator{ Hasher: hasher, Marshaller: marshaller, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateChangesCollector: stateChangesCollector, + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) @@ -71,7 +70,7 @@ func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state. StoragePruningManager: spm, AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, - StateChangesCollector: stateChangesCollector, + StateChangesCollector: &stateMock.StateChangesCollectorStub{}, } adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/state/trackableDataTrie/trackableDataTrie.go b/state/trackableDataTrie/trackableDataTrie.go index 4d4f52a8e27..8c4196ae263 100644 --- a/state/trackableDataTrie/trackableDataTrie.go +++ b/state/trackableDataTrie/trackableDataTrie.go @@ -107,19 +107,19 @@ func (tdt *trackableDataTrie) RetrieveValue(key []byte) ([]byte, uint32, error) log.Trace("retrieve value from trie", "key", key, "value", val, "account", tdt.identifier) - stateChange := &stateChange.StateChange{ - Type: "read", + sc := &stateChange.StateChange{ + Type: stateChange.Read, MainTrieKey: tdt.identifier, MainTrieVal: nil, DataTrieChanges: []*stateChange.DataTrieChange{ { - Type: "read", + Type: stateChange.Read, Key: key, Val: val, }, }, } - tdt.stateChangesCollector.AddStateChange(stateChange) + tdt.stateChangesCollector.AddStateChange(sc) return val, depth, nil } @@ -292,7 +292,7 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.Dat if wasDeleted { deletedKeys = append(deletedKeys, &stateChange.DataTrieChange{ - Type: "write", + Type: stateChange.Write, Key: []byte(key), Val: nil, }, @@ -323,7 +323,7 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.Dat } newData[dataEntry.index] = &stateChange.DataTrieChange{ - Type: "write", + Type: stateChange.Write, Key: dataTrieKey, Val: dataTrieVal, } diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index 99e492f186c..5063efbf4c1 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -7,19 +7,20 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" errorsCommon "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/dataTrieValue" - "github.com/multiversx/mx-chain-go/state/stateChanges" "github.com/multiversx/mx-chain-go/state/trackableDataTrie" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + stateMock "github.com/multiversx/mx-chain-go/testscommon/state" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) func TestNewTrackableDataTrie(t *testing.T) { @@ -33,7 +34,7 @@ func TestNewTrackableDataTrie(t *testing.T) { nil, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.Equal(t, state.ErrNilHasher, err) assert.True(t, check.IfNil(tdt)) @@ -47,7 +48,7 @@ func TestNewTrackableDataTrie(t *testing.T) { &hashingMocks.HasherMock{}, nil, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.Equal(t, state.ErrNilMarshalizer, err) assert.True(t, check.IfNil(tdt)) @@ -61,7 +62,7 @@ func TestNewTrackableDataTrie(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, nil, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.Equal(t, state.ErrNilEnableEpochsHandler, err) assert.True(t, check.IfNil(tdt)) @@ -75,7 +76,7 @@ func TestNewTrackableDataTrie(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpochsHandlerMock.NewEnableEpochsHandlerStubWithNoFlagsDefined(), - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.True(t, errors.Is(err, core.ErrInvalidEnableEpochsHandler)) assert.True(t, check.IfNil(tdt)) @@ -89,7 +90,7 @@ func TestNewTrackableDataTrie(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.Nil(t, err) assert.False(t, check.IfNil(tdt)) @@ -107,7 +108,7 @@ func TestTrackableDataTrie_SaveKeyValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) err := tdt.SaveKeyValue([]byte("key"), make([]byte, core.MaxLeafSize+1)) @@ -134,7 +135,7 @@ func TestTrackableDataTrie_SaveKeyValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -173,7 +174,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -196,7 +197,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.Nil(t, err) assert.NotNil(t, tdt) @@ -232,7 +233,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -273,7 +274,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -318,7 +319,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -343,7 +344,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -379,7 +380,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -415,7 +416,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) assert.NotNil(t, tdt) tdt.SetDataTrie(trie) @@ -437,7 +438,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) stateChanges, oldValues, err := tdt.SaveDirtyData(&trieMock.TrieStub{}) @@ -469,7 +470,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) key := []byte("key") @@ -505,7 +506,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) expectedKey := []byte("key") @@ -533,7 +534,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, stateChanges.NewStateChangesCollector()) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, expectedVal) @@ -571,7 +572,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) expectedKey := []byte("key") @@ -597,7 +598,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, stateChanges.NewStateChangesCollector()) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) @@ -629,7 +630,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) expectedKey := []byte("key") @@ -662,7 +663,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) tdt.SetDataTrie(trie) @@ -695,7 +696,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) expectedKey := []byte("key") @@ -718,7 +719,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, stateChanges.NewStateChangesCollector()) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -748,7 +749,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, stateChanges.NewStateChangesCollector()) + tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) @@ -773,7 +774,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, stateChanges.NewStateChangesCollector()) + tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) @@ -802,7 +803,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, stateChanges.NewStateChangesCollector()) + tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) @@ -839,7 +840,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs, stateChanges.NewStateChangesCollector()) + tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) @@ -877,7 +878,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs, stateChanges.NewStateChangesCollector()) + tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs, &stateMock.StateChangesCollectorStub{}) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, nil) @@ -907,7 +908,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) expectedKey := []byte("key") @@ -959,7 +960,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { hasher, marshaller, enableEpochsHandler, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) key1 := "key1" @@ -1030,7 +1031,7 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) args := vmcommon.ArgsMigrateDataTrieLeaves{ OldVersion: core.NotSpecified, @@ -1049,7 +1050,7 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) tdt.SetDataTrie(&trieMock.TrieStub{}) @@ -1077,7 +1078,7 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) tdt.SetDataTrie(tr) args := vmcommon.ArgsMigrateDataTrieLeaves{ @@ -1132,7 +1133,7 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) tdt.SetDataTrie(tr) args := vmcommon.ArgsMigrateDataTrieLeaves{ @@ -1161,7 +1162,7 @@ func TestTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - stateChanges.NewStateChangesCollector(), + &stateMock.StateChangesCollectorStub{}, ) newTrie := &trieMock.TrieStub{} diff --git a/testscommon/state/stateChangesCollectorStub.go b/testscommon/state/stateChangesCollectorStub.go index c8fc099879f..a62e5d7ab8d 100644 --- a/testscommon/state/stateChangesCollectorStub.go +++ b/testscommon/state/stateChangesCollectorStub.go @@ -16,9 +16,9 @@ type StateChangesCollectorStub struct { AddTxHashToCollectedStateChangesCalled func(txHash []byte, tx *transaction.Transaction) SetIndexToLastStateChangeCalled func(index int) error RevertToIndexCalled func(index int) error - PublishCalled func() error + PublishCalled func() (map[string]*stateChange.StateChanges, error) + StoreCalled func() error IsInterfaceNilCalled func() bool - GetStateChangesForTxsCalled func() map[string]*stateChange.StateChanges } // AddStateChange - @@ -68,11 +68,19 @@ func (s *StateChangesCollectorStub) RevertToIndex(index int) error { } // Publish - -func (s *StateChangesCollectorStub) Publish() error { +func (s *StateChangesCollectorStub) Publish() (map[string]*stateChange.StateChanges, error) { if s.PublishCalled != nil { return s.PublishCalled() } + return nil, nil +} + +func (s *StateChangesCollectorStub) Store() error { + if s.StoreCalled != nil { + return s.StoreCalled() + } + return nil } @@ -84,12 +92,3 @@ func (s *StateChangesCollectorStub) IsInterfaceNil() bool { return false } - -// GetStateChangesForTxs - -func (s *StateChangesCollectorStub) GetStateChangesForTxs() map[string]*stateChange.StateChanges { - if s.GetStateChangesForTxsCalled != nil { - return s.GetStateChangesForTxsCalled() - } - - return nil -}