diff --git a/components/chains/component.go b/components/chains/component.go index fe6c6ab3c3..2bbd8eccef 100644 --- a/components/chains/component.go +++ b/components/chains/component.go @@ -101,6 +101,7 @@ func provide(c *dig.Container) error { deps.NetworkProvider, deps.TrustedNetworkManager, deps.ChainStateDatabaseManager.ChainStateKVStore, + ParamsWAL.LoadToStore, ParamsWAL.Enabled, ParamsWAL.Path, ParamsStateManager.BlockCacheMaxSize, @@ -111,6 +112,10 @@ func provide(c *dig.Container) error { ParamsStateManager.StateManagerTimerTickPeriod, ParamsStateManager.PruningMinStatesToKeep, ParamsStateManager.PruningMaxStatesToDelete, + ParamsSnapshotManager.Period, + ParamsSnapshotManager.LocalPath, + ParamsSnapshotManager.NetworkPaths, + ParamsSnapshotManager.UpdatePeriod, deps.ChainRecordRegistryProvider, deps.DKShareRegistryProvider, deps.NodeIdentityProvider, diff --git a/components/chains/params.go b/components/chains/params.go index 9543a77bb0..3263316c58 100644 --- a/components/chains/params.go +++ b/components/chains/params.go @@ -17,8 +17,9 @@ type ParametersChains struct { } type ParametersWAL struct { - Enabled bool `default:"true" usage:"whether the \"write-ahead logging\" is enabled"` - Path string `default:"waspdb/wal" usage:"the path to the \"write-ahead logging\" folder"` + LoadToStore bool `default:"false" usage:"load blocks from \"write-ahead log\" to the store on node start-up"` + Enabled bool `default:"true" usage:"whether the \"write-ahead logging\" is enabled"` + Path string `default:"waspdb/wal" usage:"the path to the \"write-ahead logging\" folder"` } type ParametersValidator struct { @@ -36,11 +37,19 @@ type ParametersStateManager struct { PruningMaxStatesToDelete int `default:"1000" usage:"on single store pruning attempt at most this number of states will be deleted"` } +type ParametersSnapshotManager struct { + Period uint32 `default:"0" usage:"how often state snapshots should be made: 1000 meaning \"every 1000th state\", 0 meaning \"making snapshots is disabled\""` + LocalPath string `default:"waspdb/snap" usage:"the path to the snapshots folder in this node's disk"` + NetworkPaths []string `default:"" usage:"the list of paths to the remote (http(s)) snapshot locations; each of listed locations must contain 'INDEX' file with list of snapshot files"` + UpdatePeriod time.Duration `default:"5m" usage:"how often known snapshots list should be updated"` +} + var ( - ParamsChains = &ParametersChains{} - ParamsWAL = &ParametersWAL{} - ParamsValidator = &ParametersValidator{} - ParamsStateManager = &ParametersStateManager{} + ParamsChains = &ParametersChains{} + ParamsWAL = &ParametersWAL{} + ParamsValidator = &ParametersValidator{} + ParamsStateManager = &ParametersStateManager{} + ParamsSnapshotManager = &ParametersSnapshotManager{} ) var params = &app.ComponentParams{ @@ -49,6 +58,7 @@ var params = &app.ComponentParams{ "wal": ParamsWAL, "validator": ParamsValidator, "stateManager": ParamsStateManager, + "snapshots": ParamsSnapshotManager, }, Masked: nil, } diff --git a/packages/chain/node.go b/packages/chain/node.go index ef5f98b4b1..ad3cc5f8bd 100644 --- a/packages/chain/node.go +++ b/packages/chain/node.go @@ -35,6 +35,7 @@ import ( "github.com/iotaledger/wasp/packages/chain/statemanager" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/cryptolib" "github.com/iotaledger/wasp/packages/gpa" "github.com/iotaledger/wasp/packages/isc" @@ -262,7 +263,9 @@ func New( processorConfig *processors.Config, dkShareRegistryProvider registry.DKShareRegistryProvider, consensusStateRegistry cmt_log.ConsensusStateRegistry, + recoverFromWAL bool, blockWAL sm_gpa_utils.BlockWAL, + snapshotManager sm_snapshots.SnapshotManager, listener ChainListener, accessNodesFromNode []*cryptolib.PublicKey, net peering.NetworkProvider, @@ -339,7 +342,9 @@ func New( cni.chainMetrics.Pipe.TrackPipeLen("node-serversUpdatedPipe", cni.serversUpdatedPipe.Len) cni.chainMetrics.Pipe.TrackPipeLen("node-netRecvPipe", cni.netRecvPipe.Len) - cni.tryRecoverStoreFromWAL(chainStore, blockWAL) + if recoverFromWAL { + cni.recoverStoreFromWAL(chainStore, blockWAL) + } cni.me = cni.pubKeyAsNodeID(nodeIdentity.GetPublicKey()) // // Create sub-components. @@ -398,11 +403,12 @@ func New( peerPubKeys, net, blockWAL, + snapshotManager, chainStore, shutdownCoordinator.Nested("StateMgr"), chainMetrics.StateManager, chainMetrics.Pipe, - cni.log.Named("SM"), + cni.log, smParameters, ) if err != nil { @@ -1222,20 +1228,7 @@ func (cni *chainNodeImpl) GetConsensusWorkflowStatus() ConsensusWorkflowStatus { return &consensusWorkflowStatusImpl{} } -func (cni *chainNodeImpl) tryRecoverStoreFromWAL(chainStore indexedstore.IndexedStore, chainWAL sm_gpa_utils.BlockWAL) { - defer func() { - if r := recover(); r != nil { - // Don't fail, if this crashes for some reason, that's an optional step. - cni.log.Warnf("TryRecoverStoreFromWAL: Failed to populate chain store from WAL: %v", r) - } - }() - // - // Check, if store is empty. - if _, err := chainStore.BlockByIndex(0); err == nil { - cni.log.Infof("TryRecoverStoreFromWAL: Skipping, because the state is not empty.") - return // Store is not empty, so we skip this. - } - cni.log.Infof("TryRecoverStoreFromWAL: Chain store is empty, will try to load blocks from the WAL.") +func (cni *chainNodeImpl) recoverStoreFromWAL(chainStore indexedstore.IndexedStore, chainWAL sm_gpa_utils.BlockWAL) { // // Load all the existing blocks from the WAL. blocksAdded := 0 diff --git a/packages/chain/node_test.go b/packages/chain/node_test.go index 1eb6a96a58..1d53f19791 100644 --- a/packages/chain/node_test.go +++ b/packages/chain/node_test.go @@ -21,6 +21,7 @@ import ( "github.com/iotaledger/wasp/packages/chain" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/cryptolib" "github.com/iotaledger/wasp/packages/isc" "github.com/iotaledger/wasp/packages/kv/dict" @@ -451,7 +452,9 @@ func newEnv(t *testing.T, n, f int, reliable bool) *testEnv { coreprocessors.NewConfigWithCoreContracts().WithNativeContracts(inccounter.Processor), dkShareProviders[i], testutil.NewConsensusStateRegistry(), + false, sm_gpa_utils.NewMockedTestBlockWAL(), + sm_snapshots.NewEmptySnapshotManager(), chain.NewEmptyChainListener(), []*cryptolib.PublicKey{}, // Access nodes. te.networkProviders[i], diff --git a/packages/chain/statemanager/sm_gpa/block_fetcher.go b/packages/chain/statemanager/sm_gpa/block_fetcher.go index cdedac7c59..a1d33bc4e8 100644 --- a/packages/chain/statemanager/sm_gpa/block_fetcher.go +++ b/packages/chain/statemanager/sm_gpa/block_fetcher.go @@ -8,6 +8,7 @@ import ( type blockFetcherImpl struct { start time.Time + stateIndex uint32 commitment *state.L1Commitment callbacks []blockRequestCallback related []blockFetcher @@ -15,27 +16,33 @@ type blockFetcherImpl struct { var _ blockFetcher = &blockFetcherImpl{} -func newBlockFetcher(commitment *state.L1Commitment) blockFetcher { +func newBlockFetcher(stateIndex uint32, commitment *state.L1Commitment) blockFetcher { return &blockFetcherImpl{ start: time.Now(), + stateIndex: stateIndex, commitment: commitment, callbacks: make([]blockRequestCallback, 0), related: make([]blockFetcher, 0), } } -func newBlockFetcherWithCallback(commitment *state.L1Commitment, callback blockRequestCallback) blockFetcher { - result := newBlockFetcher(commitment) +func newBlockFetcherWithCallback(stateIndex uint32, commitment *state.L1Commitment, callback blockRequestCallback) blockFetcher { + result := newBlockFetcher(stateIndex, commitment) result.addCallback(callback) return result } func newBlockFetcherWithRelatedFetcher(commitment *state.L1Commitment, fetcher blockFetcher) blockFetcher { - result := newBlockFetcher(commitment) + newStateIndex := fetcher.getStateIndex() - 1 + result := newBlockFetcher(newStateIndex, commitment) result.addRelatedFetcher(fetcher) return result } +func (bfiT *blockFetcherImpl) getStateIndex() uint32 { + return bfiT.stateIndex +} + func (bfiT *blockFetcherImpl) getCommitment() *state.L1Commitment { return bfiT.commitment } @@ -52,17 +59,21 @@ func (bfiT *blockFetcherImpl) addRelatedFetcher(fetcher blockFetcher) { bfiT.related = append(bfiT.related, fetcher) } -func (bfiT *blockFetcherImpl) notifyFetched(notifyFun func(blockFetcher) bool) { - if notifyFun(bfiT) { - for _, callback := range bfiT.callbacks { - if callback.isValid() { - callback.requestCompleted() - } - } - for _, fetcher := range bfiT.related { - fetcher.notifyFetched(notifyFun) +func (bfiT *blockFetcherImpl) commitAndNotifyFetched(commitFun func(blockFetcher) bool) { + if commitFun(bfiT) { + bfiT.notifyFetched(commitFun) + } +} + +func (bfiT *blockFetcherImpl) notifyFetched(commitFun func(blockFetcher) bool) { + for _, callback := range bfiT.callbacks { + if callback.isValid() { + callback.requestCompleted() } } + for _, fetcher := range bfiT.related { + fetcher.commitAndNotifyFetched(commitFun) + } } func (bfiT *blockFetcherImpl) cleanCallbacks() { diff --git a/packages/chain/statemanager/sm_gpa/interface.go b/packages/chain/statemanager/sm_gpa/interface.go index fff383390a..279ddb4c37 100644 --- a/packages/chain/statemanager/sm_gpa/interface.go +++ b/packages/chain/statemanager/sm_gpa/interface.go @@ -3,18 +3,32 @@ package sm_gpa import ( "time" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/state" ) +type StateManagerOutput interface { + addBlockCommitted(uint32, *state.L1Commitment) + addSnapshotToLoad(uint32, *state.L1Commitment) + setUpdateSnapshots() + TakeBlocksCommitted() []sm_snapshots.SnapshotInfo + TakeSnapshotToLoad() sm_snapshots.SnapshotInfo + TakeUpdateSnapshots() bool +} + +type SnapshotExistsFun func(uint32, *state.L1Commitment) bool + type blockRequestCallback interface { isValid() bool requestCompleted() } type blockFetcher interface { + getStateIndex() uint32 getCommitment() *state.L1Commitment getCallbacksCount() int - notifyFetched(func(blockFetcher) bool) // calls fun for this fetcher and each related recursively; fun for parent block is always called before fun for related block + commitAndNotifyFetched(func(blockFetcher) bool) // calls fun for this block, notifies waiting callbacks of this fetcher and does the same for each related fetcher recursively; fun for parent block is always called before fun for related block + notifyFetched(func(blockFetcher) bool) // notifies waiting callbacks of this fetcher, then calls fun and notifies waiting callbacks of all related fetchers recursively; fun for parent block is always called before fun for related block addCallback(blockRequestCallback) addRelatedFetcher(blockFetcher) cleanCallbacks() diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go index 2376f01a54..25c214875f 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go @@ -30,7 +30,7 @@ var _ BlockCache = &blockCache{} func NewBlockCache(tp TimeProvider, maxCacheSize int, wal BlockWAL, metrics *metrics.ChainStateManagerMetrics, log *logger.Logger) (BlockCache, error) { return &blockCache{ - log: log.Named("bc"), + log: log.Named("BC"), blocks: shrinkingmap.New[BlockKey, state.Block](), maxCacheSize: maxCacheSize, wal: wal, diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go index 0d93a81be1..09df85015e 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go @@ -24,7 +24,7 @@ type blockWAL struct { metrics *metrics.ChainBlockWALMetrics } -const constFileSuffix = ".blk" +const constBlockWALFileSuffix = ".blk" func NewBlockWAL(log *logger.Logger, baseDir string, chainID isc.ChainID, metrics *metrics.ChainBlockWALMetrics) (BlockWAL, error) { dir := filepath.Join(baseDir, chainID.String()) @@ -33,7 +33,7 @@ func NewBlockWAL(log *logger.Logger, baseDir string, chainID isc.ChainID, metric } result := &blockWAL{ - WrappedLogger: logger.NewWrappedLogger(log), + WrappedLogger: logger.NewWrappedLogger(log.Named("WAL")), dir: dir, metrics: metrics, } @@ -45,7 +45,7 @@ func NewBlockWAL(log *logger.Logger, baseDir string, chainID isc.ChainID, metric func (bwT *blockWAL) Write(block state.Block) error { blockIndex := block.StateIndex() commitment := block.L1Commitment() - fileName := fileName(commitment.BlockHash()) + fileName := blockWALFileName(commitment.BlockHash()) filePath := filepath.Join(bwT.dir, fileName) f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) if err != nil { @@ -69,12 +69,12 @@ func (bwT *blockWAL) Write(block state.Block) error { } func (bwT *blockWAL) Contains(blockHash state.BlockHash) bool { - _, err := os.Stat(filepath.Join(bwT.dir, fileName(blockHash))) + _, err := os.Stat(filepath.Join(bwT.dir, blockWALFileName(blockHash))) return err == nil } func (bwT *blockWAL) Read(blockHash state.BlockHash) (state.Block, error) { - fileName := fileName(blockHash) + fileName := blockWALFileName(blockHash) filePath := filepath.Join(bwT.dir, fileName) block, err := blockFromFilePath(filePath) if err != nil { @@ -97,7 +97,7 @@ func (bwT *blockWAL) ReadAllByStateIndex(cb func(stateIndex uint32, block state. if !dirEntry.Type().IsRegular() { continue } - if !strings.HasSuffix(dirEntry.Name(), constFileSuffix) { + if !strings.HasSuffix(dirEntry.Name(), constBlockWALFileSuffix) { continue } filePath := filepath.Join(bwT.dir, dirEntry.Name()) @@ -160,6 +160,6 @@ func blockFromFilePath(filePath string) (state.Block, error) { return block, nil } -func fileName(blockHash state.BlockHash) string { - return blockHash.String() + constFileSuffix +func blockWALFileName(blockHash state.BlockHash) string { + return blockHash.String() + constBlockWALFileSuffix } diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go index 45620296d5..be7c42cc8c 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go @@ -189,7 +189,7 @@ func (bwtsmT *blockWALTestSM) getGoodBlockHashes() []state.BlockHash { } func (bwtsmT *blockWALTestSM) pathFromHash(blockHash state.BlockHash) string { - return filepath.Join(constTestFolder, bwtsmT.factory.GetChainID().String(), fileName(blockHash)) + return filepath.Join(constTestFolder, bwtsmT.factory.GetChainID().String(), blockWALFileName(blockHash)) } func (bwtsmT *blockWALTestSM) invariantAllWrittenBlocksExist(t *rapid.T) { diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go index 49233a6a4a..81ea15a778 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go @@ -64,7 +64,7 @@ func TestBlockWALOverwrite(t *testing.T) { require.NoError(t, err) } pathFromHashFun := func(blockHash state.BlockHash) string { - return filepath.Join(constTestFolder, factory.GetChainID().String(), fileName(blockHash)) + return filepath.Join(constTestFolder, factory.GetChainID().String(), blockWALFileName(blockHash)) } file0Path := pathFromHashFun(blocks[0].Hash()) file1Path := pathFromHashFun(blocks[1].Hash()) diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/read_only_store.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/read_only_store.go new file mode 100644 index 0000000000..66718a3ae6 --- /dev/null +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/read_only_store.go @@ -0,0 +1,84 @@ +package sm_gpa_utils + +import ( + "fmt" + "time" + + "github.com/iotaledger/hive.go/kvstore" + "github.com/iotaledger/wasp/packages/state" + "github.com/iotaledger/wasp/packages/trie" +) + +type readOnlyStore struct { + store state.Store +} + +var _ state.Store = &readOnlyStore{} + +func NewReadOnlyStore(store state.Store) state.Store { + return &readOnlyStore{store: store} +} + +func (ros *readOnlyStore) HasTrieRoot(trieRoot trie.Hash) bool { + return ros.store.HasTrieRoot(trieRoot) +} + +func (ros *readOnlyStore) BlockByTrieRoot(trieRoot trie.Hash) (state.Block, error) { + return ros.store.BlockByTrieRoot(trieRoot) +} + +func (ros *readOnlyStore) StateByTrieRoot(trieRoot trie.Hash) (state.State, error) { + return ros.store.StateByTrieRoot(trieRoot) +} + +func (ros *readOnlyStore) SetLatest(trie.Hash) error { + return fmt.Errorf("cannot write to read-only store") +} + +func (ros *readOnlyStore) LatestBlockIndex() (uint32, error) { + return ros.store.LatestBlockIndex() +} + +func (ros *readOnlyStore) LatestBlock() (state.Block, error) { + return ros.store.LatestBlock() +} + +func (ros *readOnlyStore) LatestState() (state.State, error) { + return ros.store.LatestState() +} + +func (ros *readOnlyStore) LatestTrieRoot() (trie.Hash, error) { + return ros.store.LatestTrieRoot() +} + +func (ros *readOnlyStore) NewOriginStateDraft() state.StateDraft { + panic("Cannot create origin state draft in read-only store") +} + +func (ros *readOnlyStore) NewStateDraft(time.Time, *state.L1Commitment) (state.StateDraft, error) { + return nil, fmt.Errorf("cannot create state draft in read-only store") +} + +func (ros *readOnlyStore) NewEmptyStateDraft(prevL1Commitment *state.L1Commitment) (state.StateDraft, error) { + return nil, fmt.Errorf("cannot create empty state draft in read-only store") +} + +func (ros *readOnlyStore) Commit(state.StateDraft) state.Block { + panic("Cannot commit to read-only store") +} + +func (ros *readOnlyStore) ExtractBlock(stateDraft state.StateDraft) state.Block { + return ros.store.ExtractBlock(stateDraft) +} + +func (ros *readOnlyStore) Prune(trie.Hash) (trie.PruneStats, error) { + panic("Cannot prune read-only store") +} + +func (ros *readOnlyStore) TakeSnapshot(trieRoot trie.Hash, kvStore kvstore.KVStore) error { + return ros.store.TakeSnapshot(trieRoot, kvStore) +} + +func (ros *readOnlyStore) RestoreSnapshot(trie.Hash, kvstore.KVStore) error { + return fmt.Errorf("cannot write snapshot into read-only store") +} diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_block_factory.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_block_factory.go index 1a6c31e729..1a2168c4af 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_block_factory.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_block_factory.go @@ -153,6 +153,10 @@ func (bfT *BlockFactory) GetNextBlock( return block } +func (bfT *BlockFactory) GetStore() state.Store { + return NewReadOnlyStore(bfT.store) +} + func (bfT *BlockFactory) GetStateDraft(block state.Block) state.StateDraft { result, err := bfT.store.NewEmptyStateDraft(block.PreviousL1Commitment()) require.NoError(bfT.t, err) @@ -160,12 +164,6 @@ func (bfT *BlockFactory) GetStateDraft(block state.Block) state.StateDraft { return result } -func (bfT *BlockFactory) GetState(commitment *state.L1Commitment) state.State { - result, err := bfT.store.StateByTrieRoot(commitment.TrieRoot()) - require.NoError(bfT.t, err) - return result -} - func (bfT *BlockFactory) GetAliasOutput(commitment *state.L1Commitment) *isc.AliasOutputWithID { result, ok := bfT.aliasOutputs[commitment.BlockHash()] require.True(bfT.t, ok) diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_utils.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_utils.go index b1e181ac7d..246c01dc91 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_utils.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/test_utils.go @@ -4,8 +4,11 @@ package sm_gpa_utils import ( + "bytes" + "github.com/stretchr/testify/require" + "github.com/iotaledger/wasp/packages/kv" "github.com/iotaledger/wasp/packages/state" ) @@ -34,3 +37,63 @@ func BlocksEqual(block1, block2 state.Block) bool { func CheckBlocksDifferent(t require.TestingT, block1, block2 state.Block) { require.False(t, block1.Hash().Equals(block2.Hash())) } + +// ----------------------------------------------------------------------------- +func CheckStateInStores(t require.TestingT, storeOrig, storeNew state.Store, commitment *state.L1Commitment) { + origState, err := storeOrig.StateByTrieRoot(commitment.TrieRoot()) + require.NoError(t, err) + CheckStateInStore(t, storeNew, origState) +} + +func CheckStateInStore(t require.TestingT, store state.Store, origState state.State) { + stateFromStore, err := store.StateByTrieRoot(origState.TrieRoot()) + require.NoError(t, err) + require.True(t, origState.TrieRoot().Equals(stateFromStore.TrieRoot())) + require.Equal(t, origState.BlockIndex(), stateFromStore.BlockIndex()) + require.Equal(t, origState.Timestamp(), stateFromStore.Timestamp()) + require.True(t, origState.PreviousL1Commitment().Equals(stateFromStore.PreviousL1Commitment())) + commonState := getCommonState(origState, stateFromStore) + for _, entry := range commonState { + require.True(t, bytes.Equal(entry.value1, entry.value2)) + } +} + +// NOTE: this function should not exist. state.State should have Equals method +func StatesEqual(state1, state2 state.State) bool { + if !state1.TrieRoot().Equals(state2.TrieRoot()) || + state1.BlockIndex() != state2.BlockIndex() || + state1.Timestamp() != state2.Timestamp() || + !state1.PreviousL1Commitment().Equals(state2.PreviousL1Commitment()) { + return false + } + commonState := getCommonState(state1, state2) + for _, entry := range commonState { + if !bytes.Equal(entry.value1, entry.value2) { + return false + } + } + return true +} + +type commonEntry struct { + value1 []byte + value2 []byte +} + +func getCommonState(state1, state2 state.State) map[kv.Key]*commonEntry { + result := make(map[kv.Key]*commonEntry) + iterateFun := func(iterState state.State, setValueFun func(*commonEntry, []byte)) { + iterState.Iterate(kv.EmptyPrefix, func(key kv.Key, value []byte) bool { + entry, ok := result[key] + if !ok { + entry = &commonEntry{} + result[key] = entry + } + setValueFun(entry, value) + return true + }) + } + iterateFun(state1, func(entry *commonEntry, value []byte) { entry.value1 = value }) + iterateFun(state2, func(entry *commonEntry, value []byte) { entry.value2 = value }) + return result +} diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/time_provider_artifficial.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/time_provider_artifficial.go index 1539362b4a..9f29c9c624 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/time_provider_artifficial.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/time_provider_artifficial.go @@ -1,12 +1,14 @@ package sm_gpa_utils import ( + "sync" "time" ) type artifficialTimeProvider struct { now time.Time timers []*timer + mutex sync.Mutex } type timer struct { @@ -26,10 +28,14 @@ func NewArtifficialTimeProvider(nowOpt ...time.Time) TimeProvider { return &artifficialTimeProvider{ now: now, timers: make([]*timer, 0), + mutex: sync.Mutex{}, } } func (atpT *artifficialTimeProvider) SetNow(now time.Time) { + atpT.mutex.Lock() + defer atpT.mutex.Unlock() + atpT.now = now var i int for i = 0; i < len(atpT.timers) && atpT.timers[i].time.Before(atpT.now); i++ { @@ -40,26 +46,37 @@ func (atpT *artifficialTimeProvider) SetNow(now time.Time) { } func (atpT *artifficialTimeProvider) GetNow() time.Time { + atpT.mutex.Lock() + defer atpT.mutex.Unlock() + return atpT.now } func (atpT *artifficialTimeProvider) After(d time.Duration) <-chan time.Time { - timerTime := atpT.now.Add(d) + channel := make(chan time.Time, 1) + if d == 0 { + channel <- atpT.now + close(channel) + } else { + atpT.mutex.Lock() + defer atpT.mutex.Unlock() - var count int - for i := 0; i < len(atpT.timers) && atpT.timers[i].time.Before(timerTime); i++ { - count++ - } + timerTime := atpT.now.Add(d) - if count == len(atpT.timers) { - atpT.timers = append(atpT.timers, nil) - } else { - atpT.timers = append(atpT.timers[:count+1], atpT.timers[count:]...) - } - channel := make(chan time.Time, 1) - atpT.timers[count] = &timer{ - time: timerTime, - channel: channel, + var count int + for i := 0; i < len(atpT.timers) && atpT.timers[i].time.Before(timerTime); i++ { + count++ + } + + if count == len(atpT.timers) { + atpT.timers = append(atpT.timers, nil) + } else { + atpT.timers = append(atpT.timers[:count+1], atpT.timers[count:]...) + } + atpT.timers[count] = &timer{ + time: timerTime, + channel: channel, + } } return channel } diff --git a/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_decided_state.go b/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_decided_state.go index a2f271bded..e6e341499b 100644 --- a/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_decided_state.go +++ b/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_decided_state.go @@ -11,6 +11,7 @@ import ( type ConsensusDecidedState struct { context context.Context + stateIndex uint32 l1Commitment *state.L1Commitment resultCh chan<- state.State } @@ -25,11 +26,16 @@ func NewConsensusDecidedState(ctx context.Context, aliasOutput *isc.AliasOutputW resultChannel := make(chan state.State, 1) return &ConsensusDecidedState{ context: ctx, + stateIndex: aliasOutput.GetStateIndex(), l1Commitment: commitment, resultCh: resultChannel, }, resultChannel } +func (cdsT *ConsensusDecidedState) GetStateIndex() uint32 { + return cdsT.stateIndex +} + func (cdsT *ConsensusDecidedState) GetL1Commitment() *state.L1Commitment { return cdsT.l1Commitment } diff --git a/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_state_proposal.go b/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_state_proposal.go index ada5eb5db0..6d42ee9449 100644 --- a/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_state_proposal.go +++ b/packages/chain/statemanager/sm_gpa/sm_inputs/consensus_state_proposal.go @@ -11,6 +11,7 @@ import ( type ConsensusStateProposal struct { context context.Context + stateIndex uint32 l1Commitment *state.L1Commitment resultCh chan<- interface{} } @@ -25,11 +26,16 @@ func NewConsensusStateProposal(ctx context.Context, aliasOutput *isc.AliasOutput resultChannel := make(chan interface{}, 1) return &ConsensusStateProposal{ context: ctx, + stateIndex: aliasOutput.GetStateIndex(), l1Commitment: commitment, resultCh: resultChannel, }, resultChannel } +func (cspT *ConsensusStateProposal) GetStateIndex() uint32 { + return cspT.stateIndex +} + func (cspT *ConsensusStateProposal) GetL1Commitment() *state.L1Commitment { return cspT.l1Commitment } diff --git a/packages/chain/statemanager/sm_gpa/sm_inputs/snapshot_manager_snapshot_done.go b/packages/chain/statemanager/sm_gpa/sm_inputs/snapshot_manager_snapshot_done.go new file mode 100644 index 0000000000..b6faafd75d --- /dev/null +++ b/packages/chain/statemanager/sm_gpa/sm_inputs/snapshot_manager_snapshot_done.go @@ -0,0 +1,34 @@ +package sm_inputs + +import ( + "github.com/iotaledger/wasp/packages/gpa" + "github.com/iotaledger/wasp/packages/state" +) + +type SnapshotManagerSnapshotDone struct { + stateIndex uint32 + commitment *state.L1Commitment + result error +} + +var _ gpa.Input = &SnapshotManagerSnapshotDone{} + +func NewSnapshotManagerSnapshotDone(stateIndex uint32, commitment *state.L1Commitment, result error) *SnapshotManagerSnapshotDone { + return &SnapshotManagerSnapshotDone{ + stateIndex: stateIndex, + commitment: commitment, + result: result, + } +} + +func (smsdT *SnapshotManagerSnapshotDone) GetStateIndex() uint32 { + return smsdT.stateIndex +} + +func (smsdT *SnapshotManagerSnapshotDone) GetCommitment() *state.L1Commitment { + return smsdT.commitment +} + +func (smsdT *SnapshotManagerSnapshotDone) GetResult() error { + return smsdT.result +} diff --git a/packages/chain/statemanager/sm_gpa/state_manager_gpa.go b/packages/chain/statemanager/sm_gpa/state_manager_gpa.go index bd090698b7..77b13f7b73 100644 --- a/packages/chain/statemanager/sm_gpa/state_manager_gpa.go +++ b/packages/chain/statemanager/sm_gpa/state_manager_gpa.go @@ -28,12 +28,15 @@ type stateManagerGPA struct { blocksToFetch blockFetchers blocksFetched blockFetchers nodeRandomiser sm_utils.NodeRandomiser + snapshotExistsFun SnapshotExistsFun store state.Store + output StateManagerOutput parameters StateManagerParameters lastGetBlocksTime time.Time lastCleanBlockCacheTime time.Time lastCleanRequestsTime time.Time lastStatusLogTime time.Time + lastSnapshotsUpdateTime time.Time metrics *metrics.ChainStateManagerMetrics } @@ -48,13 +51,14 @@ func New( chainID isc.ChainID, nr sm_utils.NodeRandomiser, wal sm_gpa_utils.BlockWAL, + snapshotExistsFun SnapshotExistsFun, store state.Store, metrics *metrics.ChainStateManagerMetrics, log *logger.Logger, parameters StateManagerParameters, ) (gpa.GPA, error) { var err error - smLog := log.Named("gpa") + smLog := log.Named("GPA") blockCache, err := sm_gpa_utils.NewBlockCache(parameters.TimeProvider, parameters.BlockCacheMaxSize, wal, metrics, smLog) if err != nil { return nil, fmt.Errorf("error creating block cache: %v", err) @@ -66,11 +70,15 @@ func New( blocksToFetch: newBlockFetchers(newBlockFetchersMetrics(metrics.IncBlocksFetching, metrics.DecBlocksFetching, metrics.StateManagerBlockFetched)), blocksFetched: newBlockFetchers(newBlockFetchersMetrics(metrics.IncBlocksPending, metrics.DecBlocksPending, bfmNopDurationFun)), nodeRandomiser: nr, + snapshotExistsFun: snapshotExistsFun, store: store, + output: newOutput(), parameters: parameters, lastGetBlocksTime: time.Time{}, lastCleanBlockCacheTime: time.Time{}, + lastCleanRequestsTime: time.Time{}, lastStatusLogTime: time.Time{}, + lastSnapshotsUpdateTime: time.Time{}, metrics: metrics, } @@ -91,6 +99,8 @@ func (smT *stateManagerGPA) Input(input gpa.Input) gpa.OutMessages { return smT.handleConsensusBlockProduced(inputCasted) case *sm_inputs.ChainFetchStateDiff: // From mempool return smT.handleChainFetchStateDiff(inputCasted) + case *sm_inputs.SnapshotManagerSnapshotDone: // From snapshot manager + return smT.handleSnapshotManagerSnapshotDone(inputCasted) case *sm_inputs.StateManagerTimerTick: // From state manager go routine return smT.handleStateManagerTimerTick(inputCasted.GetTime()) default: @@ -112,7 +122,7 @@ func (smT *stateManagerGPA) Message(msg gpa.Message) gpa.OutMessages { } func (smT *stateManagerGPA) Output() gpa.Output { - return nil + return smT.output } func (smT *stateManagerGPA) StatusString() string { @@ -174,25 +184,25 @@ func (smT *stateManagerGPA) handlePeerBlock(from gpa.NodeID, block state.Block) func (smT *stateManagerGPA) handleConsensusStateProposal(csp *sm_inputs.ConsensusStateProposal) gpa.OutMessages { start := time.Now() - smT.log.Debugf("Input consensus state proposal %s received...", csp.GetL1Commitment()) + smT.log.Debugf("Input consensus state proposal index %v %s received...", csp.GetStateIndex(), csp.GetL1Commitment()) callback := newBlockRequestCallback( func() bool { return csp.IsValid() }, func() { csp.Respond() - smT.log.Debugf("Input consensus state proposal %s: responded to consensus", csp.GetL1Commitment()) + smT.log.Debugf("Input consensus state proposal index %v %s: responded to consensus", csp.GetStateIndex(), csp.GetL1Commitment()) smT.metrics.ConsensusStateProposalHandled(time.Since(start)) }, ) - messages := smT.traceBlockChainWithCallback(csp.GetL1Commitment(), callback) - smT.log.Debugf("Input consensus state proposal %s handled", csp.GetL1Commitment()) + messages := smT.traceBlockChainWithCallback(csp.GetStateIndex(), csp.GetL1Commitment(), callback) + smT.log.Debugf("Input consensus state proposal index %v %s handled", csp.GetStateIndex(), csp.GetL1Commitment()) return messages } func (smT *stateManagerGPA) handleConsensusDecidedState(cds *sm_inputs.ConsensusDecidedState) gpa.OutMessages { start := time.Now() - smT.log.Debugf("Input consensus decided state %s received...", cds.GetL1Commitment()) + smT.log.Debugf("Input consensus decided state index %v %s received...", cds.GetStateIndex(), cds.GetL1Commitment()) callback := newBlockRequestCallback( func() bool { return cds.IsValid() @@ -200,22 +210,23 @@ func (smT *stateManagerGPA) handleConsensusDecidedState(cds *sm_inputs.Consensus func() { state, err := smT.store.StateByTrieRoot(cds.GetL1Commitment().TrieRoot()) if err != nil { - smT.log.Errorf("Input consensus decided state %s: error obtaining state: %w", cds.GetL1Commitment(), err) + smT.log.Errorf("Input consensus decided state index %v %s: error obtaining state: %w", cds.GetStateIndex(), cds.GetL1Commitment(), err) return } cds.Respond(state) - smT.log.Debugf("Input consensus decided state %s: responded to consensus with state index %v", cds.GetL1Commitment(), state.BlockIndex()) + smT.log.Debugf("Input consensus decided state index %v %s: responded to consensus with state index %v", + cds.GetStateIndex(), cds.GetL1Commitment(), state.BlockIndex()) smT.metrics.ConsensusDecidedStateHandled(time.Since(start)) }, ) - messages := smT.traceBlockChainWithCallback(cds.GetL1Commitment(), callback) - smT.log.Debugf("Input consensus decided state %s handled", cds.GetL1Commitment()) + messages := smT.traceBlockChainWithCallback(cds.GetStateIndex(), cds.GetL1Commitment(), callback) + smT.log.Debugf("Input consensus decided state index %v %s handled", cds.GetStateIndex(), cds.GetL1Commitment()) return messages } func (smT *stateManagerGPA) handleConsensusBlockProduced(input *sm_inputs.ConsensusBlockProduced) gpa.OutMessages { start := time.Now() - stateIndex := input.GetStateDraft().BlockIndex() + stateIndex := input.GetStateDraft().BlockIndex() - 1 // NOTE: as this state draft is complete, the returned index is the one of the next state (which will be obtained, once this state draft is committed); to get the index of the base state, we need to subtract one commitment := input.GetStateDraft().BaseL1Commitment() smT.log.Debugf("Input block produced on state index %v %s received...", stateIndex, commitment) if !smT.store.HasTrieRoot(commitment.TrieRoot()) { @@ -226,12 +237,12 @@ func (smT *stateManagerGPA) handleConsensusBlockProduced(input *sm_inputs.Consen blockCommitment := block.L1Commitment() smT.blockCache.AddBlock(block) input.Respond(block) - smT.log.Debugf("Input block produced on state index %v %s: state draft index %v has been committed to the store, responded to consensus with resulting block %s", - stateIndex, commitment, input.GetStateDraft().BlockIndex(), blockCommitment) + smT.log.Debugf("Input block produced on state index %v %s: state draft has been committed to the store, responded to consensus with resulting block index %v %s", + stateIndex, commitment, block.StateIndex(), blockCommitment) fetcher := smT.blocksToFetch.takeFetcher(blockCommitment) var result gpa.OutMessages if fetcher != nil { - result = smT.markFetched(fetcher) + result = smT.markFetched(fetcher, false) } smT.log.Debugf("Input block produced on state index %v %s handled", stateIndex, commitment) smT.metrics.ConsensusBlockProducedHandled(time.Since(start)) @@ -245,54 +256,9 @@ func (smT *stateManagerGPA) handleChainFetchStateDiff(input *sm_inputs.ChainFetc oldBlockRequestCompleted := false newBlockRequestCompleted := false isValidFun := func() bool { return input.IsValid() } - obtainCommittedBlockFun := func(commitment *state.L1Commitment) state.Block { - result := smT.getBlock(commitment) - if result == nil { - smT.log.Panicf("Input mempool state request for state index %v %s: cannot obtain block %s", input.GetNewStateIndex(), input.GetNewL1Commitment(), commitment) - } - return result - } - lastBlockFun := func(blocks []state.Block) state.Block { - return blocks[len(blocks)-1] - } - respondFun := func() { - oldBlock := obtainCommittedBlockFun(input.GetOldL1Commitment()) - newBlock := obtainCommittedBlockFun(input.GetNewL1Commitment()) - oldChainOfBlocks := []state.Block{oldBlock} - newChainOfBlocks := []state.Block{newBlock} - for lastBlockFun(oldChainOfBlocks).StateIndex() > lastBlockFun(newChainOfBlocks).StateIndex() { - oldChainOfBlocks = append(oldChainOfBlocks, obtainCommittedBlockFun(lastBlockFun(oldChainOfBlocks).PreviousL1Commitment())) - } - for lastBlockFun(oldChainOfBlocks).StateIndex() < lastBlockFun(newChainOfBlocks).StateIndex() { - newChainOfBlocks = append(newChainOfBlocks, obtainCommittedBlockFun(lastBlockFun(newChainOfBlocks).PreviousL1Commitment())) - } - for lastBlockFun(oldChainOfBlocks).StateIndex() > 0 { - if lastBlockFun(oldChainOfBlocks).L1Commitment().Equals(lastBlockFun(newChainOfBlocks).L1Commitment()) { - break - } - oldChainOfBlocks = append(oldChainOfBlocks, obtainCommittedBlockFun(lastBlockFun(oldChainOfBlocks).PreviousL1Commitment())) - newChainOfBlocks = append(newChainOfBlocks, obtainCommittedBlockFun(lastBlockFun(newChainOfBlocks).PreviousL1Commitment())) - } - commonIndex := lastBlockFun(oldChainOfBlocks).StateIndex() - commonCommitment := lastBlockFun(oldChainOfBlocks).L1Commitment() - oldChainOfBlocks = lo.Reverse(oldChainOfBlocks[:len(oldChainOfBlocks)-1]) - newChainOfBlocks = lo.Reverse(newChainOfBlocks[:len(newChainOfBlocks)-1]) - newState, err := smT.store.StateByTrieRoot(input.GetNewL1Commitment().TrieRoot()) - if err != nil { - smT.log.Errorf("Input mempool state request for state index %v %s: error obtaining state: %w", - input.GetNewStateIndex(), input.GetNewL1Commitment(), err) - return - } - input.Respond(sm_inputs.NewChainFetchStateDiffResults(newState, newChainOfBlocks, oldChainOfBlocks)) - smT.log.Debugf("Input mempool state request for state index %v %s: responded to chain with requested state, "+ - "and block chains of length %v (requested) and %v (old) with common ancestor index %v %s", - input.GetNewStateIndex(), input.GetNewL1Commitment(), len(newChainOfBlocks), len(oldChainOfBlocks), - commonIndex, commonCommitment) - smT.metrics.ChainFetchStateDiffHandled(time.Since(start)) - } respondIfNeededFun := func() { if oldBlockRequestCompleted && newBlockRequestCompleted { - respondFun() + smT.handleChainFetchStateDiffRespond(input, start) } } oldRequestCallback := newBlockRequestCallback(isValidFun, func() { @@ -308,13 +274,121 @@ func (smT *stateManagerGPA) handleChainFetchStateDiff(input *sm_inputs.ChainFetc respondIfNeededFun() }) result := gpa.NoMessages() - result.AddAll(smT.traceBlockChainWithCallback(input.GetOldL1Commitment(), oldRequestCallback)) - result.AddAll(smT.traceBlockChainWithCallback(input.GetNewL1Commitment(), newRequestCallback)) + result.AddAll(smT.traceBlockChainWithCallback(input.GetOldStateIndex(), input.GetOldL1Commitment(), oldRequestCallback)) + result.AddAll(smT.traceBlockChainWithCallback(input.GetNewStateIndex(), input.GetNewL1Commitment(), newRequestCallback)) smT.log.Debugf("Input mempool state request for state index %v %s handled", input.GetNewStateIndex(), input.GetNewL1Commitment()) return result } +func (smT *stateManagerGPA) handleChainFetchStateDiffRespond(input *sm_inputs.ChainFetchStateDiff, start time.Time) { //nolint:funlen + makeCallbackFun := func(part string) blockRequestCallback { + return newBlockRequestCallback( + func() bool { return input.IsValid() }, + func() { + smT.log.Debugf("Input mempool state request for state index %v %s: %s block request completed once again", + input.GetNewStateIndex(), input.GetNewL1Commitment(), part) + smT.handleChainFetchStateDiffRespond(input, start) + }, + ) + } + obtainCommittedPreviousBlockFun := func(block state.Block, part string) state.Block { + commitment := block.PreviousL1Commitment() + result := smT.getBlock(commitment) + if result == nil { + blockIndex := block.StateIndex() - 1 + smT.log.Debugf("Input mempool state request for state index %v %s: block %v %s in the %s block chain is missing; fetching it", + input.GetNewStateIndex(), input.GetNewL1Commitment(), blockIndex, commitment, part) + // NOTE: returned messages are not sent out; only GetBlock messages are possible in this case and + // these messages will be sent out at the next retry; + smT.traceBlockChainWithCallback(blockIndex, commitment, makeCallbackFun(part)) + } + return result + } + lastBlockFun := func(blocks []state.Block) state.Block { + return blocks[len(blocks)-1] + } + oldBlock := smT.getBlock(input.GetOldL1Commitment()) + if oldBlock == nil { + smT.log.Panicf("Input mempool state request for state index %v %s: cannot obtain final old block %s", + input.GetNewStateIndex(), input.GetNewL1Commitment(), input.GetOldL1Commitment()) + return + } + newBlock := smT.getBlock(input.GetNewL1Commitment()) + if newBlock == nil { + smT.log.Panicf("Input mempool state request for state index %v %s: cannot obtain final new block %s", + input.GetNewStateIndex(), input.GetNewL1Commitment(), input.GetNewL1Commitment()) + return + } + oldChainOfBlocks := []state.Block{oldBlock} + newChainOfBlocks := []state.Block{newBlock} + for lastBlockFun(oldChainOfBlocks).StateIndex() > lastBlockFun(newChainOfBlocks).StateIndex() { + oldBlock = obtainCommittedPreviousBlockFun(lastBlockFun(oldChainOfBlocks), "old") + if oldBlock == nil { + return + } + oldChainOfBlocks = append(oldChainOfBlocks, oldBlock) + } + for lastBlockFun(oldChainOfBlocks).StateIndex() < lastBlockFun(newChainOfBlocks).StateIndex() { + newBlock = obtainCommittedPreviousBlockFun(lastBlockFun(newChainOfBlocks), "new") + if newBlock == nil { + return + } + newChainOfBlocks = append(newChainOfBlocks, newBlock) + } + for lastBlockFun(oldChainOfBlocks).StateIndex() > 0 { + if lastBlockFun(oldChainOfBlocks).L1Commitment().Equals(lastBlockFun(newChainOfBlocks).L1Commitment()) { + break + } + oldBlock = obtainCommittedPreviousBlockFun(lastBlockFun(oldChainOfBlocks), "old") + if oldBlock == nil { + return + } + newBlock = obtainCommittedPreviousBlockFun(lastBlockFun(newChainOfBlocks), "new") + if newBlock == nil { + return + } + oldChainOfBlocks = append(oldChainOfBlocks, oldBlock) + newChainOfBlocks = append(newChainOfBlocks, newBlock) + } + commonIndex := lastBlockFun(oldChainOfBlocks).StateIndex() + commonCommitment := lastBlockFun(oldChainOfBlocks).L1Commitment() + oldChainOfBlocks = lo.Reverse(oldChainOfBlocks[:len(oldChainOfBlocks)-1]) + newChainOfBlocks = lo.Reverse(newChainOfBlocks[:len(newChainOfBlocks)-1]) + newState, err := smT.store.StateByTrieRoot(input.GetNewL1Commitment().TrieRoot()) + if err != nil { + smT.log.Errorf("Input mempool state request for state index %v %s: error obtaining state: %w", + input.GetNewStateIndex(), input.GetNewL1Commitment(), err) + return + } + input.Respond(sm_inputs.NewChainFetchStateDiffResults(newState, newChainOfBlocks, oldChainOfBlocks)) + smT.log.Debugf("Input mempool state request for state index %v %s: responded to chain with requested state, "+ + "and block chains of length %v (requested) and %v (old) with common ancestor index %v %s", + input.GetNewStateIndex(), input.GetNewL1Commitment(), len(newChainOfBlocks), len(oldChainOfBlocks), + commonIndex, commonCommitment) + smT.metrics.ChainFetchStateDiffHandled(time.Since(start)) +} + +func (smT *stateManagerGPA) handleSnapshotManagerSnapshotDone(input *sm_inputs.SnapshotManagerSnapshotDone) gpa.OutMessages { + stateIndex := input.GetStateIndex() + commitment := input.GetCommitment() + smT.log.Debugf("Input snapshot manager snapshot %v %s done received, result=%v...", stateIndex, commitment, input.GetResult()) + fetcher := smT.blocksFetched.takeFetcher(input.GetCommitment()) + if fetcher == nil { + smT.log.Warnf("Input snapshot manager snapshot %v %s done: snapshot no longer needed, ignoring it", stateIndex, commitment) + return nil // No messages to send + } + if input.GetResult() != nil { + // TODO: maybe downloading snapshot should be retried? + smT.log.Errorf("Input snapshot manager snapshot %v %s done: retrieving snapshot failed %v", stateIndex, commitment, input.GetResult()) + smT.blocksToFetch.addFetcher(fetcher) + return smT.makeGetBlockRequestMessages(commitment) + } + result := smT.markFetched(fetcher, false) + smT.log.Debugf("Input snapshot manager snapshot %v %s done handled.", stateIndex, commitment) + return result +} + func (smT *stateManagerGPA) getBlock(commitment *state.L1Commitment) state.Block { block := smT.blockCache.GetBlock(commitment) if block != nil { @@ -346,23 +420,23 @@ func (smT *stateManagerGPA) getBlock(commitment *state.L1Commitment) state.Block return block } -func (smT *stateManagerGPA) traceBlockChainWithCallback(lastCommitment *state.L1Commitment, callback blockRequestCallback) gpa.OutMessages { +func (smT *stateManagerGPA) traceBlockChainWithCallback(index uint32, lastCommitment *state.L1Commitment, callback blockRequestCallback) gpa.OutMessages { if smT.store.HasTrieRoot(lastCommitment.TrieRoot()) { - smT.log.Debugf("Tracing block %s chain: the block is already in the store, calling back", lastCommitment) + smT.log.Debugf("Tracing block index %v %s chain: the block is already in the store, calling back", index, lastCommitment) callback.requestCompleted() return nil // No messages to send } if smT.blocksToFetch.addCallback(lastCommitment, callback) { smT.metrics.IncRequestsWaiting() - smT.log.Debugf("Tracing block %s chain: the block is already being fetched", lastCommitment) + smT.log.Debugf("Tracing block index %v %s chain: the block is already being fetched", index, lastCommitment) return nil } if smT.blocksFetched.addCallback(lastCommitment, callback) { smT.metrics.IncRequestsWaiting() - smT.log.Debugf("Tracing block %s chain: the block is already fetched, but cannot yet be committed", lastCommitment) + smT.log.Debugf("Tracing block index %v %s chain: the block is already fetched, but cannot yet be committed", index, lastCommitment) return nil } - fetcher := newBlockFetcherWithCallback(lastCommitment, callback) + fetcher := newBlockFetcherWithCallback(index, lastCommitment, callback) smT.metrics.IncRequestsWaiting() return smT.traceBlockChain(fetcher) } @@ -372,10 +446,16 @@ func (smT *stateManagerGPA) traceBlockChainWithCallback(lastCommitment *state.L1 // requested node has the required block committed into the store, it certainly // has all the blocks before it. func (smT *stateManagerGPA) traceBlockChain(fetcher blockFetcher) gpa.OutMessages { + stateIndex := fetcher.getStateIndex() commitment := fetcher.getCommitment() if !smT.store.HasTrieRoot(commitment.TrieRoot()) { block := smT.blockCache.GetBlock(commitment) if block == nil { + if smT.snapshotExistsFun(stateIndex, commitment) { + smT.output.addSnapshotToLoad(stateIndex, commitment) + smT.blocksFetched.addFetcher(fetcher) + return nil // No messages to send + } smT.blocksToFetch.addFetcher(fetcher) smT.log.Debugf("Block %s is missing, starting fetching it", commitment) return smT.makeGetBlockRequestMessages(commitment) @@ -385,29 +465,29 @@ func (smT *stateManagerGPA) traceBlockChain(fetcher blockFetcher) gpa.OutMessage previousCommitment := block.PreviousL1Commitment() smT.log.Debugf("Tracing block index %v %s -> previous block %v %s", blockIndex, commitment, previousBlockIndex, previousCommitment) if previousCommitment == nil { - result := smT.markFetched(fetcher) + result := smT.markFetched(fetcher, true) smT.log.Debugf("Traced to the initial block") return result } smT.blocksFetched.addFetcher(fetcher) if smT.blocksToFetch.addRelatedFetcher(previousCommitment, fetcher) { smT.log.Debugf("Block %v %s is already being fetched", previousBlockIndex, previousCommitment) - return nil + return nil // No messages to send } if smT.blocksFetched.addRelatedFetcher(previousCommitment, fetcher) { smT.log.Debugf("Block %v %s is already fetched, but cannot yet be committed", previousBlockIndex, previousCommitment) - return nil + return nil // No messages to send } return smT.traceBlockChain(newBlockFetcherWithRelatedFetcher(previousCommitment, fetcher)) } - result := smT.markFetched(fetcher) + result := smT.markFetched(fetcher, false) smT.log.Debugf("Block %s is already committed", commitment) return result } -func (smT *stateManagerGPA) markFetched(fetcher blockFetcher) gpa.OutMessages { +func (smT *stateManagerGPA) markFetched(fetcher blockFetcher, commitInitial bool) gpa.OutMessages { result := gpa.NoMessages() - fetcher.notifyFetched(func(bf blockFetcher) bool { + commitFun := func(bf blockFetcher) bool { commitment := bf.getCommitment() block := smT.blockCache.GetBlock(commitment) if block == nil { @@ -444,7 +524,12 @@ func (smT *stateManagerGPA) markFetched(fetcher blockFetcher) gpa.OutMessages { _ = smT.blocksFetched.takeFetcher(commitment) smT.metrics.SubRequestsWaiting(bf.getCallbacksCount()) return true - }) + } + if commitInitial { + fetcher.commitAndNotifyFetched(commitFun) + } else { + fetcher.notifyFetched(commitFun) + } return result } @@ -493,6 +578,13 @@ func (smT *stateManagerGPA) handleStateManagerTimerTick(now time.Time) gpa.OutMe smT.log.Debugf("Callbacks of block fetchers cleaned, %v waiting callbacks remained, next cleaning not earlier than %v", waitingCallbacks, smT.lastCleanRequestsTime.Add(smT.parameters.StateManagerRequestCleaningPeriod)) } + nextSnapshotsUpdateTime := smT.lastSnapshotsUpdateTime.Add(smT.parameters.SnapshotManagerUpdatePeriod) + if now.After(nextSnapshotsUpdateTime) { + smT.output.setUpdateSnapshots() + smT.lastSnapshotsUpdateTime = now + smT.log.Debugf("Ordered snapshot update, next update not earlier than %v", + smT.lastSnapshotsUpdateTime.Add(smT.parameters.SnapshotManagerUpdatePeriod)) + } smT.metrics.StateManagerTimerTickHandled(time.Since(start)) return result } @@ -507,6 +599,7 @@ func (smT *stateManagerGPA) commitStateDraft(stateDraft state.StateDraft) state. if smT.pruningNeeded() { smT.pruneStore(block.PreviousL1Commitment()) } + smT.output.addBlockCommitted(block.StateIndex(), block.L1Commitment()) return block } diff --git a/packages/chain/statemanager/sm_gpa/state_manager_gpa_test.go b/packages/chain/statemanager/sm_gpa/state_manager_gpa_test.go index c56b888218..f4cdf599dd 100644 --- a/packages/chain/statemanager/sm_gpa/state_manager_gpa_test.go +++ b/packages/chain/statemanager/sm_gpa/state_manager_gpa_test.go @@ -8,19 +8,25 @@ import ( "github.com/stretchr/testify/require" + "github.com/iotaledger/hive.go/logger" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_inputs" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/gpa" "github.com/iotaledger/wasp/packages/origin" "github.com/iotaledger/wasp/packages/state" ) +var newEmptySnapshotManagerFun = func(_, _ state.Store, _ sm_gpa_utils.TimeProvider, _ *logger.Logger) sm_snapshots.SnapshotManagerTest { + return sm_snapshots.NewEmptySnapshotManager() +} + // Single node network. 8 blocks are sent to state manager. The result is checked // by checking the store and sending consensus requests, which force the access // of the blocks. func TestBasic(t *testing.T) { nodeIDs := gpa.MakeTestNodeIDs(1) - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun) defer env.finalize() nodeID := nodeIDs[0] @@ -41,7 +47,7 @@ func TestManyNodes(t *testing.T) { nodeIDs := gpa.MakeTestNodeIDs(10) smParameters := NewStateManagerParameters() smParameters.StateManagerGetBlockRetry = 100 * time.Millisecond - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() blocks := env.bf.GetBlocks(16, 1) @@ -109,7 +115,7 @@ func TestFull(t *testing.T) { nodeIDs := gpa.MakeTestNodeIDs(nodeCount) smParameters := NewStateManagerParameters() smParameters.StateManagerGetBlockRetry = 100 * time.Millisecond - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() lastCommitment := origin.L1Commitment(nil, 0) @@ -191,7 +197,7 @@ func TestMempoolRequest(t *testing.T) { nodeIDs := gpa.MakeTestNodeIDs(nodeCount) smParameters := NewStateManagerParameters() smParameters.StateManagerGetBlockRetry = 100 * time.Millisecond - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() mainBlocks := env.bf.GetBlocks(mainSize, 1) @@ -220,7 +226,7 @@ func TestMempoolRequest(t *testing.T) { // and block 0 as an old block. func TestMempoolRequestFirstStep(t *testing.T) { nodeIDs := gpa.MakeTestNodeIDs(1) - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun) defer env.finalize() nodeID := nodeIDs[0] @@ -243,7 +249,7 @@ func TestMempoolRequestNoBranch(t *testing.T) { middleBlock := 4 nodeIDs := gpa.MakeTestNodeIDs(1) - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun) defer env.finalize() nodeID := nodeIDs[0] @@ -269,7 +275,7 @@ func TestMempoolRequestBranchFromOrigin(t *testing.T) { branchSize := 8 nodeIDs := gpa.MakeTestNodeIDs(1) - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newEmptySnapshotManagerFun) defer env.finalize() nodeID := nodeIDs[0] @@ -286,6 +292,53 @@ func TestMempoolRequestBranchFromOrigin(t *testing.T) { require.True(env.t, env.sendAndEnsureCompletedChainFetchStateDiff(oldCommitment, newCommitment, oldBlocks, newBlocks, nodeID, 1, 0*time.Second)) } +// Two node setting. +// 1. A batch of 10 consecutive blocks is generated, each of them is sent +// to the first node. +// 2. A batch of 5 consecutive blocks is branched from block 4. Each of +// the blocks is sent to the first node. +// 3. Second node is configured to download snapshot at index 7 of both branches +// 4. A ChainFetchStateDiff request is sent for the branch as a new and +// and original batch as old. +func TestMempoolSnapshotInTheMiddle(t *testing.T) { + batchSize := 10 + branchSize := 5 + branchIndex := 4 + snapshottedIndex := 7 + + nodeIDs := gpa.MakeTestNodeIDs(2) + newMockedSnapshotManagerFun := func(origStore, nodeStore state.Store, timeProvider sm_gpa_utils.TimeProvider, log *logger.Logger) sm_snapshots.SnapshotManagerTest { + return sm_snapshots.NewMockedSnapshotManager(t, 0, origStore, nodeStore, 0*time.Second, 0*time.Second, timeProvider, log) + } + smParameters := NewStateManagerParameters() + smParameters.StateManagerGetBlockRetry = 100 * time.Millisecond + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewMockedTestBlockWAL, newMockedSnapshotManagerFun, smParameters) + defer env.finalize() + + oldBlocks := env.bf.GetBlocks(batchSize, 1) + newBlocks := env.bf.GetBlocksFrom(branchSize, 1, oldBlocks[branchIndex].L1Commitment(), 2) + oldSnapshottedBlock := oldBlocks[snapshottedIndex] + newSnapshottedBlock := newBlocks[snapshottedIndex-branchIndex-1] + env.snapms[nodeIDs[1]].SnapshotReady(sm_snapshots.NewSnapshotInfo(oldSnapshottedBlock.StateIndex(), oldSnapshottedBlock.L1Commitment())) + env.snapms[nodeIDs[1]].SnapshotReady(sm_snapshots.NewSnapshotInfo(newSnapshottedBlock.StateIndex(), newSnapshottedBlock.L1Commitment())) + env.snapms[nodeIDs[1]].UpdateAsync() + + env.sendBlocksToNode(nodeIDs[0], 0*time.Second, oldBlocks...) + require.True(env.t, env.ensureStoreContainsBlocksNoWait(nodeIDs[0], oldBlocks)) + + env.sendBlocksToNode(nodeIDs[0], 0*time.Second, newBlocks...) + require.True(env.t, env.ensureStoreContainsBlocksNoWait(nodeIDs[0], newBlocks)) + + oldCommitment := oldBlocks[len(oldBlocks)-1].L1Commitment() + newCommitment := newBlocks[len(newBlocks)-1].L1Commitment() + responseCh := env.sendChainFetchStateDiff(oldCommitment, newCommitment, nodeIDs[1]) + time.Sleep(10 * time.Millisecond) // To allow snapshot manager to receive load old state snapshot request + env.sendTimerTickToNodes(100 * time.Millisecond) // To check the response from snapshot manager about loaded old state snapshot; timer tick is not necessary: any input would be suitable + time.Sleep(10 * time.Millisecond) // To allow snapshot manager to receive load new state snapshot request + env.sendTimerTickToNodes(100 * time.Millisecond) // To check the response from snapshot manager about loaded new state snapshot; timer tick is not necessary: any input would be suitable + require.True(env.t, env.ensureCompletedChainFetchStateDiff(responseCh, oldBlocks[branchIndex+1:], newBlocks, 10, 100*time.Millisecond)) +} + // Single node setting, pruning leaves 10 historic blocks. // - 11 blocks are added into the store one by one; each time it is checked if // all of the added blocks are in the store (none of them got pruned). @@ -300,7 +353,7 @@ func TestPruningSequentially(t *testing.T) { nodeID := nodeIDs[0] smParameters := NewStateManagerParameters() smParameters.PruningMinStatesToKeep = blocksToKeep - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() blocks := env.bf.GetBlocks(blockCount, 1) @@ -338,7 +391,7 @@ func TestPruningMany(t *testing.T) { nodeID := nodeIDs[0] smParameters := NewStateManagerParameters() smParameters.PruningMinStatesToKeep = blocksToKeep // Also initializes chain with this value in governance contract - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() sm, ok := env.sms[nodeID] @@ -379,7 +432,7 @@ func TestPruningTooMuch(t *testing.T) { smParameters := NewStateManagerParameters() smParameters.PruningMinStatesToKeep = blocksToKeep // Also initializes chain with this value in governance contract smParameters.PruningMaxStatesToDelete = blocksToPrune - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() sm, ok := env.sms[nodeID] @@ -412,6 +465,146 @@ func TestPruningTooMuch(t *testing.T) { } } +// Two nodes setting: first node is making snapshots, the other is using them. +// - 30 blocks are committed to the first node. +// - Update is triggered (in both nodes) and created snapshots are available +// to be used by state manager. +// - The other node is requested to obtain snapshotted state, and several not +// snapshotted states. The results are checked. +func TestSnapshots(t *testing.T) { + blockCount := 30 + snapshotCreatePeriod := uint32(5) + snapshotCreateTime := 1 * time.Second + snapshotLoadTime := 2 * time.Second + snapshotCreatedFun := func(index uint32) bool { + return (index+1)%snapshotCreatePeriod == 0 + } + requestedStateIndex1 := uint32(5) + requestedStateIndex2 := uint32(14) + requestedStateIndex3 := uint32(23) + timerTickPeriod := 150 * time.Millisecond + + nodeIDs := gpa.MakeTestNodeIDs(2) + nodeIDFirst := nodeIDs[0] + nodeIDOther := nodeIDs[1] + newEmptyTestBlockWALFun := func(gpa.NodeID) sm_gpa_utils.TestBlockWAL { return sm_gpa_utils.NewEmptyTestBlockWAL() } + newMockedSnapshotManagerFun := func(nodeID gpa.NodeID, origStore, nodeStore state.Store, tp sm_gpa_utils.TimeProvider, log *logger.Logger) sm_snapshots.SnapshotManagerTest { + if nodeID.Equals(nodeIDFirst) { + return sm_snapshots.NewMockedSnapshotManager(t, snapshotCreatePeriod, origStore, nodeStore, snapshotCreateTime, snapshotLoadTime, tp, log) + } + return sm_snapshots.NewMockedSnapshotManager(t, 0, origStore, nodeStore, snapshotCreateTime, snapshotLoadTime, tp, log) + } + smParameters := NewStateManagerParameters() + smParameters.SnapshotManagerUpdatePeriod = 2 * time.Second + env := newVariedTestEnv(t, nodeIDs, newEmptyTestBlockWALFun, newMockedSnapshotManagerFun, smParameters) + env.snapms[nodeIDFirst].SetAfterSnapshotCreated(func(snapshotInfo sm_snapshots.SnapshotInfo) { + <-env.timeProvider.After(100 * time.Millisecond) // Other node knows about the snapshot a bit later + env.snapms[nodeIDOther].SnapshotReady(snapshotInfo) + }) + defer env.finalize() + + blocks := env.bf.GetBlocks(blockCount, 1) + snapshotInfos := make([]sm_snapshots.SnapshotInfo, len(blocks)) + for i := range snapshotInfos { + snapshotInfos[i] = sm_snapshots.NewSnapshotInfo(blocks[i].StateIndex(), blocks[i].L1Commitment()) + } + type expectedValues struct { + snapshotReady bool // Ready to be picked up by snapshot manager's Update + snapshotExists bool // Already picked up by snapshot manager's Update; available to node + blockCommitted bool + } + checkBlocksInNodeFun := func(expected []expectedValues, nodeID gpa.NodeID) { + snapM, ok := env.snapms[nodeID] + require.True(env.t, ok) + store, ok := env.stores[nodeID] + require.True(env.t, ok) + for i := range expected { + env.t.Logf("Checking snapshot/block index %v at node %v", snapshotInfos[i].GetStateIndex(), nodeID) + require.Equal(env.t, expected[i].snapshotReady, snapM.IsSnapshotReady(snapshotInfos[i])) + require.Equal(env.t, expected[i].snapshotExists, snapM.SnapshotExists(snapshotInfos[i].GetStateIndex(), snapshotInfos[i].GetCommitment())) + require.Equal(env.t, expected[i].blockCommitted, store.HasTrieRoot(snapshotInfos[i].GetTrieRoot())) + } + } + expectedFirst := make([]expectedValues, len(blocks)) + expectedOther := make([]expectedValues, len(blocks)) + checkBlocksFun := func() { + checkBlocksInNodeFun(expectedFirst, nodeIDFirst) + checkBlocksInNodeFun(expectedOther, nodeIDOther) + } + checkBlocksFun() // At start no blocks/snapshots are in any node + env.sendTimerTickToNodes(0 * time.Second) // Initial timer tick to send first snapshot manager Update request + time.Sleep(10 * time.Millisecond) // Time for first snapshot manager Update request to propagate to snapshot manager (and do nothing) + + // Blocks are sent to the first node: they are committed there, snapshots are being produced, but not yet available + env.sendBlocksToNode(nodeIDFirst, 0*time.Second, blocks...) + for i := range blocks { + expectedFirst[i].blockCommitted = true + } + checkBlocksFun() + + // Time is passing, snapshots are produced and are ready in the first node; Update hasn't picked them up yet + for i := 0; i < 7; i++ { + env.sendTimerTickToNodes(timerTickPeriod) // Timer tick is not necessary; it's just a way to advance artificial timer + } + time.Sleep(10 * time.Millisecond) // To allow threads, that "create snapshots", to wake up + for i := range blocks { + if snapshotCreatedFun(uint32(i)) { + expectedFirst[i].snapshotReady = true + } + } + checkBlocksFun() + + // Some more time passes, produced snapshots are visible in other node too; Update hasn't picked them up yet + env.sendTimerTickToNodes(timerTickPeriod) // Timer tick is not necessary; it's just a way to advance artificial timer + for i := range blocks { + expectedOther[i].snapshotReady = expectedFirst[i].snapshotReady + } + time.Sleep(10 * time.Millisecond) // To allow threads, that "create snapshots", to wake up + checkBlocksFun() + + // More time passes, Update event is triggered in both nodes, snapshots are available for state managers of both nodes + for i := 0; i < 7; i++ { + env.sendTimerTickToNodes(timerTickPeriod) // Only the last timer tick is necessary as it sends Update request to snapshot manager + } + time.Sleep(10 * time.Millisecond) // Time for snapshot manager Update request to propagate to snapshot manager + for i := range blocks { + expectedFirst[i].snapshotExists = expectedFirst[i].snapshotReady + expectedOther[i].snapshotExists = expectedOther[i].snapshotReady + } + checkBlocksFun() + + sendAndEnsureCompletedConsensusStateProposalWithWaitFun := func(snapshotInfo sm_snapshots.SnapshotInfo) { + respCh := env.sendConsensusStateProposal(snapshotInfo.GetCommitment(), nodeIDOther) + time.Sleep(10 * time.Millisecond) // Time for load snapshot request to propagate to snapshot manager + for i := 0; i < 14; i++ { + env.sendTimerTickToNodes(timerTickPeriod) // Timer tick is not necessary; it's just a way to advance artificial timer + } + time.Sleep(10 * time.Millisecond) // To allow snapshot manager thread to wake up and respond + require.True(env.t, env.ensureCompletedConsensusStateProposal(respCh, 2, timerTickPeriod)) + } + + // Request for other node to have state, which contains snapshot; snapshot is downloaded, no other blocks are committed + require.True(env.t, snapshotCreatedFun(requestedStateIndex2)) + sendAndEnsureCompletedConsensusStateProposalWithWaitFun(snapshotInfos[requestedStateIndex2]) + expectedOther[requestedStateIndex2].blockCommitted = true + checkBlocksFun() + + // Request for other node to have state, which is one index after snapshot; snapshot and the requested block are downloaded + require.True(env.t, snapshotCreatedFun(requestedStateIndex1-1)) + sendAndEnsureCompletedConsensusStateProposalWithWaitFun(snapshotInfos[requestedStateIndex1]) + expectedOther[requestedStateIndex1].blockCommitted = true + expectedOther[requestedStateIndex1-1].blockCommitted = true + checkBlocksFun() + + // Request for other node to have state, which is one index before snapshot; all the blocks up to previous snapshot and the previous snapshot are downloaded + require.True(env.t, snapshotCreatedFun(requestedStateIndex3+1)) + sendAndEnsureCompletedConsensusStateProposalWithWaitFun(snapshotInfos[requestedStateIndex3]) + for i := requestedStateIndex3; i > requestedStateIndex3-snapshotCreatePeriod; i-- { + expectedOther[i].blockCommitted = true + } + checkBlocksFun() +} + // Single node network. Checks if block cache is cleaned via state manager // timer events. func TestBlockCacheCleaningAuto(t *testing.T) { @@ -419,7 +612,7 @@ func TestBlockCacheCleaningAuto(t *testing.T) { smParameters := NewStateManagerParameters() smParameters.BlockCacheBlocksInCacheDuration = 300 * time.Millisecond smParameters.BlockCacheBlockCleaningPeriod = 70 * time.Millisecond - env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, smParameters) + env := newTestEnv(t, nodeIDs, sm_gpa_utils.NewEmptyTestBlockWAL, newEmptySnapshotManagerFun, smParameters) defer env.finalize() nodeID := nodeIDs[0] diff --git a/packages/chain/statemanager/sm_gpa/state_manager_output.go b/packages/chain/statemanager/sm_gpa/state_manager_output.go new file mode 100644 index 0000000000..7684f41450 --- /dev/null +++ b/packages/chain/statemanager/sm_gpa/state_manager_output.go @@ -0,0 +1,60 @@ +package sm_gpa + +import ( + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" + "github.com/iotaledger/wasp/packages/gpa" + "github.com/iotaledger/wasp/packages/state" +) + +type smOutputImpl struct { + blocksCommitted []sm_snapshots.SnapshotInfo + snapshotsToLoad []sm_snapshots.SnapshotInfo + updateSnapshots bool +} + +var ( + _ gpa.Output = &smOutputImpl{} + _ StateManagerOutput = &smOutputImpl{} +) + +func newOutput() StateManagerOutput { + return &smOutputImpl{ + blocksCommitted: make([]sm_snapshots.SnapshotInfo, 0), + snapshotsToLoad: make([]sm_snapshots.SnapshotInfo, 0, 1), + } +} + +func (smoi *smOutputImpl) addBlockCommitted(stateIndex uint32, commitment *state.L1Commitment) { + smoi.blocksCommitted = append(smoi.blocksCommitted, sm_snapshots.NewSnapshotInfo(stateIndex, commitment)) +} + +func (smoi *smOutputImpl) addSnapshotToLoad(stateIndex uint32, commitment *state.L1Commitment) { + smoi.snapshotsToLoad = append(smoi.snapshotsToLoad, sm_snapshots.NewSnapshotInfo(stateIndex, commitment)) +} + +func (smoi *smOutputImpl) setUpdateSnapshots() { + smoi.updateSnapshots = true +} + +func (smoi *smOutputImpl) TakeBlocksCommitted() []sm_snapshots.SnapshotInfo { + result := smoi.blocksCommitted + smoi.blocksCommitted = make([]sm_snapshots.SnapshotInfo, 0) + return result +} + +func (smoi *smOutputImpl) TakeSnapshotToLoad() sm_snapshots.SnapshotInfo { + if len(smoi.snapshotsToLoad) == 0 { + return nil + } + result := smoi.snapshotsToLoad[0] + smoi.snapshotsToLoad = smoi.snapshotsToLoad[1:] + return result +} + +func (smoi *smOutputImpl) TakeUpdateSnapshots() bool { + if smoi.updateSnapshots { + smoi.updateSnapshots = false + return true + } + return false +} diff --git a/packages/chain/statemanager/sm_gpa/state_manager_parameters.go b/packages/chain/statemanager/sm_gpa/state_manager_parameters.go index d8209fbdfc..3fa287a38d 100644 --- a/packages/chain/statemanager/sm_gpa/state_manager_parameters.go +++ b/packages/chain/statemanager/sm_gpa/state_manager_parameters.go @@ -26,6 +26,8 @@ type StateManagerParameters struct { PruningMinStatesToKeep int // On single store pruning attempt at most this number of states will be deleted PruningMaxStatesToDelete int + // How often snapshot manager should update list of known snapshots + SnapshotManagerUpdatePeriod time.Duration TimeProvider sm_gpa_utils.TimeProvider } @@ -46,6 +48,7 @@ func NewStateManagerParameters(tpOpt ...sm_gpa_utils.TimeProvider) StateManagerP StateManagerTimerTickPeriod: 1 * time.Second, PruningMinStatesToKeep: 10000, PruningMaxStatesToDelete: 1000, + SnapshotManagerUpdatePeriod: 5 * time.Minute, TimeProvider: tp, } } diff --git a/packages/chain/statemanager/sm_gpa/test_env.go b/packages/chain/statemanager/sm_gpa/test_env.go index 629a3b6209..b56ef575e1 100644 --- a/packages/chain/statemanager/sm_gpa/test_env.go +++ b/packages/chain/statemanager/sm_gpa/test_env.go @@ -12,6 +12,7 @@ import ( "github.com/iotaledger/hive.go/logger" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_inputs" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_utils" "github.com/iotaledger/wasp/packages/gpa" "github.com/iotaledger/wasp/packages/isc" @@ -31,14 +32,42 @@ type testEnv struct { timeProvider sm_gpa_utils.TimeProvider sms map[gpa.NodeID]gpa.GPA stores map[gpa.NodeID]state.Store + snapms map[gpa.NodeID]sm_snapshots.SnapshotManagerTest + snaprchs map[gpa.NodeID]<-chan error + snaprsis map[gpa.NodeID]sm_snapshots.SnapshotInfo tc *gpa.TestContext log *logger.Logger } -func newTestEnv(t *testing.T, nodeIDs []gpa.NodeID, createWALFun func() sm_gpa_utils.TestBlockWAL, parametersOpt ...StateManagerParameters) *testEnv { +func newTestEnv( + t *testing.T, + nodeIDs []gpa.NodeID, + createWALFun func() sm_gpa_utils.TestBlockWAL, + createSnapMFun func(origStore, nodeStore state.Store, tp sm_gpa_utils.TimeProvider, log *logger.Logger) sm_snapshots.SnapshotManagerTest, + parametersOpt ...StateManagerParameters, +) *testEnv { + createWALVariedFun := func(gpa.NodeID) sm_gpa_utils.TestBlockWAL { + return createWALFun() + } + createSnapMVariedFun := func(nodeID gpa.NodeID, origStore, nodeStore state.Store, tp sm_gpa_utils.TimeProvider, log *logger.Logger) sm_snapshots.SnapshotManagerTest { + return createSnapMFun(origStore, nodeStore, tp, log) + } + return newVariedTestEnv(t, nodeIDs, createWALVariedFun, createSnapMVariedFun, parametersOpt...) +} + +func newVariedTestEnv( + t *testing.T, + nodeIDs []gpa.NodeID, + createWALFun func(gpa.NodeID) sm_gpa_utils.TestBlockWAL, + createSnapMFun func(nodeID gpa.NodeID, origStore, nodeStore state.Store, tp sm_gpa_utils.TimeProvider, log *logger.Logger) sm_snapshots.SnapshotManagerTest, + parametersOpt ...StateManagerParameters, +) *testEnv { var bf *sm_gpa_utils.BlockFactory sms := make(map[gpa.NodeID]gpa.GPA) stores := make(map[gpa.NodeID]state.Store) + snapms := make(map[gpa.NodeID]sm_snapshots.SnapshotManagerTest) + snaprchs := make(map[gpa.NodeID]<-chan error) + snaprsis := make(map[gpa.NodeID]sm_snapshots.SnapshotInfo) var parameters StateManagerParameters var chainInitParameters dict.Dict if len(parametersOpt) > 0 { @@ -52,29 +81,59 @@ func newTestEnv(t *testing.T, nodeIDs []gpa.NodeID, createWALFun func() sm_gpa_u bf = sm_gpa_utils.NewBlockFactory(t, chainInitParameters) chainID := bf.GetChainID() - log := testlogger.NewLogger(t).Named("c-" + chainID.ShortString()) + log := testlogger.NewLogger(t) parameters.TimeProvider = sm_gpa_utils.NewArtifficialTimeProvider() for _, nodeID := range nodeIDs { var err error smLog := log.Named(nodeID.ShortString()) nr := sm_utils.NewNodeRandomiser(nodeID, nodeIDs, smLog) - wal := createWALFun() + wal := createWALFun(nodeID) store := state.NewStore(mapdb.NewMapDB()) + snapshotManager := createSnapMFun(nodeID, bf.GetStore(), store, parameters.TimeProvider, smLog) + snapshotExistsFun := snapshotManager.SnapshotExists origin.InitChain(store, chainInitParameters, 0) stores[nodeID] = store - sms[nodeID], err = New(chainID, nr, wal, store, mockStateManagerMetrics(), smLog, parameters) + sms[nodeID], err = New(chainID, nr, wal, snapshotExistsFun, store, mockStateManagerMetrics(), smLog, parameters) require.NoError(t, err) + snapms[nodeID] = snapshotManager + snaprchs[nodeID] = nil + snaprsis[nodeID] = nil } - return &testEnv{ + result := &testEnv{ t: t, bf: bf, nodeIDs: nodeIDs, timeProvider: parameters.TimeProvider, sms: sms, + snapms: snapms, + snaprchs: snaprchs, + snaprsis: snaprsis, stores: stores, - tc: gpa.NewTestContext(sms), log: log, } + result.tc = gpa.NewTestContext(sms).WithOutputHandler(func(nodeID gpa.NodeID, outputOrig gpa.Output) { + output, ok := outputOrig.(StateManagerOutput) + require.True(result.t, ok) + result.checkSnapshotsLoaded() + snapshotManager, ok := result.snapms[nodeID] + require.True(result.t, ok) + snapshotRespChannel, ok := result.snaprchs[nodeID] + require.True(result.t, ok) + if snapshotRespChannel == nil { + snapshotInfo := output.TakeSnapshotToLoad() + if snapshotInfo != nil { + result.snaprchs[nodeID] = snapshotManager.LoadSnapshotAsync(snapshotInfo) + result.snaprsis[nodeID] = snapshotInfo + } + } + for _, snapshotInfo := range output.TakeBlocksCommitted() { + snapshotManager.BlockCommittedAsync(snapshotInfo) + } + if output.TakeUpdateSnapshots() { + snapshotManager.UpdateAsync() + } + }) + return result } func (teT *testEnv) finalize() { @@ -93,6 +152,24 @@ func (teT *testEnv) doesNotContainBlock(nodeID gpa.NodeID, block state.Block) { require.False(teT.t, store.HasTrieRoot(block.TrieRoot())) } +func (teT *testEnv) checkSnapshotsLoaded() { + inputs := make(map[gpa.NodeID]gpa.Input) + for nodeID, ch := range teT.snaprchs { + select { + case result, ok := <-ch: + if ok { + snapshotInfo, ok := teT.snaprsis[nodeID] + require.True(teT.t, ok) + input := sm_inputs.NewSnapshotManagerSnapshotDone(snapshotInfo.GetStateIndex(), snapshotInfo.GetCommitment(), result) + inputs[nodeID] = input + } + teT.snaprchs[nodeID] = nil + default: + } + } + teT.tc.WithInputs(inputs).RunAll() +} + func (teT *testEnv) sendBlocksToNode(nodeID gpa.NodeID, timeStep time.Duration, blocks ...state.Block) { // If `ConsensusBlockProduced` is sent to the node, the node has definitely obtained all the blocks // needed to commit this block. This is ensured by consensus. @@ -178,16 +255,11 @@ func (teT *testEnv) sendConsensusDecidedState(commitment *state.L1Commitment, no } func (teT *testEnv) ensureCompletedConsensusDecidedState(respChan <-chan state.State, expectedCommitment *state.L1Commitment, maxTimeIterations int, timeStep time.Duration) bool { - expectedState := teT.bf.GetState(expectedCommitment) return teT.ensureTrue("response from ConsensusDecidedState", func() bool { select { case s := <-respChan: - // Should be require.True(teT.t, expected.Equals(s)) - expectedTrieRoot := expectedState.TrieRoot() - receivedTrieRoot := s.TrieRoot() - require.Equal(teT.t, expectedState.BlockIndex(), s.BlockIndex()) - teT.t.Logf("Checking trie roots: expected %s, obtained %s", expectedTrieRoot, receivedTrieRoot) - require.True(teT.t, expectedTrieRoot.Equals(receivedTrieRoot)) + sm_gpa_utils.CheckStateInStore(teT.t, teT.bf.GetStore(), s) + require.True(teT.t, expectedCommitment.TrieRoot().Equals(s.TrieRoot())) return true default: return false @@ -216,14 +288,13 @@ func (teT *testEnv) ensureCompletedChainFetchStateDiff(respChan <-chan *sm_input lastNewBlockTrieRoot := expectedNewBlocks[len(expectedNewBlocks)-1].TrieRoot() teT.t.Logf("Checking trie roots: expected %s, obtained %s", lastNewBlockTrieRoot, newStateTrieRoot) require.True(teT.t, newStateTrieRoot.Equals(lastNewBlockTrieRoot)) + sm_gpa_utils.CheckStateInStore(teT.t, teT.bf.GetStore(), cfsdr.GetNewState()) requireEqualsFun := func(expected, received []state.Block) { teT.t.Logf("\tExpected %v elements, obtained %v elements", len(expected), len(received)) require.Equal(teT.t, len(expected), len(received)) for i := range expected { - expectedCommitment := expected[i].L1Commitment() - receivedCommitment := received[i].L1Commitment() - teT.t.Logf("\tchecking %v-th element: expected %s, received %s", i, expectedCommitment, receivedCommitment) - require.True(teT.t, expectedCommitment.Equals(receivedCommitment)) + teT.t.Logf("\tchecking %v-th element: expected %s, received %s", i, expected[i].L1Commitment(), received[i].L1Commitment()) + sm_gpa_utils.CheckBlocksEqual(teT.t, expected[i], received[i]) } } teT.t.Log("Checking added blocks...") diff --git a/packages/chain/statemanager/sm_snapshots/interface.go b/packages/chain/statemanager/sm_snapshots/interface.go new file mode 100644 index 0000000000..390c2a876e --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/interface.go @@ -0,0 +1,57 @@ +package sm_snapshots + +import ( + "io" + + "github.com/iotaledger/wasp/packages/state" + "github.com/iotaledger/wasp/packages/trie" +) + +type SnapshotManager interface { + UpdateAsync() + BlockCommittedAsync(SnapshotInfo) + SnapshotExists(uint32, *state.L1Commitment) bool + LoadSnapshotAsync(SnapshotInfo) <-chan error +} + +type SnapshotManagerTest interface { + SnapshotManager + SnapshotReady(SnapshotInfo) + IsSnapshotReady(SnapshotInfo) bool + SetAfterSnapshotCreated(func(SnapshotInfo)) +} + +type SnapshotInfo interface { + GetStateIndex() uint32 + GetCommitment() *state.L1Commitment + GetTrieRoot() trie.Hash + GetBlockHash() state.BlockHash + String() string + Equals(SnapshotInfo) bool +} + +type snapshotManagerCore interface { + createSnapshotsNeeded() bool + handleUpdate() + handleBlockCommitted(SnapshotInfo) + handleLoadSnapshot(SnapshotInfo, chan<- error) +} + +type snapshotter interface { + storeSnapshot(SnapshotInfo, io.Writer) error + loadSnapshot(SnapshotInfo, io.Reader) error +} + +// Putting slice into a map is not acceptable as if you want to append to slice, +// you'll have to re-include the appended slice into the map. +type SliceStruct[E any] interface { + Add(E) + Get(int) E + Set(int, E) + Length() int + ForEach(func(int, E) bool) bool + Clone() SliceStruct[E] // Returns a new SliceStruct with exactly the same elements + CloneDeep(func(E) E) SliceStruct[E] // Returns a new SliceStruct with every element of the old SliceStruct cloned using provided function + ContainsBy(func(E) bool) bool + Find(func(E) bool) (E, bool) +} diff --git a/packages/chain/statemanager/sm_snapshots/progress_reporter.go b/packages/chain/statemanager/sm_snapshots/progress_reporter.go new file mode 100644 index 0000000000..68673164b0 --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/progress_reporter.go @@ -0,0 +1,47 @@ +package sm_snapshots + +import ( + "io" + "time" + + "github.com/dustin/go-humanize" + + "github.com/iotaledger/hive.go/logger" +) + +type progressReporter struct { + log *logger.Logger + header string + lastReport time.Time + + expected uint64 + total uint64 + prevTotal uint64 +} + +var _ io.Writer = &progressReporter{} + +const logStatusPeriodConst = 1 * time.Second + +func NewProgressReporter(log *logger.Logger, header string, expected uint64) io.Writer { + return &progressReporter{ + log: log, + header: header, + lastReport: time.Time{}, + expected: expected, + total: 0, + prevTotal: 0, + } +} + +func (pr *progressReporter) Write(p []byte) (int, error) { + now := time.Now() + timeDiff := now.Sub(pr.lastReport) + if timeDiff >= logStatusPeriodConst { + bps := uint64(float64(pr.total-pr.prevTotal) / timeDiff.Seconds()) + pr.log.Debugf("%s: downloaded %s of %s (%s/s)", pr.header, humanize.Bytes(pr.total), humanize.Bytes(pr.expected), humanize.Bytes(bps)) + pr.lastReport = now + pr.prevTotal = pr.total + } + return len(p), nil +} diff --git a/packages/chain/statemanager/sm_snapshots/slice_struct.go b/packages/chain/statemanager/sm_snapshots/slice_struct.go new file mode 100644 index 0000000000..d848fd30ac --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/slice_struct.go @@ -0,0 +1,69 @@ +package sm_snapshots + +import ( + "github.com/samber/lo" +) + +type sliceStructImpl[E any] struct { + slice []E +} + +var _ SliceStruct[int] = &sliceStructImpl[int]{} + +func NewSliceStruct[E any](elems ...E) SliceStruct[E] { + return &sliceStructImpl[E]{slice: elems} +} + +func NewSliceStructLength[E any](length int) SliceStruct[E] { + return NewSliceStructLengthCapacity[E](length, length) +} + +func NewSliceStructLengthCapacity[E any](length, capacity int) SliceStruct[E] { + return &sliceStructImpl[E]{slice: make([]E, length, capacity)} +} + +func (s *sliceStructImpl[E]) Add(elem E) { + s.slice = append(s.slice, elem) +} + +func (s *sliceStructImpl[E]) Get(index int) E { + return s.slice[index] +} + +func (s *sliceStructImpl[E]) Set(index int, elem E) { + s.slice[index] = elem +} + +func (s *sliceStructImpl[E]) Length() int { + return len(s.slice) +} + +func (s *sliceStructImpl[E]) ForEach(forEachFun func(int, E) bool) bool { + for index, elem := range s.slice { + if !forEachFun(index, elem) { + return false + } + } + return true +} + +func (s *sliceStructImpl[E]) Clone() SliceStruct[E] { + return s.CloneDeep(func(elem E) E { return elem }) // NOTE: this is not deep cloning as the passed function is a simple identity +} + +func (s *sliceStructImpl[E]) CloneDeep(cloneFun func(E) E) SliceStruct[E] { + result := make([]E, s.Length()) + s.ForEach(func(index int, elem E) bool { + result[index] = cloneFun(elem) + return true + }) + return NewSliceStruct(result...) +} + +func (s *sliceStructImpl[E]) ContainsBy(fun func(E) bool) bool { + return lo.ContainsBy(s.slice, fun) +} + +func (s *sliceStructImpl[E]) Find(fun func(E) bool) (E, bool) { + return lo.Find(s.slice, fun) +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_info.go b/packages/chain/statemanager/sm_snapshots/snapshot_info.go new file mode 100644 index 0000000000..0b1ab191b1 --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_info.go @@ -0,0 +1,52 @@ +package sm_snapshots + +import ( + "fmt" + + "github.com/iotaledger/wasp/packages/state" + "github.com/iotaledger/wasp/packages/trie" +) + +type snapshotInfoImpl struct { + index uint32 + commitment *state.L1Commitment +} + +var _ SnapshotInfo = &snapshotInfoImpl{} + +func NewSnapshotInfo(index uint32, commitment *state.L1Commitment) SnapshotInfo { + return &snapshotInfoImpl{ + index: index, + commitment: commitment, + } +} + +func (si *snapshotInfoImpl) GetStateIndex() uint32 { + return si.index +} + +func (si *snapshotInfoImpl) GetCommitment() *state.L1Commitment { + return si.commitment +} + +func (si *snapshotInfoImpl) GetTrieRoot() trie.Hash { + return si.GetCommitment().TrieRoot() +} + +func (si *snapshotInfoImpl) GetBlockHash() state.BlockHash { + return si.GetCommitment().BlockHash() +} + +func (si *snapshotInfoImpl) String() string { + return fmt.Sprintf("%v %s", si.GetStateIndex(), si.GetCommitment()) +} + +func (si *snapshotInfoImpl) Equals(other SnapshotInfo) bool { + if si == nil { + return other == nil + } + if si.GetStateIndex() != other.GetStateIndex() { + return false + } + return si.GetCommitment().Equals(other.GetCommitment()) +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_manager.go b/packages/chain/statemanager/sm_snapshots/snapshot_manager.go new file mode 100644 index 0000000000..a3751f61cd --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_manager.go @@ -0,0 +1,439 @@ +package sm_snapshots + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/iotaledger/hive.go/ds/shrinkingmap" + "github.com/iotaledger/hive.go/logger" + "github.com/iotaledger/hive.go/runtime/ioutils" + "github.com/iotaledger/wasp/packages/isc" + "github.com/iotaledger/wasp/packages/shutdown" + "github.com/iotaledger/wasp/packages/state" +) + +type commitmentSources struct { + commitment *state.L1Commitment + sources []string +} + +type snapshotManagerImpl struct { + *snapshotManagerRunner + + log *logger.Logger + ctx context.Context + chainID isc.ChainID + + lastIndexSnapshotted uint32 + lastIndexSnapshottedMutex sync.Mutex + createPeriod uint32 + snapshotter snapshotter + + availableSnapshots *shrinkingmap.ShrinkingMap[uint32, SliceStruct[*commitmentSources]] + availableSnapshotsMutex sync.RWMutex + + localPath string + networkPaths []string +} + +var ( + _ snapshotManagerCore = &snapshotManagerImpl{} + _ SnapshotManager = &snapshotManagerImpl{} +) + +const ( + constDownloadTimeout = 10 * time.Minute + constSnapshotIndexHashFileNameSepparator = "-" + constSnapshotFileSuffix = ".snap" + constSnapshotTmpFileSuffix = ".tmp" + constIndexFileName = "INDEX" // Index file contains a new-line separated list of snapshot files + constLocalAddress = "local://" +) + +func NewSnapshotManager( + ctx context.Context, + shutdownCoordinator *shutdown.Coordinator, + chainID isc.ChainID, + createPeriod uint32, + baseLocalPath string, + baseNetworkPaths []string, + store state.Store, + log *logger.Logger, +) (SnapshotManager, error) { + chainIDString := chainID.String() + localPath := filepath.Join(baseLocalPath, chainIDString) + networkPaths := make([]string, len(baseNetworkPaths)) + var err error + for i := range baseNetworkPaths { + networkPaths[i], err = url.JoinPath(baseNetworkPaths[i], chainIDString) + if err != nil { + return nil, fmt.Errorf("cannot append chain ID to network path %s: %v", baseNetworkPaths[i], err) + } + } + snapMLog := log.Named("Snap") + result := &snapshotManagerImpl{ + log: snapMLog, + ctx: ctx, + chainID: chainID, + lastIndexSnapshotted: 0, + lastIndexSnapshottedMutex: sync.Mutex{}, + createPeriod: createPeriod, + snapshotter: newSnapshotter(store), + availableSnapshots: shrinkingmap.New[uint32, SliceStruct[*commitmentSources]](), + availableSnapshotsMutex: sync.RWMutex{}, + localPath: localPath, + networkPaths: networkPaths, + } + if result.createSnapshotsNeeded() { + if err := ioutils.CreateDirectory(localPath, 0o777); err != nil { + return nil, fmt.Errorf("cannot create folder %s: %v", localPath, err) + } + result.cleanTempFiles() // To be able to make snapshots, which were not finished. See comment in `handleBlockCommitted` function + snapMLog.Debugf("Snapshot manager created; folder %v is used for snapshots", localPath) + } else { + snapMLog.Debugf("Snapshot manager created; no snapshots will be produced") + } + result.snapshotManagerRunner = newSnapshotManagerRunner(ctx, shutdownCoordinator, result, snapMLog) + return result, nil +} + +// ------------------------------------- +// Implementations of SnapshotManager interface +// ------------------------------------- + +func (smiT *snapshotManagerImpl) SnapshotExists(stateIndex uint32, commitment *state.L1Commitment) bool { + smiT.availableSnapshotsMutex.RLock() + defer smiT.availableSnapshotsMutex.RUnlock() + + commitments, exists := smiT.availableSnapshots.Get(stateIndex) + if !exists { + return false + } + return commitments.ContainsBy(func(elem *commitmentSources) bool { return elem.commitment.Equals(commitment) && len(elem.sources) > 0 }) +} + +// NOTE: other implementations are inherited from snapshotManagerRunner + +// ------------------------------------- +// Implementations of snapshotManagerCore interface +// ------------------------------------- + +func (smiT *snapshotManagerImpl) createSnapshotsNeeded() bool { + return smiT.createPeriod > 0 +} + +func (smiT *snapshotManagerImpl) handleUpdate() { + result := shrinkingmap.New[uint32, SliceStruct[*commitmentSources]]() + smiT.handleUpdateLocal(result) + smiT.handleUpdateNetwork(result) + + smiT.availableSnapshotsMutex.Lock() + smiT.availableSnapshots = result + smiT.availableSnapshotsMutex.Unlock() +} + +// Snapshot manager makes snapshot of every `period`th state only, if this state hasn't +// been snapshotted before. The snapshot file name includes state index and state hash. +// Snapshot manager first writes the state to temporary file and only then moves it to +// permanent location. Writing is done in separate thread to not interfere with +// normal State manager routine, as it may be lengthy. If snapshot manager detects that +// the temporary file, needed to create a snapshot, already exists, it assumes +// that another go routine is already making a snapshot and returns. For this reason +// it is important to delete all temporary files on snapshot manager start. +func (smiT *snapshotManagerImpl) handleBlockCommitted(snapshotInfo SnapshotInfo) { + stateIndex := snapshotInfo.GetStateIndex() + var lastIndexSnapshotted uint32 + smiT.lastIndexSnapshottedMutex.Lock() + lastIndexSnapshotted = smiT.lastIndexSnapshotted + smiT.lastIndexSnapshottedMutex.Unlock() + if (stateIndex > lastIndexSnapshotted) && (stateIndex%smiT.createPeriod == 0) { // TODO: what if snapshotted state has been reverted? + commitment := snapshotInfo.GetCommitment() + smiT.log.Debugf("Creating snapshot %v %s...", stateIndex, commitment) + tmpFileName := tempSnapshotFileName(stateIndex, commitment.BlockHash()) + tmpFilePath := filepath.Join(smiT.localPath, tmpFileName) + exists, _, _ := ioutils.PathExists(tmpFilePath) + if exists { + smiT.log.Debugf("Creating snapshot %v %s: skipped making snapshot as it is already being produced", stateIndex, commitment) + return + } + f, err := os.OpenFile(tmpFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) + if err != nil { + smiT.log.Errorf("Creating snapshot %v %s: failed to create temporary snapshot file %s: %v", stateIndex, commitment, tmpFilePath, err) + f.Close() + return + } + go func() { + defer f.Close() + + smiT.log.Debugf("Creating snapshot %v %s: storing it to file", stateIndex, commitment) + err := smiT.snapshotter.storeSnapshot(snapshotInfo, f) + if err != nil { + smiT.log.Errorf("Creating snapshot %v %s: failed to write snapshot to temporary file %s: %v", stateIndex, commitment, tmpFilePath, err) + return + } + + finalFileName := snapshotFileName(stateIndex, commitment.BlockHash()) + finalFilePath := filepath.Join(smiT.localPath, finalFileName) + err = os.Rename(tmpFilePath, finalFilePath) + if err != nil { + smiT.log.Errorf("Creating snapshot %v %s: failed to move temporary snapshot file %s to permanent location %s: %v", + stateIndex, commitment, tmpFilePath, finalFilePath, err) + return + } + + smiT.lastIndexSnapshottedMutex.Lock() + if stateIndex > smiT.lastIndexSnapshotted { + smiT.lastIndexSnapshotted = stateIndex + } + smiT.lastIndexSnapshottedMutex.Unlock() + smiT.log.Infof("Creating snapshot %v %s: snapshot created in %s", stateIndex, commitment, finalFilePath) + }() + } +} + +func (smiT *snapshotManagerImpl) handleLoadSnapshot(snapshotInfo SnapshotInfo, callback chan<- error) { + smiT.log.Debugf("Loading snapshot %s", snapshotInfo) + // smiT.availableSnapshotsMutex.RLock() // Probably locking is not needed as it happens on the same thread as editing available snapshots + commitments, exists := smiT.availableSnapshots.Get(snapshotInfo.GetStateIndex()) + // smiT.availableSnapshotsMutex.RUnlock() + if !exists { + err := fmt.Errorf("failed to obtain snapshot commitments of index %v", snapshotInfo.GetStateIndex()) + smiT.log.Errorf("Loading snapshot %s: %v", snapshotInfo, err) + callback <- err + return + } + cs, exists := commitments.Find(func(c *commitmentSources) bool { + return c.commitment.Equals(snapshotInfo.GetCommitment()) + }) + if !exists { + err := fmt.Errorf("failed to obtain sources of snapshot %s", snapshotInfo) + smiT.log.Errorf("Loading snapshot %s: %v", snapshotInfo, err) + callback <- err + return + } + + loadSnapshotFun := func(r io.Reader) error { + err := smiT.snapshotter.loadSnapshot(snapshotInfo, r) + if err != nil { + return fmt.Errorf("loading snapshot failed: %v", err) + } + return nil + } + loadLocalFun := func(path string) error { + f, err := os.Open(path) + if err != nil { + return fmt.Errorf("failed to open snapshot file %s", path) + } + defer f.Close() + return loadSnapshotFun(f) + } + loadNetworkFun := func(ctx context.Context, url string) error { + closeFun, reader, err := downloadFile(ctx, smiT.log, url, constDownloadTimeout) + defer closeFun() + if err != nil { + return err + } + return loadSnapshotFun(reader) + } + loadFun := func(source string) error { + if strings.HasPrefix(source, constLocalAddress) { + filePath := strings.TrimPrefix(source, constLocalAddress) + smiT.log.Debugf("Loading snapshot %s: reading local file %s", snapshotInfo, filePath) + return loadLocalFun(filePath) + } + smiT.log.Debugf("Loading snapshot %s: downloading file %s", snapshotInfo, source) + return loadNetworkFun(smiT.ctx, source) + } + + var err error + for _, source := range cs.sources { + e := loadFun(source) + if e == nil { + smiT.log.Debugf("Loading snapshot %s succeeded", snapshotInfo) + callback <- nil + return + } + smiT.log.Errorf("Loading snapshot %s: %v", snapshotInfo, e) + err = errors.Join(err, e) + } + callback <- err +} + +// ------------------------------------- +// Internal functions +// ------------------------------------- + +func (smiT *snapshotManagerImpl) cleanTempFiles() { + tempFileRegExp := tempSnapshotFileNameString("*", "*") + tempFileRegExpWithPath := filepath.Join(smiT.localPath, tempFileRegExp) + tempFiles, err := filepath.Glob(tempFileRegExpWithPath) + if err != nil { + smiT.log.Errorf("Failed to obtain temporary snapshot file list: %v", err) + return + } + + removed := 0 + for _, tempFile := range tempFiles { + err = os.Remove(tempFile) + if err != nil { + smiT.log.Warnf("Failed to remove temporary snapshot file %s: %v", tempFile, err) + } else { + removed++ + } + } + smiT.log.Debugf("Removed %v out of %v temporary snapshot files", removed, len(tempFiles)) +} + +func (smiT *snapshotManagerImpl) handleUpdateLocal(result *shrinkingmap.ShrinkingMap[uint32, SliceStruct[*commitmentSources]]) { + fileRegExp := snapshotFileNameString("*", "*") + fileRegExpWithPath := filepath.Join(smiT.localPath, fileRegExp) + files, err := filepath.Glob(fileRegExpWithPath) + if err != nil { + if smiT.createSnapshotsNeeded() { + smiT.log.Errorf("Update local: failed to obtain snapshot file list: %v", err) + } else { + // If snapshots are not created, snapshot dir is not supposed to exists; unless, it was created by other runs of Wasp or manually + smiT.log.Warnf("Update local: cannot obtain snapshot file list (possibly, it does not exist): %v", err) + } + return + } + snapshotCount := 0 + for _, file := range files { + func() { // Function to make the defers sooner + f, err := os.Open(file) + if err != nil { + smiT.log.Errorf("Update local: failed to open snapshot file %s: %v", file, err) + } + defer f.Close() + snapshotInfo, err := readSnapshotInfo(f) + if err != nil { + smiT.log.Errorf("Update local: failed to read snapshot info from file %s: %v", file, err) + return + } + addSource(result, snapshotInfo, constLocalAddress+file) + snapshotCount++ + }() + } + smiT.log.Debugf("Update local: %v snapshot files found", snapshotCount) +} + +func (smiT *snapshotManagerImpl) handleUpdateNetwork(result *shrinkingmap.ShrinkingMap[uint32, SliceStruct[*commitmentSources]]) { + for _, networkPath := range smiT.networkPaths { + func() { // Function to make the defers sooner + indexFilePath, err := url.JoinPath(networkPath, constIndexFileName) + if err != nil { + smiT.log.Errorf("Update network: unable to join paths %s and %s: %v", networkPath, constIndexFileName, err) + return + } + cancelFun, reader, err := downloadFile(smiT.ctx, smiT.log, indexFilePath, constDownloadTimeout) + defer cancelFun() + if err != nil { + smiT.log.Errorf("Update network: failed to download index file: %v", err) + return + } + snapshotCount := 0 + scanner := bufio.NewScanner(reader) // Defaults to splitting input by newline character + for scanner.Scan() { + func() { + snapshotFileName := scanner.Text() + snapshotFilePath, er := url.JoinPath(networkPath, snapshotFileName) + if er != nil { + smiT.log.Errorf("Update network: unable to join paths %s and %s: %v", networkPath, snapshotFileName, er) + return + } + sCancelFun, sReader, er := downloadFile(smiT.ctx, smiT.log, snapshotFilePath, constDownloadTimeout) + defer sCancelFun() + if er != nil { + smiT.log.Errorf("Update network: failed to download snapshot file: %v", er) + return + } + snapshotInfo, er := readSnapshotInfo(sReader) + if er != nil { + smiT.log.Errorf("Update network: failed to read snapshot info from %s: %v", snapshotFilePath, er) + return + } + addSource(result, snapshotInfo, snapshotFilePath) + snapshotCount++ + }() + } + err = scanner.Err() + if err != nil && !errors.Is(err, io.EOF) { + smiT.log.Errorf("Update network: failed reading index file %s: %v", indexFilePath, err) + } + smiT.log.Debugf("Update network: %v snapshot files found on %s", snapshotCount, networkPath) + }() + } +} + +func tempSnapshotFileName(index uint32, blockHash state.BlockHash) string { + return tempSnapshotFileNameString(fmt.Sprint(index), blockHash.String()) +} + +func tempSnapshotFileNameString(index, blockHash string) string { + return snapshotFileNameString(index, blockHash) + constSnapshotTmpFileSuffix +} + +func snapshotFileName(index uint32, blockHash state.BlockHash) string { + return snapshotFileNameString(fmt.Sprint(index), blockHash.String()) +} + +func snapshotFileNameString(index, blockHash string) string { + return index + constSnapshotIndexHashFileNameSepparator + blockHash + constSnapshotFileSuffix +} + +func downloadFile(ctx context.Context, log *logger.Logger, url string, timeout time.Duration) (context.CancelFunc, io.Reader, error) { + downloadCtx, downloadCtxCancel := context.WithTimeout(ctx, timeout) + + request, err := http.NewRequestWithContext(downloadCtx, http.MethodGet, url, http.NoBody) + if err != nil { + return downloadCtxCancel, nil, fmt.Errorf("failed creating request with url %s: %v", url, err) + } + + response, err := http.DefaultClient.Do(request) //nolint:bodyclose// it will be closed, when the caller calls `cancelFun` + if err != nil { + return downloadCtxCancel, nil, fmt.Errorf("http request to file url %s failed: %v", url, err) + } + cancelFun := func() { + response.Body.Close() + downloadCtxCancel() + } + + if response.StatusCode != http.StatusOK { + return cancelFun, nil, fmt.Errorf("http request to %s got status code %v", url, response.StatusCode) + } + + progressReporter := NewProgressReporter(log, fmt.Sprintf("downloading file %s", url), uint64(response.ContentLength)) + reader := io.TeeReader(response.Body, progressReporter) + return cancelFun, reader, nil +} + +func addSource(result *shrinkingmap.ShrinkingMap[uint32, SliceStruct[*commitmentSources]], si SnapshotInfo, path string) { + makeNewComSourcesFun := func() *commitmentSources { + return &commitmentSources{ + commitment: si.GetCommitment(), + sources: []string{path}, + } + } + comSourcesArray, exists := result.Get(si.GetStateIndex()) + if exists { + comSources, ok := comSourcesArray.Find(func(elem *commitmentSources) bool { return elem.commitment.Equals(si.GetCommitment()) }) + if ok { + comSources.sources = append(comSources.sources, path) + } else { + comSourcesArray.Add(makeNewComSourcesFun()) + } + } else { + comSourcesArray = NewSliceStruct[*commitmentSources](makeNewComSourcesFun()) + result.Set(si.GetStateIndex(), comSourcesArray) + } +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_manager_empty.go b/packages/chain/statemanager/sm_snapshots/snapshot_manager_empty.go new file mode 100644 index 0000000000..7cfd3942dc --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_manager_empty.go @@ -0,0 +1,26 @@ +package sm_snapshots + +import ( + "github.com/iotaledger/wasp/packages/state" +) + +type snapshotManagerEmpty struct{} + +var ( + _ SnapshotManager = &snapshotManagerEmpty{} + _ SnapshotManagerTest = &snapshotManagerEmpty{} +) + +func NewEmptySnapshotManager() SnapshotManagerTest { return &snapshotManagerEmpty{} } +func (*snapshotManagerEmpty) UpdateAsync() {} +func (*snapshotManagerEmpty) BlockCommittedAsync(SnapshotInfo) {} +func (*snapshotManagerEmpty) SnapshotExists(uint32, *state.L1Commitment) bool { return false } +func (*snapshotManagerEmpty) SnapshotReady(SnapshotInfo) {} +func (*snapshotManagerEmpty) IsSnapshotReady(SnapshotInfo) bool { return false } +func (*snapshotManagerEmpty) SetAfterSnapshotCreated(func(SnapshotInfo)) {} + +func (*snapshotManagerEmpty) LoadSnapshotAsync(SnapshotInfo) <-chan error { + callback := make(chan error, 1) + callback <- nil + return callback +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_manager_mocked.go b/packages/chain/statemanager/sm_snapshots/snapshot_manager_mocked.go new file mode 100644 index 0000000000..d57a631915 --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_manager_mocked.go @@ -0,0 +1,175 @@ +package sm_snapshots + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/iotaledger/hive.go/kvstore/mapdb" + "github.com/iotaledger/hive.go/logger" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/state" +) + +type mockedSnapshotManager struct { + *snapshotManagerRunner + + t *testing.T + createPeriod uint32 + + availableSnapshots map[uint32]SliceStruct[*state.L1Commitment] + availableSnapshotsMutex sync.RWMutex + readySnapshots map[uint32]SliceStruct[*state.L1Commitment] + readySnapshotsMutex sync.Mutex + + origStore state.Store + nodeStore state.Store + + snapshotCommitTime time.Duration + snapshotLoadTime time.Duration + timeProvider sm_gpa_utils.TimeProvider + afterSnapshotCreatedFun func(SnapshotInfo) +} + +var ( + _ snapshotManagerCore = &mockedSnapshotManager{} + _ SnapshotManager = &mockedSnapshotManager{} + _ SnapshotManagerTest = &mockedSnapshotManager{} +) + +func NewMockedSnapshotManager( + t *testing.T, + createPeriod uint32, + origStore state.Store, + nodeStore state.Store, + snapshotCommitTime time.Duration, + snapshotLoadTime time.Duration, + timeProvider sm_gpa_utils.TimeProvider, + log *logger.Logger, +) SnapshotManagerTest { + result := &mockedSnapshotManager{ + t: t, + createPeriod: createPeriod, + availableSnapshots: make(map[uint32]SliceStruct[*state.L1Commitment]), + availableSnapshotsMutex: sync.RWMutex{}, + readySnapshots: make(map[uint32]SliceStruct[*state.L1Commitment]), + readySnapshotsMutex: sync.Mutex{}, + origStore: origStore, + nodeStore: nodeStore, + snapshotCommitTime: snapshotCommitTime, + snapshotLoadTime: snapshotLoadTime, + timeProvider: timeProvider, + afterSnapshotCreatedFun: func(SnapshotInfo) {}, + } + result.snapshotManagerRunner = newSnapshotManagerRunner(context.Background(), nil, result, log.Named("MSnap")) + return result +} + +// ------------------------------------- +// Implementations of SnapshotManager interface +// ------------------------------------- + +func (msmT *mockedSnapshotManager) SnapshotExists(stateIndex uint32, commitment *state.L1Commitment) bool { + msmT.availableSnapshotsMutex.RLock() + defer msmT.availableSnapshotsMutex.RUnlock() + + commitments, ok := msmT.availableSnapshots[stateIndex] + if !ok { + return false + } + return commitments.ContainsBy(func(comm *state.L1Commitment) bool { return comm.Equals(commitment) }) +} + +// NOTE: other implementations are inherited from snapshotManagerRunner + +// ------------------------------------- +// Implementations of SnapshotManagerTest interface +// ------------------------------------- + +func (msmT *mockedSnapshotManager) SnapshotReady(snapshotInfo SnapshotInfo) { + msmT.readySnapshotsMutex.Lock() + defer msmT.readySnapshotsMutex.Unlock() + + commitments, ok := msmT.readySnapshots[snapshotInfo.GetStateIndex()] + if ok { + if !commitments.ContainsBy(func(comm *state.L1Commitment) bool { return comm.Equals(snapshotInfo.GetCommitment()) }) { + commitments.Add(snapshotInfo.GetCommitment()) + } + } else { + msmT.readySnapshots[snapshotInfo.GetStateIndex()] = NewSliceStruct(snapshotInfo.GetCommitment()) + } +} + +func (msmT *mockedSnapshotManager) IsSnapshotReady(snapshotInfo SnapshotInfo) bool { + msmT.readySnapshotsMutex.Lock() + defer msmT.readySnapshotsMutex.Unlock() + + commitments, ok := msmT.readySnapshots[snapshotInfo.GetStateIndex()] + if !ok { + return false + } + return commitments.ContainsBy(func(elem *state.L1Commitment) bool { return elem.Equals(snapshotInfo.GetCommitment()) }) +} + +func (msmT *mockedSnapshotManager) SetAfterSnapshotCreated(fun func(SnapshotInfo)) { + msmT.afterSnapshotCreatedFun = fun +} + +// ------------------------------------- +// Implementations of snapshotManagerCore interface +// ------------------------------------- + +func (msmT *mockedSnapshotManager) createSnapshotsNeeded() bool { + return msmT.createPeriod > 0 +} + +func (msmT *mockedSnapshotManager) handleUpdate() { + msmT.readySnapshotsMutex.Lock() + defer msmT.readySnapshotsMutex.Unlock() + + availableSnapshots := make(map[uint32]SliceStruct[*state.L1Commitment]) + count := 0 + for index, commitments := range msmT.readySnapshots { + clonedCommitments := commitments.Clone() + availableSnapshots[index] = clonedCommitments + count += clonedCommitments.Length() + } + msmT.log.Debugf("Update: %v snapshots found", count) + + msmT.availableSnapshotsMutex.Lock() + defer msmT.availableSnapshotsMutex.Unlock() + msmT.availableSnapshots = availableSnapshots +} + +func (msmT *mockedSnapshotManager) handleBlockCommitted(snapshotInfo SnapshotInfo) { + stateIndex := snapshotInfo.GetStateIndex() + if stateIndex%msmT.createPeriod == 0 { + msmT.log.Debugf("Creating snapshot %s...", snapshotInfo) + go func() { + <-msmT.timeProvider.After(msmT.snapshotCommitTime) + msmT.SnapshotReady(snapshotInfo) + msmT.afterSnapshotCreatedFun(snapshotInfo) + msmT.log.Debugf("Creating snapshot %s: completed", snapshotInfo) + }() + } +} + +func (msmT *mockedSnapshotManager) handleLoadSnapshot(snapshotInfo SnapshotInfo, callback chan<- error) { + msmT.log.Debugf("Loading snapshot %s...", snapshotInfo) + commitments, ok := msmT.availableSnapshots[snapshotInfo.GetStateIndex()] + require.True(msmT.t, ok) + require.True(msmT.t, commitments.ContainsBy(func(elem *state.L1Commitment) bool { + return elem.Equals(snapshotInfo.GetCommitment()) + })) + <-msmT.timeProvider.After(msmT.snapshotLoadTime) + snapshot := mapdb.NewMapDB() + err := msmT.origStore.TakeSnapshot(snapshotInfo.GetTrieRoot(), snapshot) + require.NoError(msmT.t, err) + err = msmT.nodeStore.RestoreSnapshot(snapshotInfo.GetTrieRoot(), snapshot) + require.NoError(msmT.t, err) + callback <- nil + msmT.log.Debugf("Loading snapshot %s: snapshot loaded", snapshotInfo) +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_manager_runner.go b/packages/chain/statemanager/sm_snapshots/snapshot_manager_runner.go new file mode 100644 index 0000000000..fa55f89bed --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_manager_runner.go @@ -0,0 +1,115 @@ +package sm_snapshots + +import ( + "context" + + "github.com/iotaledger/hive.go/logger" + "github.com/iotaledger/wasp/packages/shutdown" + "github.com/iotaledger/wasp/packages/util/pipe" +) + +type snapshotInfoCallback struct { + SnapshotInfo + callback chan<- error +} + +type snapshotManagerRunner struct { + log *logger.Logger + ctx context.Context + shutdownCoordinator *shutdown.Coordinator + + updatePipe pipe.Pipe[bool] + blockCommittedPipe pipe.Pipe[SnapshotInfo] + loadSnapshotPipe pipe.Pipe[*snapshotInfoCallback] + + core snapshotManagerCore +} + +func newSnapshotManagerRunner( + ctx context.Context, + shutdownCoordinator *shutdown.Coordinator, + core snapshotManagerCore, + log *logger.Logger, +) *snapshotManagerRunner { + result := &snapshotManagerRunner{ + log: log, + ctx: ctx, + shutdownCoordinator: shutdownCoordinator, + updatePipe: pipe.NewInfinitePipe[bool](), + blockCommittedPipe: pipe.NewInfinitePipe[SnapshotInfo](), + loadSnapshotPipe: pipe.NewInfinitePipe[*snapshotInfoCallback](), + core: core, + } + go result.run() + return result +} + +// ------------------------------------- +// Implementations of SnapshotManager interface +// ------------------------------------- + +func (smrT *snapshotManagerRunner) UpdateAsync() { + smrT.updatePipe.In() <- true +} + +func (smrT *snapshotManagerRunner) BlockCommittedAsync(snapshotInfo SnapshotInfo) { + if smrT.core.createSnapshotsNeeded() { + smrT.blockCommittedPipe.In() <- snapshotInfo + } +} + +func (smrT *snapshotManagerRunner) LoadSnapshotAsync(snapshotInfo SnapshotInfo) <-chan error { + callback := make(chan error, 1) + smrT.loadSnapshotPipe.In() <- &snapshotInfoCallback{ + SnapshotInfo: snapshotInfo, + callback: callback, + } + return callback +} + +// ------------------------------------- +// Internal functions +// ------------------------------------- + +func (smrT *snapshotManagerRunner) run() { + updatePipeCh := smrT.updatePipe.Out() + blockCommittedPipeCh := smrT.blockCommittedPipe.Out() + loadSnapshotPipeCh := smrT.loadSnapshotPipe.Out() + for { + if smrT.ctx.Err() != nil { + if smrT.shutdownCoordinator == nil { + return + } + if smrT.shutdownCoordinator.CheckNestedDone() { + smrT.log.Debugf("Stopping snapshot manager, because context was closed") + smrT.shutdownCoordinator.Done() + return + } + } + select { + case _, ok := <-updatePipeCh: + if ok { + smrT.core.handleUpdate() + } else { + updatePipeCh = nil + } + case snapshotInfo, ok := <-blockCommittedPipeCh: + if ok { + smrT.core.handleBlockCommitted(snapshotInfo) + } else { + blockCommittedPipeCh = nil + } + case snapshotInfoC, ok := <-loadSnapshotPipeCh: + if ok { + func() { + defer close(snapshotInfoC.callback) + smrT.core.handleLoadSnapshot(snapshotInfoC.SnapshotInfo, snapshotInfoC.callback) + }() + } else { + loadSnapshotPipeCh = nil + } + case <-smrT.ctx.Done(): + continue + } + } +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_manager_test.go b/packages/chain/statemanager/sm_snapshots/snapshot_manager_test.go new file mode 100644 index 0000000000..b4d2a3d223 --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_manager_test.go @@ -0,0 +1,167 @@ +package sm_snapshots + +import ( + "bufio" + "context" + "fmt" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/iotaledger/hive.go/kvstore/mapdb" + "github.com/iotaledger/hive.go/logger" + "github.com/iotaledger/hive.go/runtime/ioutils" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/isc" + "github.com/iotaledger/wasp/packages/state" + "github.com/iotaledger/wasp/packages/testutil/testlogger" +) + +const localSnapshotsPathConst = "testSnapshots" + +func TestSnapshotManagerLocal(t *testing.T) { + createFun := func(chainID isc.ChainID, store state.Store, log *logger.Logger) SnapshotManager { + snapshotManager, err := NewSnapshotManager(context.Background(), nil, chainID, 0, localSnapshotsPathConst, []string{}, store, log) + require.NoError(t, err) + return snapshotManager + } + defer cleanupAfterTest(t) + + testSnapshotManagerSimple(t, createFun, func(isc.ChainID, []SnapshotInfo) {}) +} + +func TestSnapshotManagerNetwork(t *testing.T) { + log := testlogger.NewLogger(t) + defer log.Sync() + + err := ioutils.CreateDirectory(localSnapshotsPathConst, 0o777) + require.NoError(t, err) + + port := ":9999" + handler := http.FileServer(http.Dir(localSnapshotsPathConst)) + go http.ListenAndServe(port, handler) + + createFun := func(chainID isc.ChainID, store state.Store, log *logger.Logger) SnapshotManager { + networkPaths := []string{"http://localhost" + port + "/"} + snapshotManager, err := NewSnapshotManager(context.Background(), nil, chainID, 0, "nonexistent", networkPaths, store, log) + require.NoError(t, err) + return snapshotManager + } + defer cleanupAfterTest(t) + + createIndexFileFun := func(chainID isc.ChainID, snapshotInfos []SnapshotInfo) { + indexFilePath := filepath.Join(localSnapshotsPathConst, chainID.String(), constIndexFileName) + f, err := os.OpenFile(indexFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) + require.NoError(t, err) + defer f.Close() + w := bufio.NewWriter(f) + for _, snapshotInfo := range snapshotInfos { + w.WriteString(snapshotFileName(snapshotInfo.GetStateIndex(), snapshotInfo.GetBlockHash()) + "\n") + } + w.Flush() + } + testSnapshotManagerSimple(t, createFun, createIndexFileFun) +} + +func testSnapshotManagerSimple( + t *testing.T, + createNewNodeFun func(isc.ChainID, state.Store, *logger.Logger) SnapshotManager, + snapshotsAvailableFun func(isc.ChainID, []SnapshotInfo), +) { + log := testlogger.NewLogger(t) + defer log.Sync() + + numberOfBlocks := 10 + snapshotCreatePeriod := 2 + + var err error + factory := sm_gpa_utils.NewBlockFactory(t) + blocks := factory.GetBlocks(numberOfBlocks, 1) + storeOrig := factory.GetStore() + snapshotManagerOrig, err := NewSnapshotManager(context.Background(), nil, factory.GetChainID(), uint32(snapshotCreatePeriod), localSnapshotsPathConst, []string{}, storeOrig, log) + require.NoError(t, err) + + // "Running" node, making snapshots + for _, block := range blocks { + snapshotManagerOrig.BlockCommittedAsync(NewSnapshotInfo(block.StateIndex(), block.L1Commitment())) + } + for i := snapshotCreatePeriod - 1; i < numberOfBlocks; i += snapshotCreatePeriod { + require.True(t, waitForBlock(t, snapshotManagerOrig, blocks[i], 10, 50*time.Millisecond)) + } + createdSnapshots := make([]SnapshotInfo, 0) + for _, block := range blocks { + exists := snapshotManagerOrig.SnapshotExists(block.StateIndex(), block.L1Commitment()) + if block.StateIndex()%uint32(snapshotCreatePeriod) == 0 { + require.True(t, exists) + createdSnapshots = append(createdSnapshots, NewSnapshotInfo(block.StateIndex(), block.L1Commitment())) + } else { + require.False(t, exists) + } + } + snapshotsAvailableFun(factory.GetChainID(), createdSnapshots) + + // Node is restarted + storeNew := state.NewStore(mapdb.NewMapDB()) + snapshotManagerNew := createNewNodeFun(factory.GetChainID(), storeNew, log) + + // Wait for node to read the list of snapshots + lastBlock := blocks[len(blocks)-1] + require.True(t, waitForBlock(t, snapshotManagerNew, lastBlock, 10, 50*time.Millisecond)) + require.True(t, loadAndWaitLoaded(t, snapshotManagerNew, NewSnapshotInfo(lastBlock.StateIndex(), lastBlock.L1Commitment()), 10, 50*time.Millisecond)) + + // Check the loaded snapshot + for i := 0; i < len(blocks)-1; i++ { + require.False(t, storeNew.HasTrieRoot(blocks[i].TrieRoot())) + } + require.True(t, storeNew.HasTrieRoot(lastBlock.TrieRoot())) + + sm_gpa_utils.CheckBlockInStore(t, storeNew, lastBlock) + sm_gpa_utils.CheckStateInStores(t, storeOrig, storeNew, lastBlock.L1Commitment()) +} + +func waitForBlock(t *testing.T, snapshotManager SnapshotManager, block state.Block, maxIterations int, sleep time.Duration) bool { + updateAndWaitFun := func() { + snapshotManager.UpdateAsync() + time.Sleep(sleep) + } + snapshotExistsFun := func() bool { return snapshotManager.SnapshotExists(block.StateIndex(), block.L1Commitment()) } + return ensureTrue(t, fmt.Sprintf("block %v to be committed", block.StateIndex()), snapshotExistsFun, maxIterations, updateAndWaitFun) +} + +func loadAndWaitLoaded(t *testing.T, snapshotManager SnapshotManager, snapshotInfo SnapshotInfo, maxIterations int, sleep time.Duration) bool { + respChan := snapshotManager.LoadSnapshotAsync(snapshotInfo) + loadCompletedFun := func() bool { + select { + case result := <-respChan: + require.NoError(t, result) + return true + default: + return false + } + } + waitFun := func() { time.Sleep(sleep) } + return ensureTrue(t, fmt.Sprintf("state %v to be loaded", snapshotInfo.GetStateIndex()), loadCompletedFun, maxIterations, waitFun) +} + +func ensureTrue(t *testing.T, title string, predicate func() bool, maxIterations int, step func()) bool { + if predicate() { + return true + } + for i := 1; i < maxIterations; i++ { + t.Logf("Waiting for %s iteration %v", title, i) + step() + if predicate() { + return true + } + } + return false +} + +func cleanupAfterTest(t *testing.T) { + err := os.RemoveAll(localSnapshotsPathConst) + require.NoError(t, err) +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshot_store_test.go b/packages/chain/statemanager/sm_snapshots/snapshot_store_test.go new file mode 100644 index 0000000000..b22cd75258 --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshot_store_test.go @@ -0,0 +1,98 @@ +package sm_snapshots + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/iotaledger/hive.go/kvstore" + "github.com/iotaledger/hive.go/kvstore/mapdb" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/state" +) + +func TestNewerSnapshotKeepsOlderSnapshot(t *testing.T) { + twoSnapshotsCheckEnds(t, func(t *testing.T, _storeOrig, storeNew state.Store, intermediateSnapshot, lastSnapshot kvstore.KVStore, blocks []state.Block) { + intermediateTrieRoot := blocks[0].TrieRoot() + lastTrieRoot := blocks[len(blocks)-1].TrieRoot() + + err := storeNew.RestoreSnapshot(intermediateTrieRoot, intermediateSnapshot) + require.NoError(t, err) + require.True(t, storeNew.HasTrieRoot(intermediateTrieRoot)) + + err = storeNew.RestoreSnapshot(lastTrieRoot, lastSnapshot) + require.NoError(t, err) + require.True(t, storeNew.HasTrieRoot(intermediateTrieRoot)) + require.True(t, storeNew.HasTrieRoot(lastTrieRoot)) + }) +} + +func TestOlderSnapshotKeepsNewerSnapshot(t *testing.T) { + twoSnapshotsCheckEnds(t, func(t *testing.T, _storeOrig, storeNew state.Store, intermediateSnapshot, lastSnapshot kvstore.KVStore, blocks []state.Block) { + intermediateTrieRoot := blocks[0].TrieRoot() + lastTrieRoot := blocks[len(blocks)-1].TrieRoot() + + err := storeNew.RestoreSnapshot(lastTrieRoot, lastSnapshot) + require.NoError(t, err) + require.True(t, storeNew.HasTrieRoot(lastTrieRoot)) + + err = storeNew.RestoreSnapshot(intermediateTrieRoot, intermediateSnapshot) + require.NoError(t, err) + require.True(t, storeNew.HasTrieRoot(intermediateTrieRoot)) + require.True(t, storeNew.HasTrieRoot(lastTrieRoot)) + }) +} + +func TestFillTheBlocksBetweenSnapshots(t *testing.T) { + twoSnapshotsCheckEnds(t, func(t *testing.T, storeOrig, storeNew state.Store, intermediateSnapshot, lastSnapshot kvstore.KVStore, blocks []state.Block) { + intermediateTrieRoot := blocks[0].TrieRoot() + lastTrieRoot := blocks[len(blocks)-1].TrieRoot() + err := storeNew.RestoreSnapshot(lastTrieRoot, lastSnapshot) + require.NoError(t, err) + err = storeNew.RestoreSnapshot(intermediateTrieRoot, intermediateSnapshot) + require.NoError(t, err) + require.True(t, storeNew.HasTrieRoot(intermediateTrieRoot)) + require.True(t, storeNew.HasTrieRoot(lastTrieRoot)) + for i := 1; i < len(blocks); i++ { + stateDraft, err := storeNew.NewEmptyStateDraft(blocks[i].PreviousL1Commitment()) + require.NoError(t, err) + blocks[i].Mutations().ApplyTo(stateDraft) + block := storeNew.Commit(stateDraft) + require.True(t, blocks[i].TrieRoot().Equals(block.TrieRoot())) + require.True(t, blocks[i].Hash().Equals(block.Hash())) + } + for i := 1; i < len(blocks)-1; i++ { // blocks[i] and blocsk[len(blocks)-1] will be checked in `twoSnapshotsCheckEnds` + sm_gpa_utils.CheckBlockInStore(t, storeNew, blocks[i]) + sm_gpa_utils.CheckStateInStores(t, storeOrig, storeNew, blocks[i].L1Commitment()) + } + }) +} + +func twoSnapshotsCheckEnds(t *testing.T, performTestFun func(t *testing.T, storeOrig, storeNew state.Store, intermediateSnapshot, lastSnapshot kvstore.KVStore, blocks []state.Block)) { + numberOfBlocks := 10 + intermediateBlockIndex := 4 + + factory := sm_gpa_utils.NewBlockFactory(t) + blocks := factory.GetBlocks(numberOfBlocks, 1) + storeOrig := factory.GetStore() + storeNew := state.NewStore(mapdb.NewMapDB()) + + intermediateBlock := blocks[intermediateBlockIndex] + intermediateCommitment := intermediateBlock.L1Commitment() + intermediateSnapshot := mapdb.NewMapDB() + err := storeOrig.TakeSnapshot(intermediateCommitment.TrieRoot(), intermediateSnapshot) + require.NoError(t, err) + + lastBlock := blocks[len(blocks)-1] + lastCommitment := lastBlock.L1Commitment() + lastSnapshot := mapdb.NewMapDB() + err = storeOrig.TakeSnapshot(lastCommitment.TrieRoot(), lastSnapshot) + require.NoError(t, err) + + performTestFun(t, storeOrig, storeNew, intermediateSnapshot, lastSnapshot, blocks[intermediateBlockIndex:]) + + sm_gpa_utils.CheckBlockInStore(t, storeNew, intermediateBlock) + sm_gpa_utils.CheckStateInStores(t, storeOrig, storeNew, intermediateCommitment) + sm_gpa_utils.CheckBlockInStore(t, storeNew, lastBlock) + sm_gpa_utils.CheckStateInStores(t, storeOrig, storeNew, lastCommitment) +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshotter.go b/packages/chain/statemanager/sm_snapshots/snapshotter.go new file mode 100644 index 0000000000..e387099bbb --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshotter.go @@ -0,0 +1,177 @@ +package sm_snapshots + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/iotaledger/hive.go/kvstore" + "github.com/iotaledger/hive.go/kvstore/mapdb" + "github.com/iotaledger/wasp/packages/state" +) + +type snapshotterImpl struct { + store state.Store +} + +var _ snapshotter = &snapshotterImpl{} + +const constLengthArrayLength = 4 // bytes + +func newSnapshotter(store state.Store) snapshotter { + return &snapshotterImpl{store: store} +} + +func (sn *snapshotterImpl) storeSnapshot(snapshotInfo SnapshotInfo, w io.Writer) error { + snapshot := mapdb.NewMapDB() + err := sn.store.TakeSnapshot(snapshotInfo.GetTrieRoot(), snapshot) + if err != nil { + return fmt.Errorf("failed to read store: %w", err) + } + err = writeSnapshot(snapshotInfo, snapshot, w) + if err != nil { + return fmt.Errorf("failed writing snapshot: %w", err) + } + return nil +} + +func (sn *snapshotterImpl) loadSnapshot(snapshotInfo SnapshotInfo, r io.Reader) error { + readSnapshotInfo, snapshot, err := readSnapshot(r) + if err != nil { + return fmt.Errorf("failed reading snapshot: %w", err) + } + if !readSnapshotInfo.Equals(snapshotInfo) { + return fmt.Errorf("snapshot read %s is different than expected %v", readSnapshotInfo, snapshotInfo) + } + err = sn.store.RestoreSnapshot(readSnapshotInfo.GetTrieRoot(), snapshot) + if err != nil { + return fmt.Errorf("failed restoring snapshot: %w", err) + } + return nil +} + +func writeSnapshot(snapshotInfo SnapshotInfo, snapshot kvstore.KVStore, w io.Writer) error { + indexArray := make([]byte, 4) // Size of block index, which is of type uint32: 4 bytes + binary.LittleEndian.PutUint32(indexArray, snapshotInfo.GetStateIndex()) + err := writeBytes(indexArray, w) + if err != nil { + return fmt.Errorf("failed writing block index %v: %w", snapshotInfo.GetStateIndex(), err) + } + + trieRootBytes := snapshotInfo.GetCommitment().Bytes() + err = writeBytes(trieRootBytes, w) + if err != nil { + return fmt.Errorf("failed writing L1 commitment %s: %w", snapshotInfo.GetCommitment(), err) + } + + iterErr := snapshot.Iterate(kvstore.EmptyPrefix, func(key kvstore.Key, value kvstore.Value) bool { + e := writeBytes(key, w) + if e != nil { + err = fmt.Errorf("failed writing key %v: %w", key, e) + return false + } + + e = writeBytes(value, w) + if e != nil { + err = fmt.Errorf("failed writing key's %v value %v: %w", key, value, e) + return false + } + + return true + }) + + if iterErr != nil { + return iterErr + } + + return err +} + +func readSnapshotInfo(r io.Reader) (SnapshotInfo, error) { + indexArray, err := readBytes(r) + if err != nil { + return nil, fmt.Errorf("failed to read block index: %w", err) + } + if len(indexArray) != 4 { // Size of block index, which is of type uint32: 4 bytes + return nil, fmt.Errorf("block index is %v instead of 4 bytes", len(indexArray)) + } + index := binary.LittleEndian.Uint32(indexArray) + + trieRootArray, err := readBytes(r) + if err != nil { + return nil, fmt.Errorf("failed to read trie root: %w", err) + } + commitment, err := state.L1CommitmentFromBytes(trieRootArray) + if err != nil { + return nil, fmt.Errorf("failed to parse L1 commitment: %w", err) + } + + return NewSnapshotInfo(index, commitment), nil +} + +func readSnapshot(r io.Reader) (SnapshotInfo, kvstore.KVStore, error) { + snapshotInfo, err := readSnapshotInfo(r) + if err != nil { + return nil, nil, err + } + snapshot := mapdb.NewMapDB() + for key, err := readBytes(r); !errors.Is(err, io.EOF); key, err = readBytes(r) { + if err != nil { + return nil, nil, fmt.Errorf("failed to read key: %w", err) + } + + value, err := readBytes(r) + if err != nil { + return nil, nil, fmt.Errorf("failed to read value of key %v: %w", key, err) + } + + err = snapshot.Set(key, value) + if err != nil { + return nil, nil, fmt.Errorf("failed to set key's %v value %v to snapshot: %w", key, value, err) + } + } + return snapshotInfo, snapshot, nil +} + +func writeBytes(bytes []byte, w io.Writer) error { + lengthArray := make([]byte, constLengthArrayLength) + binary.LittleEndian.PutUint32(lengthArray, uint32(len(bytes))) + n, err := w.Write(lengthArray) + if n != constLengthArrayLength { + return fmt.Errorf("only %v of total %v bytes of length written", n, constLengthArrayLength) + } + if err != nil { + return fmt.Errorf("failed writing length: %w", err) + } + + n, err = w.Write(bytes) + if n != len(bytes) { + return fmt.Errorf("only %v of total %v bytes of array written", n, len(bytes)) + } + if err != nil { + return fmt.Errorf("failed writing array: %w", err) + } + + return nil +} + +func readBytes(r io.Reader) ([]byte, error) { + w := new(bytes.Buffer) + read, err := io.CopyN(w, r, constLengthArrayLength) + lengthArray := w.Bytes() + if err != nil { + return nil, fmt.Errorf("read only %v bytes out of %v of length, error: %w", read, constLengthArrayLength, err) + } + + length := int64(binary.LittleEndian.Uint32(lengthArray)) + w = new(bytes.Buffer) + read, err = io.CopyN(w, r, length) + array := w.Bytes() + if err != nil { + return nil, fmt.Errorf("only %v of %v bytes of array read, error: %w", read, len(array), err) + } + + return array, nil +} diff --git a/packages/chain/statemanager/sm_snapshots/snapshotter_test.go b/packages/chain/statemanager/sm_snapshots/snapshotter_test.go new file mode 100644 index 0000000000..2af5001b0f --- /dev/null +++ b/packages/chain/statemanager/sm_snapshots/snapshotter_test.go @@ -0,0 +1,48 @@ +package sm_snapshots + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/iotaledger/hive.go/kvstore/mapdb" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/state" + "github.com/iotaledger/wasp/packages/testutil/testlogger" +) + +func TestWriteReadDifferentStores(t *testing.T) { + log := testlogger.NewLogger(t) + defer log.Sync() + + var err error + numberOfBlocks := 10 + factory := sm_gpa_utils.NewBlockFactory(t) + blocks := factory.GetBlocks(numberOfBlocks, 1) + lastBlock := blocks[numberOfBlocks-1] + lastCommitment := lastBlock.L1Commitment() + snapshotInfo := NewSnapshotInfo(blocks[numberOfBlocks-1].StateIndex(), lastCommitment) + snapshotterOrig := newSnapshotter(factory.GetStore()) + fileName := "TestWriteReadDifferentStores.snap" + f, err := os.OpenFile(fileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) + require.NoError(t, err) + err = snapshotterOrig.storeSnapshot(snapshotInfo, f) + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + + store := state.NewStore(mapdb.NewMapDB()) + snapshotterNew := newSnapshotter(store) + f, err = os.Open(fileName) + require.NoError(t, err) + err = snapshotterNew.loadSnapshot(snapshotInfo, f) + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + err = os.Remove(fileName) + require.NoError(t, err) + + sm_gpa_utils.CheckBlockInStore(t, store, lastBlock) + sm_gpa_utils.CheckStateInStores(t, factory.GetStore(), store, lastCommitment) +} diff --git a/packages/chain/statemanager/sm_utils/node_randomiser.go b/packages/chain/statemanager/sm_utils/node_randomiser.go index c719afe013..74fa605f99 100644 --- a/packages/chain/statemanager/sm_utils/node_randomiser.go +++ b/packages/chain/statemanager/sm_utils/node_randomiser.go @@ -28,7 +28,7 @@ func NewNodeRandomiserNoInit(me gpa.NodeID, log *logger.Logger) NodeRandomiser { me: me, nodeIDs: nil, // Will be set in result.UpdateNodeIDs([]gpa.NodeID). permutation: nil, // Will be set in result.UpdateNodeIDs([]gpa.NodeID). - log: log.Named("nr"), + log: log.Named("NR"), } } diff --git a/packages/chain/statemanager/state_manager.go b/packages/chain/statemanager/state_manager.go index eca9dbe217..a906aed4b8 100644 --- a/packages/chain/statemanager/state_manager.go +++ b/packages/chain/statemanager/state_manager.go @@ -12,6 +12,7 @@ import ( "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_inputs" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_utils" "github.com/iotaledger/wasp/packages/cryptolib" "github.com/iotaledger/wasp/packages/gpa" @@ -85,6 +86,9 @@ type stateManager struct { messagePipe pipe.Pipe[*peering.PeerMessageIn] nodePubKeysPipe pipe.Pipe[*reqChainNodesUpdated] preliminaryBlockPipe pipe.Pipe[*reqPreliminaryBlock] + snapshotManager sm_snapshots.SnapshotManager + snapshotRespChannel <-chan error + snapshotRespInfo sm_snapshots.SnapshotInfo wal sm_gpa_utils.BlockWAL net peering.NetworkProvider netPeeringID peering.PeeringID @@ -111,6 +115,7 @@ func New( peerPubKeys []*cryptolib.PublicKey, net peering.NetworkProvider, wal sm_gpa_utils.BlockWAL, + snapshotManager sm_snapshots.SnapshotManager, store state.Store, shutdownCoordinator *shutdown.Coordinator, metrics *metrics.ChainStateManagerMetrics, @@ -118,14 +123,18 @@ func New( log *logger.Logger, parameters sm_gpa.StateManagerParameters, ) (StateMgr, error) { - nr := sm_utils.NewNodeRandomiserNoInit(gpa.NodeIDFromPublicKey(me), log) - stateManagerGPA, err := sm_gpa.New(chainID, nr, wal, store, metrics, log, parameters) + smLog := log.Named("SM") + nr := sm_utils.NewNodeRandomiserNoInit(gpa.NodeIDFromPublicKey(me), smLog) + snapshotExistsFun := func(stateIndex uint32, commitment *state.L1Commitment) bool { + return snapshotManager.SnapshotExists(stateIndex, commitment) + } + stateManagerGPA, err := sm_gpa.New(chainID, nr, wal, snapshotExistsFun, store, metrics, smLog, parameters) if err != nil { - log.Errorf("failed to create state manager GPA: %w", err) + smLog.Errorf("failed to create state manager GPA: %w", err) return nil, err } result := &stateManager{ - log: log, + log: smLog, chainID: chainID, stateManagerGPA: stateManagerGPA, nodeRandomiser: nr, @@ -133,6 +142,7 @@ func New( messagePipe: pipe.NewInfinitePipe[*peering.PeerMessageIn](), nodePubKeysPipe: pipe.NewInfinitePipe[*reqChainNodesUpdated](), preliminaryBlockPipe: pipe.NewInfinitePipe[*reqPreliminaryBlock](), + snapshotManager: snapshotManager, wal: wal, net: net, netPeeringID: peering.HashPeeringIDFromBytes(chainID.Bytes(), []byte("StateManager")), // ChainID × StateManager @@ -274,6 +284,11 @@ func (smT *stateManager) run() { //nolint:gocyclo } else { preliminaryBlockPipeCh = nil } + case result, ok := <-smT.snapshotRespChannel: + if ok { + smT.handleSnapshotDone(result) + } + smT.snapshotRespChannel = nil case now, ok := <-timerTickCh: if ok { smT.handleTimerTick(now) @@ -294,6 +309,7 @@ func (smT *stateManager) run() { //nolint:gocyclo func (smT *stateManager) handleInput(input gpa.Input) { outMsgs := smT.stateManagerGPA.Input(input) smT.sendMessages(outMsgs) + smT.handleOutput() } func (smT *stateManager) handleMessage(peerMsg *peering.PeerMessageIn) { @@ -305,6 +321,24 @@ func (smT *stateManager) handleMessage(peerMsg *peering.PeerMessageIn) { msg.SetSender(gpa.NodeIDFromPublicKey(peerMsg.SenderPubKey)) outMsgs := smT.stateManagerGPA.Message(msg) smT.sendMessages(outMsgs) + smT.handleOutput() +} + +func (smT *stateManager) handleOutput() { + output := smT.stateManagerGPA.Output().(sm_gpa.StateManagerOutput) + if smT.snapshotRespChannel == nil { + snapshotInfo := output.TakeSnapshotToLoad() + if snapshotInfo != nil { + smT.snapshotRespChannel = smT.snapshotManager.LoadSnapshotAsync(snapshotInfo) + smT.snapshotRespInfo = snapshotInfo + } + } + for _, snapshotInfo := range output.TakeBlocksCommitted() { + smT.snapshotManager.BlockCommittedAsync(snapshotInfo) + } + if output.TakeUpdateSnapshots() { + smT.snapshotManager.UpdateAsync() + } } func (smT *stateManager) handleNodePublicKeys(req *reqChainNodesUpdated) { @@ -356,6 +390,10 @@ func (smT *stateManager) handlePreliminaryBlock(msg *reqPreliminaryBlock) { msg.Respond(nil) } +func (smT *stateManager) handleSnapshotDone(result error) { + smT.handleInput(sm_inputs.NewSnapshotManagerSnapshotDone(smT.snapshotRespInfo.GetStateIndex(), smT.snapshotRespInfo.GetCommitment(), result)) +} + func (smT *stateManager) handleTimerTick(now time.Time) { smT.handleInput(sm_inputs.NewStateManagerTimerTick(now)) } diff --git a/packages/chain/statemanager/state_manager_test.go b/packages/chain/statemanager/state_manager_test.go index d75046143c..1c5ddad6cb 100644 --- a/packages/chain/statemanager/state_manager_test.go +++ b/packages/chain/statemanager/state_manager_test.go @@ -10,8 +10,10 @@ import ( "github.com/stretchr/testify/require" "github.com/iotaledger/hive.go/kvstore/mapdb" + "github.com/iotaledger/hive.go/logger" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/cryptolib" "github.com/iotaledger/wasp/packages/isc" "github.com/iotaledger/wasp/packages/metrics" @@ -22,7 +24,7 @@ import ( "github.com/iotaledger/wasp/packages/testutil/testpeers" ) -func TestCruelWorld(t *testing.T) { +func TestCruelWorld(t *testing.T) { //nolint:gocyclo log := testlogger.NewLogger(t) defer log.Sync() @@ -39,6 +41,11 @@ func TestCruelWorld(t *testing.T) { consensusDecidedStateCount := 50 mempoolStateRequestDelay := 50 * time.Millisecond mempoolStateRequestCount := 50 + snapshotCreateNodeCount := 2 + snapshotCreatePeriod := uint32(7) + snapshotCommitTime := 170 * time.Millisecond + snapshotLoadTime := 320 * time.Millisecond + snapshotUpdatePeriod := 510 * time.Millisecond peeringURLs, peerIdentities := testpeers.SetupKeys(uint16(nodeCount)) peerPubKeys := make([]*cryptolib.PublicKey, len(peerIdentities)) @@ -55,13 +62,26 @@ func TestCruelWorld(t *testing.T) { bf := sm_gpa_utils.NewBlockFactory(t) sms := make([]StateMgr, nodeCount) stores := make([]state.Store, nodeCount) + snapMs := make([]sm_snapshots.SnapshotManagerTest, nodeCount) parameters := sm_gpa.NewStateManagerParameters() parameters.StateManagerTimerTickPeriod = timerTickPeriod parameters.StateManagerGetBlockRetry = getBlockPeriod + parameters.SnapshotManagerUpdatePeriod = snapshotUpdatePeriod + NewMockedSnapshotManagerFun := func(createSnapshots bool, store state.Store, log *logger.Logger) sm_snapshots.SnapshotManagerTest { + var createPeriod uint32 + if createSnapshots { + createPeriod = snapshotCreatePeriod + } else { + createPeriod = 0 + } + return sm_snapshots.NewMockedSnapshotManager(t, createPeriod, bf.GetStore(), store, snapshotCommitTime, snapshotLoadTime, parameters.TimeProvider, log) + } for i := range sms { t.Logf("Creating %v-th state manager for node %s", i, peeringURLs[i]) var err error + logNode := log.Named(peeringURLs[i]) stores[i] = state.NewStore(mapdb.NewMapDB()) + snapMs[i] = NewMockedSnapshotManagerFun(i < snapshotCreateNodeCount, stores[i], logNode) origin.InitChain(stores[i], nil, 0) chainMetrics := metrics.NewChainMetricsProvider().GetChainMetrics(isc.EmptyChainID()) sms[i], err = New( @@ -71,15 +91,23 @@ func TestCruelWorld(t *testing.T) { peerPubKeys, netProviders[i], sm_gpa_utils.NewMockedTestBlockWAL(), + snapMs[i], stores[i], nil, chainMetrics.StateManager, chainMetrics.Pipe, - log.Named(peeringURLs[i]), + logNode, parameters, ) require.NoError(t, err) } + for i := 0; i < snapshotCreateNodeCount; i++ { + snapMs[i].SetAfterSnapshotCreated(func(snapshotInfo sm_snapshots.SnapshotInfo) { + for j := snapshotCreateNodeCount; j < len(snapMs); j++ { + snapMs[j].SnapshotReady(snapshotInfo) + } + }) + } blocks := bf.GetBlocks(blockCount, 1) stateDrafts := make([]state.StateDraft, blockCount) blockProduced := make([]*atomic.Bool, blockCount) @@ -144,7 +172,13 @@ func TestCruelWorld(t *testing.T) { newStateOutput := bf.GetAliasOutput(blocks[newBlockIndex].L1Commitment()) responseCh := sms[nodeIndex].(*stateManager).ChainFetchStateDiff(context.Background(), oldStateOutput, newStateOutput) results := <-responseCh - if !bf.GetState(blocks[newBlockIndex].L1Commitment()).TrieRoot().Equals(results.GetNewState().TrieRoot()) { // TODO: should compare states instead of trie roots + expectedNewState, err := bf.GetStore().StateByTrieRoot(blocks[newBlockIndex].TrieRoot()) + if err != nil { + t.Logf("Mempool state request for new block %v and old block %v to node %v wasn't able to retrieve expected new state: %v", + newBlockIndex+1, oldBlockIndex+1, peeringURLs[nodeIndex], err) + return false + } + if !sm_gpa_utils.StatesEqual(expectedNewState, results.GetNewState()) { t.Logf("Mempool state request for new block %v and old block %v to node %v return wrong new state: expected trie root %s, received %s", newBlockIndex+1, oldBlockIndex+1, peeringURLs[nodeIndex], blocks[newBlockIndex].TrieRoot(), results.GetNewState().TrieRoot()) return false diff --git a/packages/chains/chains.go b/packages/chains/chains.go index 257f3e2218..5b4a8dc563 100644 --- a/packages/chains/chains.go +++ b/packages/chains/chains.go @@ -18,6 +18,7 @@ import ( "github.com/iotaledger/wasp/packages/chain/cmt_log" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa" "github.com/iotaledger/wasp/packages/chain/statemanager/sm_gpa/sm_gpa_utils" + "github.com/iotaledger/wasp/packages/chain/statemanager/sm_snapshots" "github.com/iotaledger/wasp/packages/chains/access_mgr" "github.com/iotaledger/wasp/packages/cryptolib" "github.com/iotaledger/wasp/packages/database" @@ -55,6 +56,7 @@ type Chains struct { trustedNetworkListenerCancel context.CancelFunc chainStateStoreProvider database.ChainStateKVStoreProvider + walLoadToStore bool walEnabled bool walFolderPath string smBlockCacheMaxSize int @@ -65,6 +67,10 @@ type Chains struct { smStateManagerTimerTickPeriod time.Duration smPruningMinStatesToKeep int smPruningMaxStatesToDelete int + snapshotPeriod uint32 + snapshotFolderPath string + snapshotNetworkPaths []string + snapshotUpdatePeriod time.Duration chainRecordRegistryProvider registry.ChainRecordRegistryProvider dkShareRegistryProvider registry.DKShareRegistryProvider @@ -103,6 +109,7 @@ func New( networkProvider peering.NetworkProvider, trustedNetworkManager peering.TrustedNetworkManager, chainStateStoreProvider database.ChainStateKVStoreProvider, + walLoadToStore bool, walEnabled bool, walFolderPath string, smBlockCacheMaxSize int, @@ -113,6 +120,10 @@ func New( smStateManagerTimerTickPeriod time.Duration, smPruningMinStatesToKeep int, smPruningMaxStatesToDelete int, + snapshotPeriod uint32, + snapshotFolderPath string, + snapshotNetworkPaths []string, + snapshotUpdatePeriod time.Duration, chainRecordRegistryProvider registry.ChainRecordRegistryProvider, dkShareRegistryProvider registry.DKShareRegistryProvider, nodeIdentityProvider registry.NodeIdentityProvider, @@ -147,6 +158,7 @@ func New( networkProvider: networkProvider, trustedNetworkManager: trustedNetworkManager, chainStateStoreProvider: chainStateStoreProvider, + walLoadToStore: walLoadToStore, walEnabled: walEnabled, walFolderPath: walFolderPath, smBlockCacheMaxSize: smBlockCacheMaxSize, @@ -157,6 +169,10 @@ func New( smStateManagerTimerTickPeriod: smStateManagerTimerTickPeriod, smPruningMinStatesToKeep: smPruningMinStatesToKeep, smPruningMaxStatesToDelete: smPruningMaxStatesToDelete, + snapshotPeriod: snapshotPeriod, + snapshotFolderPath: snapshotFolderPath, + snapshotNetworkPaths: snapshotNetworkPaths, + snapshotUpdatePeriod: snapshotUpdatePeriod, chainRecordRegistryProvider: chainRecordRegistryProvider, dkShareRegistryProvider: dkShareRegistryProvider, nodeIdentityProvider: nodeIdentityProvider, @@ -249,7 +265,7 @@ func (c *Chains) activateAllFromRegistry() error { } // activateWithoutLocking activates a chain in the node. -func (c *Chains) activateWithoutLocking(chainID isc.ChainID) error { +func (c *Chains) activateWithoutLocking(chainID isc.ChainID) error { //nolint:funlen if c.ctx == nil { return errors.New("run chains first") } @@ -286,7 +302,7 @@ func (c *Chains) activateWithoutLocking(chainID isc.ChainID) error { chainLog := c.log.Named(chainID.ShortString()) var chainWAL sm_gpa_utils.BlockWAL if c.walEnabled { - chainWAL, err = sm_gpa_utils.NewBlockWAL(chainLog.Named("WAL"), c.walFolderPath, chainID, chainMetrics.BlockWAL) + chainWAL, err = sm_gpa_utils.NewBlockWAL(chainLog, c.walFolderPath, chainID, chainMetrics.BlockWAL) if err != nil { panic(fmt.Errorf("cannot create WAL: %w", err)) } @@ -303,28 +319,48 @@ func (c *Chains) activateWithoutLocking(chainID isc.ChainID) error { stateManagerParameters.StateManagerTimerTickPeriod = c.smStateManagerTimerTickPeriod stateManagerParameters.PruningMinStatesToKeep = c.smPruningMinStatesToKeep stateManagerParameters.PruningMaxStatesToDelete = c.smPruningMaxStatesToDelete + stateManagerParameters.SnapshotManagerUpdatePeriod = c.snapshotUpdatePeriod + // Initialize Snapshotter + chainStore := indexedstore.New(state.NewStoreWithMetrics(chainKVStore, chainMetrics.State)) chainCtx, chainCancel := context.WithCancel(c.ctx) validatorAgentID := accounts.CommonAccount() if c.validatorFeeAddr != nil { validatorAgentID = isc.NewAgentID(c.validatorFeeAddr) } + chainShutdownCoordinator := c.shutdownCoordinator.Nested(fmt.Sprintf("Chain-%s", chainID.AsAddress().String())) + chainSnapshotManager, err := sm_snapshots.NewSnapshotManager( + chainCtx, + chainShutdownCoordinator.Nested("SnapMgr"), + chainID, + c.snapshotPeriod, + c.snapshotFolderPath, + c.snapshotNetworkPaths, + chainStore, + chainLog, + ) + if err != nil { + panic(fmt.Errorf("cannot create Snapshotter: %w", err)) + } + newChain, err := chain.New( chainCtx, chainLog, chainID, - indexedstore.New(state.NewStoreWithMetrics(chainKVStore, chainMetrics.State)), + chainStore, c.nodeConnection, c.nodeIdentityProvider.NodeIdentity(), c.processorConfig, c.dkShareRegistryProvider, c.consensusStateRegistry, + c.walLoadToStore, chainWAL, + chainSnapshotManager, c.chainListener, chainRecord.AccessNodes, c.networkProvider, chainMetrics, - c.shutdownCoordinator.Nested(fmt.Sprintf("Chain-%s", chainID.AsAddress().String())), + chainShutdownCoordinator, func() { c.chainMetricsProvider.RegisterChain(chainID) }, func() { c.chainMetricsProvider.UnregisterChain(chainID) }, c.deriveAliasOutputByQuorum, diff --git a/packages/snapshot/snapshot.go b/packages/snapshot/snapshot.go deleted file mode 100644 index 4fc75ad0f2..0000000000 --- a/packages/snapshot/snapshot.go +++ /dev/null @@ -1,157 +0,0 @@ -package snapshot - -import ( - "errors" - "fmt" - "io" - "path" - "time" - - "github.com/iotaledger/wasp/packages/isc/coreutil" - "github.com/iotaledger/wasp/packages/kv" - "github.com/iotaledger/wasp/packages/kv/codec" - "github.com/iotaledger/wasp/packages/state" -) - -type ConsoleReportParams struct { - Console io.Writer - StatsEveryKVPairs int -} - -func FileName(stateIndex uint32) string { - return fmt.Sprintf("%d.snapshot", stateIndex) -} - -// WriteKVToStream dumps k/v pairs of the state into the -// file. Keys are not sorted, so the result in general is not deterministic -func WriteKVToStream(store kv.KVIterator, stream kv.StreamWriter, p ...ConsoleReportParams) error { - par := ConsoleReportParams{ - Console: io.Discard, - StatsEveryKVPairs: 100, - } - if len(p) > 0 { - par = p[0] - } - var err error - store.Iterate("", func(k kv.Key, v []byte) bool { - if err = stream.Write([]byte(k), v); err != nil { - return false - } - if par.StatsEveryKVPairs > 0 { - kvCount, bCount := stream.Stats() - if kvCount%par.StatsEveryKVPairs == 0 { - fmt.Fprintf(par.Console, "[WriteKVToStream] k/v pairs: %d, bytes: %d\n", kvCount, bCount) - } - } - return true - }) - if err != nil { - fmt.Fprintf(par.Console, "[WriteKVToStream] error while writing: %v\n", err) - return err - } - return nil -} - -func WriteSnapshot(sr state.State, dir string, p ...ConsoleReportParams) error { - par := ConsoleReportParams{ - Console: io.Discard, - StatsEveryKVPairs: 100, - } - if len(p) > 0 { - par = p[0] - } - stateIndex := sr.BlockIndex() - timestamp := sr.Timestamp() - fmt.Fprintf(par.Console, "[WriteSnapshot] state index: %d\n", stateIndex) - fmt.Fprintf(par.Console, "[WriteSnapshot] timestamp: %v\n", timestamp) - fname := path.Join(dir, FileName(stateIndex)) - fmt.Fprintf(par.Console, "[WriteSnapshot] will be writing to file: %s\n", fname) - - fstream, err := kv.CreateKVStreamFile(fname) - if err != nil { - return err - } - defer fstream.File.Close() - - fmt.Printf("[WriteSnapshot] writing to file ") - if err := WriteKVToStream(sr, fstream, par); err != nil { - return err - } - tKV, tBytes := fstream.Stats() - fmt.Fprintf(par.Console, "[WriteSnapshot] TOTAL: kv records: %d, bytes: %d\n", tKV, tBytes) - return nil -} - -type FileProperties struct { - FileName string - StateIndex uint32 - TimeStamp time.Time - NumRecords int - MaxKeyLen int - Bytes int -} - -func Scan(rdr kv.StreamIterator) (*FileProperties, error) { - ret := &FileProperties{} - var stateIndexFound, timestampFound bool - var errR error - - err := rdr.Iterate(func(k, v []byte) bool { - if string(k) == coreutil.StatePrefixBlockIndex { - if stateIndexFound { - errR = errors.New("duplicate record with state index") - return false - } - if ret.StateIndex, errR = codec.DecodeUint32(v); errR != nil { - return false - } - stateIndexFound = true - } - if string(k) == coreutil.StatePrefixTimestamp { - if timestampFound { - errR = errors.New("duplicate record with timestamp") - return false - } - if ret.TimeStamp, errR = codec.DecodeTime(v); errR != nil { - return false - } - timestampFound = true - } - if len(v) == 0 { - errR = errors.New("empty value encountered") - return false - } - ret.NumRecords++ - if len(k) > ret.MaxKeyLen { - ret.MaxKeyLen = len(k) - } - ret.Bytes += len(k) + len(v) + 6 - return true - }) - if err != nil { - return nil, err - } - if errR != nil { - return nil, errR - } - return ret, nil -} - -func ScanFile(fname string) (*FileProperties, error) { - stream, err := kv.OpenKVStreamFile(fname) - if err != nil { - return nil, err - } - defer stream.File.Close() - - ret, err := Scan(stream) - if err != nil { - return nil, err - } - ret.FileName = fname - return ret, nil -} - -func BlockFileName(chainid string, index uint32, h state.BlockHash) string { - return fmt.Sprintf("%08d.%s.%s.mut", index, h.String(), chainid) -} diff --git a/packages/snapshot/snapshot_test.go b/packages/snapshot/snapshot_test.go deleted file mode 100644 index ec051e7e28..0000000000 --- a/packages/snapshot/snapshot_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package snapshot - -import ( - "math/rand" - "os" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/iotaledger/hive.go/kvstore/mapdb" - "github.com/iotaledger/wasp/packages/kv" - "github.com/iotaledger/wasp/packages/origin" - "github.com/iotaledger/wasp/packages/state" - "github.com/iotaledger/wasp/packages/util" -) - -func Test1(t *testing.T) { - db := mapdb.NewMapDB() - st := state.NewStore(db) - origin.InitChain(st, nil, 0) - - tm := util.NewTimer() - count := 0 - totalBytes := 0 - - latest, err := st.LatestBlock() - require.NoError(t, err) - sd, err := st.NewStateDraft(time.Now(), latest.L1Commitment()) - require.NoError(t, err) - - seed := time.Now().UnixNano() - t.Log("seed:", seed) - rnd := util.NewPseudoRand(seed) - for i := 0; i < 1000; i++ { - k := randByteSlice(rnd, 4+1, 48) // key is hname + key - v := randByteSlice(rnd, 1, 128) - - sd.Set(kv.Key(k), v) - count++ - totalBytes += len(k) + len(v) + 6 - } - - t.Logf("write %d kv pairs, %d Mbytes, to in-memory state took %v", count, totalBytes/(1024*1024), tm.Duration()) - - tm = util.NewTimer() - block := st.Commit(sd) - err = st.SetLatest(block.TrieRoot()) - require.NoError(t, err) - t.Logf("commit and save state to in-memory db took %v", tm.Duration()) - - rdr, err := st.LatestState() - require.NoError(t, err) - - stateidx := rdr.BlockIndex() - ts := rdr.Timestamp() - - fname := FileName(stateidx) - t.Logf("file: %s", fname) - - tm = util.NewTimer() - err = WriteSnapshot(rdr, "", ConsoleReportParams{ - Console: os.Stdout, - StatsEveryKVPairs: 1_000_000, - }) - require.NoError(t, err) - t.Logf("write snapshot took %v", tm.Duration()) - defer os.Remove(fname) - - v, err := ScanFile(fname) - require.NoError(t, err) - require.EqualValues(t, stateidx, v.StateIndex) - require.True(t, ts.Equal(v.TimeStamp)) -} - -func randByteSlice(rnd *rand.Rand, minLength, maxLength int) []byte { - n := rnd.Intn(maxLength-minLength) + minLength - b := make([]byte, n) - rnd.Read(b) - return b -} diff --git a/packages/state/db.go b/packages/state/db.go index 57ecb3daf8..537f1a77ed 100644 --- a/packages/state/db.go +++ b/packages/state/db.go @@ -59,20 +59,24 @@ func (db *storeDB) setLatestTrieRoot(root trie.Hash) { db.mustSet(keyLatestTrieRoot(), root.Bytes()) } -func (db *storeDB) trieStore() trie.KVStore { +func trieStore(db kvstore.KVStore) trie.KVStore { return trie.NewHiveKVStoreAdapter(db, []byte{chaindb.PrefixTrie}) } func (db *storeDB) trieUpdatable(root trie.Hash) (*trie.TrieUpdatable, error) { - return trie.NewTrieUpdatable(db.trieStore(), root) + return trie.NewTrieUpdatable(trieStore(db), root) } func (db *storeDB) initTrie() trie.Hash { - return trie.MustInitRoot(db.trieStore()) + return trie.MustInitRoot(trieStore(db)) } func (db *storeDB) trieReader(root trie.Hash) (*trie.TrieReader, error) { - return trie.NewTrieReader(db.trieStore(), root) + return trieReader(trieStore(db), root) +} + +func trieReader(trieStore trie.KVStore, root trie.Hash) (*trie.TrieReader, error) { + return trie.NewTrieReader(trieStore, root) } func (db *storeDB) hasBlock(root trie.Hash) bool { @@ -136,3 +140,39 @@ func (db *storeDB) buffered() (*bufferedKVStore, *storeDB) { buf := newBufferedKVStore(db) return buf, &storeDB{buf} } + +func (db *storeDB) takeSnapshot(root trie.Hash, snapshot kvstore.KVStore) error { + if !db.hasBlock(root) { + return fmt.Errorf("cannot take snapshot: trie root not found: %s", root) + } + blockKey := keyBlockByTrieRoot(root) + err := snapshot.Set(blockKey, db.mustGet(blockKey)) + if err != nil { + return err + } + + trie, err := db.trieReader(root) + if err != nil { + return err + } + trie.CopyToStore(trieStore(snapshot)) + return nil +} + +func (db *storeDB) restoreSnapshot(root trie.Hash, snapshot kvstore.KVStore) error { + blockKey := keyBlockByTrieRoot(root) + blockBytes, err := snapshot.Get(blockKey) + if err != nil { + return err + } + db.mustSet(blockKey, blockBytes) + + trieSnapshot, err := trieReader(trieStore(snapshot), root) + if err != nil { + return err + } + trieSnapshot.CopyToStore(trieStore(db)) + + db.setLatestTrieRoot(root) + return nil +} diff --git a/packages/state/state_test.go b/packages/state/state_test.go index e6d5da7462..ca8468eb77 100644 --- a/packages/state/state_test.go +++ b/packages/state/state_test.go @@ -435,3 +435,42 @@ func TestPruning2(t *testing.T) { t.Logf("committed block: %d", len(trieRoots)) } } + +func TestSnapshot(t *testing.T) { + snapshot := mapdb.NewMapDB() + + trieRoot, blockHash := func() (trie.Hash, state.BlockHash) { + db := mapdb.NewMapDB() + cs := mustChainStore{initializedStore(db)} + for i := byte(1); i <= 10; i++ { + d := cs.NewStateDraft(time.Now(), cs.LatestBlock().L1Commitment()) + d.Set(kv.Key(fmt.Sprintf("k%d", i)), []byte("v")) + d.Set("k", []byte{i}) + block := cs.Commit(d) + err := cs.SetLatest(block.TrieRoot()) + require.NoError(t, err) + } + block := cs.LatestBlock() + err := cs.TakeSnapshot(block.TrieRoot(), snapshot) + require.NoError(t, err) + return block.TrieRoot(), block.Hash() + }() + + db := mapdb.NewMapDB() + cs := mustChainStore{state.NewStore(db)} + err := cs.RestoreSnapshot(trieRoot, snapshot) + require.NoError(t, err) + + block := cs.LatestBlock() + require.EqualValues(t, 10, block.StateIndex()) + require.EqualValues(t, blockHash, block.Hash()) + + _, err = cs.Store.BlockByTrieRoot(block.PreviousL1Commitment().TrieRoot()) + require.ErrorContains(t, err, "not found") + + state := cs.LatestState() + for i := byte(1); i <= 10; i++ { + require.EqualValues(t, []byte("v"), state.Get(kv.Key(fmt.Sprintf("k%d", i)))) + } + require.EqualValues(t, []byte{10}, state.Get("k")) +} diff --git a/packages/state/store.go b/packages/state/store.go index e63b49d8c2..130e7ff44f 100644 --- a/packages/state/store.go +++ b/packages/state/store.go @@ -123,7 +123,7 @@ func (s *store) extractBlock(d StateDraft) (Block, *buffered.Mutations, trie.Com for k := range d.Mutations().Dels { trie.Delete([]byte(k)) } - trieRoot, stats := trie.Commit(bufDB.trieStore()) + trieRoot, stats := trie.Commit(trieStore(bufDB)) block := &block{ trieRoot: trieRoot, mutations: d.Mutations(), @@ -154,7 +154,7 @@ func (s *store) Commit(d StateDraft) Block { func (s *store) Prune(trieRoot trie.Hash) (trie.PruneStats, error) { start := time.Now() buf, bufDB := s.db.buffered() - stats, err := trie.Prune(bufDB.trieStore(), trieRoot) + stats, err := trie.Prune(trieStore(bufDB), trieRoot) if err != nil { return trie.PruneStats{}, err } @@ -207,3 +207,11 @@ func (s *store) LatestState() (State, error) { func (s *store) LatestTrieRoot() (trie.Hash, error) { return s.db.latestTrieRoot() } + +func (s *store) TakeSnapshot(root trie.Hash, snapshot kvstore.KVStore) error { + return s.db.takeSnapshot(root, snapshot) +} + +func (s *store) RestoreSnapshot(root trie.Hash, snapshot kvstore.KVStore) error { + return s.db.restoreSnapshot(root, snapshot) +} diff --git a/packages/state/types.go b/packages/state/types.go index 2566dad2d5..07f76038ae 100644 --- a/packages/state/types.go +++ b/packages/state/types.go @@ -7,6 +7,7 @@ import ( "io" "time" + "github.com/iotaledger/hive.go/kvstore" "github.com/iotaledger/wasp/packages/kv" "github.com/iotaledger/wasp/packages/kv/buffered" "github.com/iotaledger/wasp/packages/trie" @@ -71,6 +72,13 @@ type Store interface { // Prune deletes the trie with the given root from the DB Prune(trie.Hash) (trie.PruneStats, error) + + // TakeSnapshot takes a snapshot of the block and trie at the given trie root. + TakeSnapshot(trie.Hash, kvstore.KVStore) error + + // RestoreSnapshot restores the block and trie from the given snapshot. + // It is not required for the previous trie root to be present in the DB. + RestoreSnapshot(trie.Hash, kvstore.KVStore) error } // A Block contains the mutations between the previous and current states, diff --git a/packages/trie/trie.go b/packages/trie/trie.go index b7663b4d37..4c95104b6d 100644 --- a/packages/trie/trie.go +++ b/packages/trie/trie.go @@ -107,3 +107,20 @@ func (tr *TrieReader) DebugDump() { return IterateContinue }) } + +func (tr *TrieReader) CopyToStore(snapshot KVStore) { + triePartition := makeWriterPartition(snapshot, partitionTrieNodes) + valuePartition := makeWriterPartition(snapshot, partitionValues) + refcounts := newRefcounts(snapshot) + tr.IterateNodes(func(_ []byte, n *NodeData, depth int) IterateNodesAction { + nodeKey := n.Commitment.Bytes() + triePartition.Set(nodeKey, tr.nodeStore.trieStore.Get(nodeKey)) + if n.Terminal != nil && !n.Terminal.IsValue { + n.Terminal.ExtractValue() + valueKey := n.Terminal.Bytes() + valuePartition.Set(valueKey, tr.nodeStore.valueStore.Get(valueKey)) + } + refcounts.incNodeAndValue(n) + return IterateContinue + }) +}