From 2338e8b9decfc48a67046590ab6347cbbc793128 Mon Sep 17 00:00:00 2001 From: Hans Moog <3293976+hmoog@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:04:17 +0200 Subject: [PATCH 1/5] Feat: bundled VM API in VM interface --- pkg/protocol/engine/ledger/ledger/ledger.go | 2 +- pkg/protocol/engine/ledger/ledger/vm.go | 26 +++++++++----- pkg/protocol/engine/mempool/tests/vm.go | 19 ++++++---- pkg/protocol/engine/mempool/transaction.go | 2 -- pkg/protocol/engine/mempool/v1/mempool.go | 36 +++++++------------ .../engine/mempool/v1/mempool_test.go | 20 ++++------- pkg/protocol/engine/mempool/vm.go | 14 ++++++-- 7 files changed, 62 insertions(+), 57 deletions(-) diff --git a/pkg/protocol/engine/ledger/ledger/ledger.go b/pkg/protocol/engine/ledger/ledger/ledger.go index 6136c26b2..5818a92ba 100644 --- a/pkg/protocol/engine/ledger/ledger/ledger.go +++ b/pkg/protocol/engine/ledger/ledger/ledger.go @@ -68,7 +68,7 @@ func NewProvider() module.Provider[*engine.Engine, ledger.Ledger] { l.setRetainTransactionFailureFunc(e.Retainer.RetainTransactionFailure) - l.memPool = mempoolv1.New(l.validateStardustTransaction, l.executeStardustVM, l.extractInputReferences, l.resolveState, e.Workers.CreateGroup("MemPool"), l.conflictDAG, e, l.errorHandler, mempoolv1.WithForkAllTransactions[ledger.BlockVoteRank](true)) + l.memPool = mempoolv1.New(NewVM(l), l.resolveState, e.Workers.CreateGroup("MemPool"), l.conflictDAG, l.errorHandler, mempoolv1.WithForkAllTransactions[ledger.BlockVoteRank](true)) e.EvictionState.Events.SlotEvicted.Hook(l.memPool.Evict) l.manaManager = mana.NewManager(l.apiProvider, l.resolveAccountOutput) diff --git a/pkg/protocol/engine/ledger/ledger/vm.go b/pkg/protocol/engine/ledger/ledger/vm.go index daf263543..20eb5bf3a 100644 --- a/pkg/protocol/engine/ledger/ledger/vm.go +++ b/pkg/protocol/engine/ledger/ledger/vm.go @@ -11,7 +11,17 @@ import ( "github.com/iotaledger/iota.go/v4/vm/stardust" ) -func (l *Ledger) extractInputReferences(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) { +type VM struct { + ledger *Ledger +} + +func NewVM(ledger *Ledger) *VM { + return &VM{ + ledger: ledger, + } +} + +func (v *VM) TransactionInputs(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) { stardustTransaction, ok := transaction.(*iotago.Transaction) if !ok { return nil, iotago.ErrTxTypeInvalid @@ -27,7 +37,7 @@ func (l *Ledger) extractInputReferences(transaction mempool.Transaction) (inputR return inputReferences, nil } -func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTransaction, resolvedInputStates []mempool.State) (executionContext context.Context, err error) { +func (v *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, resolvedInputStates []mempool.State) (executionContext context.Context, err error) { signedStardustTransaction, ok := signedTransaction.(*iotago.SignedTransaction) if !ok { return nil, iotago.ErrTxTypeInvalid @@ -57,7 +67,7 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra bicInputSet := make(iotagovm.BlockIssuanceCreditInputSet) for _, inp := range bicInputs { - accountData, exists, accountErr := l.accountsLedger.Account(inp.AccountID, commitmentInput.Slot) + accountData, exists, accountErr := v.ledger.accountsLedger.Account(inp.AccountID, commitmentInput.Slot) if accountErr != nil { return nil, ierrors.Join(iotago.ErrBICInputInvalid, ierrors.Wrapf(accountErr, "could not get BIC input for account %s in slot %d", inp.AccountID, commitmentInput.Slot)) } @@ -87,7 +97,7 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra accountID = iotago.AccountIDFromOutputID(outputID) } - reward, _, _, rewardErr := l.sybilProtection.ValidatorReward(accountID, stakingFeature.StakedAmount, stakingFeature.StartEpoch, stakingFeature.EndEpoch) + reward, _, _, rewardErr := v.ledger.sybilProtection.ValidatorReward(accountID, stakingFeature.StakedAmount, stakingFeature.StartEpoch, stakingFeature.EndEpoch) if rewardErr != nil { return nil, ierrors.Wrapf(iotago.ErrFailedToClaimStakingReward, "failed to get Validator reward for AccountOutput %s at index %d (StakedAmount: %d, StartEpoch: %d, EndEpoch: %d", outputID, inp.Index, stakingFeature.StakedAmount, stakingFeature.StartEpoch, stakingFeature.EndEpoch) } @@ -102,10 +112,10 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra delegationEnd := castOutput.EndEpoch if delegationEnd == 0 { - delegationEnd = l.apiProvider.APIForSlot(commitmentInput.Slot).TimeProvider().EpochFromSlot(commitmentInput.Slot) - iotago.EpochIndex(1) + delegationEnd = v.ledger.apiProvider.APIForSlot(commitmentInput.Slot).TimeProvider().EpochFromSlot(commitmentInput.Slot) - iotago.EpochIndex(1) } - reward, _, _, rewardErr := l.sybilProtection.DelegatorReward(castOutput.ValidatorAddress.AccountID(), castOutput.DelegatedAmount, castOutput.StartEpoch, delegationEnd) + reward, _, _, rewardErr := v.ledger.sybilProtection.DelegatorReward(castOutput.ValidatorAddress.AccountID(), castOutput.DelegatedAmount, castOutput.StartEpoch, delegationEnd) if rewardErr != nil { return nil, ierrors.Wrapf(iotago.ErrFailedToClaimDelegationReward, "failed to get Delegator reward for DelegationOutput %s at index %d (StakedAmount: %d, StartEpoch: %d, EndEpoch: %d", outputID, inp.Index, castOutput.DelegatedAmount, castOutput.StartEpoch, castOutput.EndEpoch) } @@ -133,7 +143,7 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra return executionContext, nil } -func (l *Ledger) executeStardustVM(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { +func (v *VM) ExecuteTransaction(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { stardustTransaction, ok := transaction.(*iotago.Transaction) if !ok { return nil, iotago.ErrTxTypeInvalid @@ -161,7 +171,7 @@ func (l *Ledger) executeStardustVM(executionContext context.Context, transaction for index, output := range createdOutputs { outputs = append(outputs, utxoledger.CreateOutput( - l.apiProvider, + v.ledger.apiProvider, iotago.OutputIDFromTransactionIDAndIndex(transactionID, uint16(index)), iotago.EmptyBlockID(), 0, diff --git a/pkg/protocol/engine/mempool/tests/vm.go b/pkg/protocol/engine/mempool/tests/vm.go index 4985e835f..b8f91a84a 100644 --- a/pkg/protocol/engine/mempool/tests/vm.go +++ b/pkg/protocol/engine/mempool/tests/vm.go @@ -6,24 +6,31 @@ import ( "github.com/iotaledger/hive.go/ierrors" ledgertests "github.com/iotaledger/iota-core/pkg/protocol/engine/ledger/tests" "github.com/iotaledger/iota-core/pkg/protocol/engine/mempool" + iotago "github.com/iotaledger/iota.go/v4" ) -func TransactionValidator(_ mempool.SignedTransaction, _ []mempool.State) (executionContext context.Context, err error) { +type VM struct{} + +func (V *VM) TransactionInputs(transaction mempool.Transaction) ([]iotago.Input, error) { + return transaction.(*Transaction).Inputs() +} + +func (V *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, resolvedInputs []mempool.State) (executionContext context.Context, err error) { return context.Background(), nil } -func TransactionExecutor(_ context.Context, inputTransaction mempool.Transaction) (outputs []mempool.State, err error) { - transaction, ok := inputTransaction.(*Transaction) +func (V *VM) ExecuteTransaction(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { + typedTransaction, ok := transaction.(*Transaction) if !ok { return nil, ierrors.New("invalid transaction type in MockedVM") } - if transaction.invalidTransaction { + if typedTransaction.invalidTransaction { return nil, ierrors.New("invalid transaction") } - for i := uint16(0); i < transaction.outputCount; i++ { - id, err := transaction.ID() + for i := uint16(0); i < typedTransaction.outputCount; i++ { + id, err := typedTransaction.ID() if err != nil { return nil, err } diff --git a/pkg/protocol/engine/mempool/transaction.go b/pkg/protocol/engine/mempool/transaction.go index fdad023a8..60b2dbde3 100644 --- a/pkg/protocol/engine/mempool/transaction.go +++ b/pkg/protocol/engine/mempool/transaction.go @@ -13,5 +13,3 @@ type Transaction interface { // ID returns the identifier of the Transaction. ID() (iotago.TransactionID, error) } - -type TransactionInputReferenceRetriever func(transaction Transaction) ([]iotago.Input, error) diff --git a/pkg/protocol/engine/mempool/v1/mempool.go b/pkg/protocol/engine/mempool/v1/mempool.go index 6e93935dc..76a9f677f 100644 --- a/pkg/protocol/engine/mempool/v1/mempool.go +++ b/pkg/protocol/engine/mempool/v1/mempool.go @@ -21,20 +21,12 @@ import ( // MemPool is a component that manages the state of transactions that are not yet included in the ledger state. type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct { - signedTransactionAttached *event.Event1[mempool.SignedTransactionMetadata] - - transactionAttached *event.Event1[mempool.TransactionMetadata] - - // executeStateTransition is the TransactionExecutor that is used to execute the state transition of transactions. - executeStateTransition mempool.TransactionExecutor + // vm is the virtual machine that is used to validate and execute transactions. + vm mempool.VM // resolveState is the function that is used to request state from the ledger. resolveState mempool.StateReferenceResolver - inputsOfTransaction mempool.TransactionInputReferenceRetriever - - validateSignedTransaction mempool.TransactionValidator - // attachments is the storage that is used to keep track of the attachments of transactions. attachments *memstorage.IndexedStorage[iotago.SlotIndex, iotago.BlockID, *SignedTransactionMetadata] @@ -66,27 +58,22 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct { optForkAllTransactions bool - apiProvider iotago.APIProvider + signedTransactionAttached *event.Event1[mempool.SignedTransactionMetadata] + + transactionAttached *event.Event1[mempool.TransactionMetadata] } // New is the constructor of the MemPool. func New[VoteRank conflictdag.VoteRankType[VoteRank]]( - transactionValidator mempool.TransactionValidator, - transactionExecutor mempool.TransactionExecutor, - transactionInputReferenceRetriever mempool.TransactionInputReferenceRetriever, + vm mempool.VM, stateResolver mempool.StateReferenceResolver, workers *workerpool.Group, conflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, VoteRank], - apiProvider iotago.APIProvider, errorHandler func(error), opts ...options.Option[MemPool[VoteRank]], ) *MemPool[VoteRank] { return options.Apply(&MemPool[VoteRank]{ - signedTransactionAttached: event.New1[mempool.SignedTransactionMetadata](), - transactionAttached: event.New1[mempool.TransactionMetadata](), - validateSignedTransaction: transactionValidator, - executeStateTransition: transactionExecutor, - inputsOfTransaction: transactionInputReferenceRetriever, + vm: vm, resolveState: stateResolver, attachments: memstorage.NewIndexedStorage[iotago.SlotIndex, iotago.BlockID, *SignedTransactionMetadata](), cachedTransactions: shrinkingmap.New[iotago.TransactionID, *TransactionMetadata](), @@ -95,8 +82,9 @@ func New[VoteRank conflictdag.VoteRankType[VoteRank]]( stateDiffs: shrinkingmap.New[iotago.SlotIndex, *StateDiff](), executionWorkers: workers.CreatePool("executionWorkers", 1), conflictDAG: conflictDAG, - apiProvider: apiProvider, errorHandler: errorHandler, + signedTransactionAttached: event.New1[mempool.SignedTransactionMetadata](), + transactionAttached: event.New1[mempool.TransactionMetadata](), }, opts, (*MemPool[VoteRank]).setup) } @@ -202,7 +190,7 @@ func (m *MemPool[VoteRank]) storeTransaction(signedTransaction mempool.SignedTra return nil, false, false, ierrors.Errorf("blockID %d is older than last evicted slot %d", blockID.Slot(), m.lastEvictedSlot) } - inputReferences, err := m.inputsOfTransaction(transaction) + inputReferences, err := m.vm.TransactionInputs(transaction) if err != nil { return nil, false, false, ierrors.Wrap(err, "failed to get input references of transaction") } @@ -261,7 +249,7 @@ func (m *MemPool[VoteRank]) solidifyInputs(transaction *TransactionMetadata) { func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context, transaction *TransactionMetadata) { m.executionWorkers.Submit(func() { - if outputStates, err := m.executeStateTransition(executionContext, transaction.Transaction()); err != nil { + if outputStates, err := m.vm.ExecuteTransaction(executionContext, transaction.Transaction()); err != nil { transaction.setInvalid(err) } else { transaction.setExecuted(outputStates) @@ -458,7 +446,7 @@ func (m *MemPool[VoteRank]) setupSignedTransaction(signedTransactionMetadata *Si transaction.addSigningTransaction(signedTransactionMetadata) transaction.OnSolid(func() { - executionContext, err := m.validateSignedTransaction(signedTransactionMetadata.SignedTransaction(), lo.Map(signedTransactionMetadata.transactionMetadata.inputs, (*StateMetadata).State)) + executionContext, err := m.vm.ValidateSignatures(signedTransactionMetadata.SignedTransaction(), lo.Map(signedTransactionMetadata.transactionMetadata.inputs, (*StateMetadata).State)) if err != nil { _ = signedTransactionMetadata.signaturesInvalid.Set(err) return diff --git a/pkg/protocol/engine/mempool/v1/mempool_test.go b/pkg/protocol/engine/mempool/v1/mempool_test.go index bb955f1a3..a79ff1068 100644 --- a/pkg/protocol/engine/mempool/v1/mempool_test.go +++ b/pkg/protocol/engine/mempool/v1/mempool_test.go @@ -19,8 +19,6 @@ import ( "github.com/iotaledger/iota-core/pkg/protocol/engine/mempool/conflictdag/conflictdagv1" mempooltests "github.com/iotaledger/iota-core/pkg/protocol/engine/mempool/tests" iotago "github.com/iotaledger/iota.go/v4" - "github.com/iotaledger/iota.go/v4/api" - "github.com/iotaledger/iota.go/v4/tpkg" ) func TestMemPoolV1_InterfaceWithoutForkingEverything(t *testing.T) { @@ -36,11 +34,9 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) { ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0)) conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](func() int { return 0 }) - mempoolInstance := New[vote.MockedRank](mempooltests.TransactionValidator, mempooltests.TransactionExecutor, func(transaction mempool.Transaction) ([]iotago.Input, error) { - return transaction.(*mempooltests.Transaction).Inputs() - }, func(reference iotago.Input) *promise.Promise[mempool.State] { + mempoolInstance := New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] { return ledgerState.ResolveOutputState(reference.StateID()) - }, workers, conflictDAG, api.SingleVersionProvider(tpkg.TestAPI), func(error) {}) + }, workers, conflictDAG, func(error) {}) tf := mempooltests.NewTestFramework(t, mempoolInstance, conflictDAG, ledgerState, workers) @@ -107,11 +103,9 @@ func newTestFramework(t *testing.T) *mempooltests.TestFramework { ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0)) conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](account.NewAccounts().SelectCommittee().SeatCount) - return mempooltests.NewTestFramework(t, New[vote.MockedRank](mempooltests.TransactionValidator, mempooltests.TransactionExecutor, func(transaction mempool.Transaction) ([]iotago.Input, error) { - return transaction.(*mempooltests.Transaction).Inputs() - }, func(reference iotago.Input) *promise.Promise[mempool.State] { + return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] { return ledgerState.ResolveOutputState(reference.StateID()) - }, workers, conflictDAG, api.SingleVersionProvider(tpkg.TestAPI), func(error) {}), conflictDAG, ledgerState, workers) + }, workers, conflictDAG, func(error) {}), conflictDAG, ledgerState, workers) } func newForkingTestFramework(t *testing.T) *mempooltests.TestFramework { @@ -120,9 +114,7 @@ func newForkingTestFramework(t *testing.T) *mempooltests.TestFramework { ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0)) conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](account.NewAccounts().SelectCommittee().SeatCount) - return mempooltests.NewTestFramework(t, New[vote.MockedRank](mempooltests.TransactionValidator, mempooltests.TransactionExecutor, func(transaction mempool.Transaction) ([]iotago.Input, error) { - return transaction.(*mempooltests.Transaction).Inputs() - }, func(reference iotago.Input) *promise.Promise[mempool.State] { + return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] { return ledgerState.ResolveOutputState(reference.StateID()) - }, workers, conflictDAG, api.SingleVersionProvider(tpkg.TestAPI), func(error) {}, WithForkAllTransactions[vote.MockedRank](true)), conflictDAG, ledgerState, workers) + }, workers, conflictDAG, func(error) {}, WithForkAllTransactions[vote.MockedRank](true)), conflictDAG, ledgerState, workers) } diff --git a/pkg/protocol/engine/mempool/vm.go b/pkg/protocol/engine/mempool/vm.go index 0d05fb49c..49ffc1d80 100644 --- a/pkg/protocol/engine/mempool/vm.go +++ b/pkg/protocol/engine/mempool/vm.go @@ -2,8 +2,18 @@ package mempool import ( "context" + + iotago "github.com/iotaledger/iota.go/v4" ) -type TransactionValidator func(signedTransaction SignedTransaction, resolvedInputs []State) (executionContext context.Context, err error) +// VM is the interface that defines the virtual machine that is used to validate and execute transactions. +type VM interface { + // TransactionInputs returns the inputs of the given transaction. + TransactionInputs(transaction Transaction) ([]iotago.Input, error) + + // ValidateSignatures validates the signatures of the given SignedTransaction and returns the execution context. + ValidateSignatures(signedTransaction SignedTransaction, resolvedInputs []State) (executionContext context.Context, err error) -type TransactionExecutor func(executionContext context.Context, transaction Transaction) (outputs []State, err error) + // ExecuteTransaction executes the transaction in the given execution context and returns the resulting states. + ExecuteTransaction(executionContext context.Context, transaction Transaction) (outputs []State, err error) +} From e2ed66c32d533574965cb618fac290e47b67d882 Mon Sep 17 00:00:00 2001 From: Hans Moog <3293976+hmoog@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:37:10 +0200 Subject: [PATCH 2/5] Feat: added alias for StateReference --- pkg/protocol/engine/ledger/ledger/vm.go | 4 +-- pkg/protocol/engine/mempool/mempool.go | 2 +- pkg/protocol/engine/mempool/state_metadata.go | 2 -- .../engine/mempool/state_reference.go | 8 ++--- pkg/protocol/engine/mempool/state_resolver.go | 8 +++++ .../engine/mempool/tests/testframework.go | 8 ++--- .../engine/mempool/tests/transaction.go | 6 ++-- pkg/protocol/engine/mempool/tests/vm.go | 5 ++- pkg/protocol/engine/mempool/v1/mempool.go | 32 ++++++++++--------- .../engine/mempool/v1/mempool_test.go | 24 +++++++------- pkg/protocol/engine/mempool/v1/state_diff.go | 24 +++++++------- .../engine/mempool/v1/state_metadata.go | 8 ----- .../engine/mempool/v1/transaction_metadata.go | 4 +-- .../mempool/v1/transaction_metadata_test.go | 2 +- pkg/protocol/engine/mempool/vm.go | 12 +++---- 15 files changed, 71 insertions(+), 78 deletions(-) create mode 100644 pkg/protocol/engine/mempool/state_resolver.go diff --git a/pkg/protocol/engine/ledger/ledger/vm.go b/pkg/protocol/engine/ledger/ledger/vm.go index 20eb5bf3a..8e4ef5d94 100644 --- a/pkg/protocol/engine/ledger/ledger/vm.go +++ b/pkg/protocol/engine/ledger/ledger/vm.go @@ -21,7 +21,7 @@ func NewVM(ledger *Ledger) *VM { } } -func (v *VM) TransactionInputs(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) { +func (v *VM) StateReferences(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) { stardustTransaction, ok := transaction.(*iotago.Transaction) if !ok { return nil, iotago.ErrTxTypeInvalid @@ -143,7 +143,7 @@ func (v *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, res return executionContext, nil } -func (v *VM) ExecuteTransaction(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { +func (v *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { stardustTransaction, ok := transaction.(*iotago.Transaction) if !ok { return nil, iotago.ErrTxTypeInvalid diff --git a/pkg/protocol/engine/mempool/mempool.go b/pkg/protocol/engine/mempool/mempool.go index 4f4bb8e18..d9b611009 100644 --- a/pkg/protocol/engine/mempool/mempool.go +++ b/pkg/protocol/engine/mempool/mempool.go @@ -15,7 +15,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] interface { MarkAttachmentIncluded(blockID iotago.BlockID) bool - StateMetadata(reference iotago.Input) (state StateMetadata, err error) + StateMetadata(reference StateReference) (state StateMetadata, err error) TransactionMetadata(id iotago.TransactionID) (transaction TransactionMetadata, exists bool) diff --git a/pkg/protocol/engine/mempool/state_metadata.go b/pkg/protocol/engine/mempool/state_metadata.go index 6227fd432..1a8b8cb2e 100644 --- a/pkg/protocol/engine/mempool/state_metadata.go +++ b/pkg/protocol/engine/mempool/state_metadata.go @@ -6,8 +6,6 @@ import ( ) type StateMetadata interface { - StateID() StateID - State() State ConflictIDs() reactive.Set[iotago.TransactionID] diff --git a/pkg/protocol/engine/mempool/state_reference.go b/pkg/protocol/engine/mempool/state_reference.go index 9c9de8d7e..37e7f596e 100644 --- a/pkg/protocol/engine/mempool/state_reference.go +++ b/pkg/protocol/engine/mempool/state_reference.go @@ -1,9 +1,5 @@ package mempool -import ( - "github.com/iotaledger/iota-core/pkg/core/promise" - iotago "github.com/iotaledger/iota.go/v4" -) +import iotago "github.com/iotaledger/iota.go/v4" -// StateReferenceResolver is a function that resolves a StateReference to a State. -type StateReferenceResolver func(reference iotago.Input) *promise.Promise[State] +type StateReference = iotago.Input diff --git a/pkg/protocol/engine/mempool/state_resolver.go b/pkg/protocol/engine/mempool/state_resolver.go new file mode 100644 index 000000000..0102800ff --- /dev/null +++ b/pkg/protocol/engine/mempool/state_resolver.go @@ -0,0 +1,8 @@ +package mempool + +import ( + "github.com/iotaledger/iota-core/pkg/core/promise" +) + +// StateResolver is a function that resolves a StateReference to a Promise with the State. +type StateResolver func(reference StateReference) *promise.Promise[State] diff --git a/pkg/protocol/engine/mempool/tests/testframework.go b/pkg/protocol/engine/mempool/tests/testframework.go index 57451f283..02125a3b3 100644 --- a/pkg/protocol/engine/mempool/tests/testframework.go +++ b/pkg/protocol/engine/mempool/tests/testframework.go @@ -22,7 +22,7 @@ type TestFramework struct { Instance mempool.MemPool[vote.MockedRank] ConflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, vote.MockedRank] - referencesByAlias map[string]iotago.Input + referencesByAlias map[string]mempool.StateReference stateIDByAlias map[string]mempool.StateID signedTransactionByAlias map[string]mempool.SignedTransaction transactionByAlias map[string]mempool.Transaction @@ -39,7 +39,7 @@ func NewTestFramework(test *testing.T, instance mempool.MemPool[vote.MockedRank] t := &TestFramework{ Instance: instance, ConflictDAG: conflictDAG, - referencesByAlias: make(map[string]iotago.Input), + referencesByAlias: make(map[string]mempool.StateReference), stateIDByAlias: make(map[string]mempool.StateID), signedTransactionByAlias: make(map[string]mempool.SignedTransaction), transactionByAlias: make(map[string]mempool.Transaction), @@ -311,7 +311,7 @@ func (t *TestFramework) setupHookedEvents() { }) } -func (t *TestFramework) stateReference(alias string) iotago.Input { +func (t *TestFramework) stateReference(alias string) mempool.StateReference { if alias == "genesis" { return &iotago.UTXOInput{} } @@ -393,7 +393,7 @@ func (t *TestFramework) Cleanup() { iotago.UnregisterIdentifierAliases() - t.referencesByAlias = make(map[string]iotago.Input) + t.referencesByAlias = make(map[string]mempool.StateReference) t.stateIDByAlias = make(map[string]mempool.StateID) t.transactionByAlias = make(map[string]mempool.Transaction) t.signedTransactionByAlias = make(map[string]mempool.SignedTransaction) diff --git a/pkg/protocol/engine/mempool/tests/transaction.go b/pkg/protocol/engine/mempool/tests/transaction.go index a21fb2523..0726b032b 100644 --- a/pkg/protocol/engine/mempool/tests/transaction.go +++ b/pkg/protocol/engine/mempool/tests/transaction.go @@ -21,7 +21,7 @@ func (s *SignedTransaction) String() string { type Transaction struct { id iotago.TransactionID - inputs []iotago.Input + inputs []mempool.StateReference outputCount uint16 invalidTransaction bool } @@ -33,7 +33,7 @@ func NewSignedTransaction(transaction mempool.Transaction) *SignedTransaction { } } -func NewTransaction(outputCount uint16, inputs ...iotago.Input) *Transaction { +func NewTransaction(outputCount uint16, inputs ...mempool.StateReference) *Transaction { return &Transaction{ id: tpkg.RandTransactionID(), inputs: inputs, @@ -45,7 +45,7 @@ func (t *Transaction) ID() (iotago.TransactionID, error) { return t.id, nil } -func (t *Transaction) Inputs() ([]iotago.Input, error) { +func (t *Transaction) Inputs() ([]mempool.StateReference, error) { return t.inputs, nil } diff --git a/pkg/protocol/engine/mempool/tests/vm.go b/pkg/protocol/engine/mempool/tests/vm.go index b8f91a84a..594e49d5f 100644 --- a/pkg/protocol/engine/mempool/tests/vm.go +++ b/pkg/protocol/engine/mempool/tests/vm.go @@ -6,12 +6,11 @@ import ( "github.com/iotaledger/hive.go/ierrors" ledgertests "github.com/iotaledger/iota-core/pkg/protocol/engine/ledger/tests" "github.com/iotaledger/iota-core/pkg/protocol/engine/mempool" - iotago "github.com/iotaledger/iota.go/v4" ) type VM struct{} -func (V *VM) TransactionInputs(transaction mempool.Transaction) ([]iotago.Input, error) { +func (V *VM) StateReferences(transaction mempool.Transaction) ([]mempool.StateReference, error) { return transaction.(*Transaction).Inputs() } @@ -19,7 +18,7 @@ func (V *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, res return context.Background(), nil } -func (V *VM) ExecuteTransaction(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { +func (V *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { typedTransaction, ok := transaction.(*Transaction) if !ok { return nil, ierrors.New("invalid transaction type in MockedVM") diff --git a/pkg/protocol/engine/mempool/v1/mempool.go b/pkg/protocol/engine/mempool/v1/mempool.go index 76a9f677f..83a9a763d 100644 --- a/pkg/protocol/engine/mempool/v1/mempool.go +++ b/pkg/protocol/engine/mempool/v1/mempool.go @@ -25,7 +25,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct { vm mempool.VM // resolveState is the function that is used to request state from the ledger. - resolveState mempool.StateReferenceResolver + resolveState mempool.StateResolver // attachments is the storage that is used to keep track of the attachments of transactions. attachments *memstorage.IndexedStorage[iotago.SlotIndex, iotago.BlockID, *SignedTransactionMetadata] @@ -66,7 +66,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct { // New is the constructor of the MemPool. func New[VoteRank conflictdag.VoteRankType[VoteRank]]( vm mempool.VM, - stateResolver mempool.StateReferenceResolver, + stateResolver mempool.StateResolver, workers *workerpool.Group, conflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, VoteRank], errorHandler func(error), @@ -128,7 +128,7 @@ func (m *MemPool[VoteRank]) TransactionMetadata(id iotago.TransactionID) (transa } // StateMetadata returns the metadata of the output state with the given ID. -func (m *MemPool[VoteRank]) StateMetadata(stateReference iotago.Input) (state mempool.StateMetadata, err error) { +func (m *MemPool[VoteRank]) StateMetadata(stateReference mempool.StateReference) (state mempool.StateMetadata, err error) { stateRequest, exists := m.cachedStateRequests.Get(stateReference.StateID()) // create a new request that does not wait for missing states @@ -190,7 +190,7 @@ func (m *MemPool[VoteRank]) storeTransaction(signedTransaction mempool.SignedTra return nil, false, false, ierrors.Errorf("blockID %d is older than last evicted slot %d", blockID.Slot(), m.lastEvictedSlot) } - inputReferences, err := m.vm.TransactionInputs(transaction) + inputReferences, err := m.vm.StateReferences(transaction) if err != nil { return nil, false, false, ierrors.Wrap(err, "failed to get input references of transaction") } @@ -249,7 +249,7 @@ func (m *MemPool[VoteRank]) solidifyInputs(transaction *TransactionMetadata) { func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context, transaction *TransactionMetadata) { m.executionWorkers.Submit(func() { - if outputStates, err := m.vm.ExecuteTransaction(executionContext, transaction.Transaction()); err != nil { + if outputStates, err := m.vm.Execute(executionContext, transaction.Transaction()); err != nil { transaction.setInvalid(err) } else { transaction.setExecuted(outputStates) @@ -261,11 +261,13 @@ func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context, func (m *MemPool[VoteRank]) bookTransaction(transaction *TransactionMetadata) { if m.optForkAllTransactions { - m.forkTransaction(transaction, ds.NewSet(lo.Map(transaction.inputs, (*StateMetadata).StateID)...)) + m.forkTransaction(transaction, ds.NewSet(lo.Map(transaction.inputs, func(stateMetadata *StateMetadata) mempool.StateID { + return stateMetadata.state.StateID() + })...)) } else { lo.ForEach(transaction.inputs, func(input *StateMetadata) { input.OnDoubleSpent(func() { - m.forkTransaction(transaction, ds.NewSet(input.StateID())) + m.forkTransaction(transaction, ds.NewSet(input.state.StateID())) }) }) } @@ -287,7 +289,7 @@ func (m *MemPool[VoteRank]) forkTransaction(transactionMetadata *TransactionMeta func (m *MemPool[VoteRank]) publishOutputStates(transaction *TransactionMetadata) { for _, output := range transaction.outputs { - stateRequest, isNew := m.cachedStateRequests.GetOrCreate(output.StateID(), lo.NoVariadic(promise.New[*StateMetadata])) + stateRequest, isNew := m.cachedStateRequests.GetOrCreate(output.State().StateID(), lo.NoVariadic(promise.New[*StateMetadata])) stateRequest.Resolve(output) if isNew { @@ -296,7 +298,7 @@ func (m *MemPool[VoteRank]) publishOutputStates(transaction *TransactionMetadata } } -func (m *MemPool[VoteRank]) requestState(stateRef iotago.Input, waitIfMissing ...bool) *promise.Promise[*StateMetadata] { +func (m *MemPool[VoteRank]) requestState(stateRef mempool.StateReference, waitIfMissing ...bool) *promise.Promise[*StateMetadata] { return promise.New(func(p *promise.Promise[*StateMetadata]) { request := m.resolveState(stateRef) @@ -430,15 +432,15 @@ func (m *MemPool[VoteRank]) setupTransaction(transaction *TransactionMetadata) { }) } -func (m *MemPool[VoteRank]) setupOutputState(state *StateMetadata) { - state.OnCommitted(func() { - if !m.cachedStateRequests.Delete(state.StateID(), state.HasNoSpenders) && m.cachedStateRequests.Has(state.StateID()) { - state.onAllSpendersRemoved(func() { m.cachedStateRequests.Delete(state.StateID(), state.HasNoSpenders) }) +func (m *MemPool[VoteRank]) setupOutputState(stateMetadata *StateMetadata) { + stateMetadata.OnCommitted(func() { + if !m.cachedStateRequests.Delete(stateMetadata.state.StateID(), stateMetadata.HasNoSpenders) && m.cachedStateRequests.Has(stateMetadata.state.StateID()) { + stateMetadata.onAllSpendersRemoved(func() { m.cachedStateRequests.Delete(stateMetadata.state.StateID(), stateMetadata.HasNoSpenders) }) } }) - state.OnOrphaned(func() { - m.cachedStateRequests.Delete(state.StateID()) + stateMetadata.OnOrphaned(func() { + m.cachedStateRequests.Delete(stateMetadata.state.StateID()) }) } diff --git a/pkg/protocol/engine/mempool/v1/mempool_test.go b/pkg/protocol/engine/mempool/v1/mempool_test.go index a79ff1068..c4f64dee2 100644 --- a/pkg/protocol/engine/mempool/v1/mempool_test.go +++ b/pkg/protocol/engine/mempool/v1/mempool_test.go @@ -34,11 +34,11 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) { ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0)) conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](func() int { return 0 }) - mempoolInstance := New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] { + memPoolInstance := New[vote.MockedRank](new(mempooltests.VM), func(reference mempool.StateReference) *promise.Promise[mempool.State] { return ledgerState.ResolveOutputState(reference.StateID()) }, workers, conflictDAG, func(error) {}) - tf := mempooltests.NewTestFramework(t, mempoolInstance, conflictDAG, ledgerState, workers) + tf := mempooltests.NewTestFramework(t, memPoolInstance, conflictDAG, ledgerState, workers) issueTransactions := func(startIndex, transactionCount int, prevStateAlias string) (int, string) { index := startIndex @@ -58,7 +58,7 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) { tf.CommitSlot(iotago.SlotIndex(index)) tf.Instance.Evict(iotago.SlotIndex(index)) - require.Nil(t, mempoolInstance.attachments.Get(iotago.SlotIndex(index), false)) + require.Nil(t, memPoolInstance.attachments.Get(iotago.SlotIndex(index), false)) } return index, prevStateAlias @@ -70,19 +70,19 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) { txIndex, prevStateAlias := issueTransactions(1, 10, "genesis") tf.WaitChildren() - require.Equal(t, 0, mempoolInstance.cachedTransactions.Size()) - require.Equal(t, 0, mempoolInstance.stateDiffs.Size()) - require.Equal(t, 0, mempoolInstance.cachedStateRequests.Size()) + require.Equal(t, 0, memPoolInstance.cachedTransactions.Size()) + require.Equal(t, 0, memPoolInstance.stateDiffs.Size()) + require.Equal(t, 0, memPoolInstance.cachedStateRequests.Size()) txIndex, prevStateAlias = issueTransactions(txIndex, 10, prevStateAlias) tf.WaitChildren() - require.Equal(t, 0, mempoolInstance.cachedTransactions.Size()) - require.Equal(t, 0, mempoolInstance.stateDiffs.Size()) - require.Equal(t, 0, mempoolInstance.cachedStateRequests.Size()) + require.Equal(t, 0, memPoolInstance.cachedTransactions.Size()) + require.Equal(t, 0, memPoolInstance.stateDiffs.Size()) + require.Equal(t, 0, memPoolInstance.cachedStateRequests.Size()) attachmentsSlotCount := 0 - mempoolInstance.attachments.ForEach(func(index iotago.SlotIndex, storage *shrinkingmap.ShrinkingMap[iotago.BlockID, *SignedTransactionMetadata]) { + memPoolInstance.attachments.ForEach(func(index iotago.SlotIndex, storage *shrinkingmap.ShrinkingMap[iotago.BlockID, *SignedTransactionMetadata]) { attachmentsSlotCount++ }) @@ -103,7 +103,7 @@ func newTestFramework(t *testing.T) *mempooltests.TestFramework { ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0)) conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](account.NewAccounts().SelectCommittee().SeatCount) - return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] { + return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference mempool.StateReference) *promise.Promise[mempool.State] { return ledgerState.ResolveOutputState(reference.StateID()) }, workers, conflictDAG, func(error) {}), conflictDAG, ledgerState, workers) } @@ -114,7 +114,7 @@ func newForkingTestFramework(t *testing.T) *mempooltests.TestFramework { ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0)) conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](account.NewAccounts().SelectCommittee().SeatCount) - return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] { + return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference mempool.StateReference) *promise.Promise[mempool.State] { return ledgerState.ResolveOutputState(reference.StateID()) }, workers, conflictDAG, func(error) {}, WithForkAllTransactions[vote.MockedRank](true)), conflictDAG, ledgerState, workers) } diff --git a/pkg/protocol/engine/mempool/v1/state_diff.go b/pkg/protocol/engine/mempool/v1/state_diff.go index b152f489d..f9efc845d 100644 --- a/pkg/protocol/engine/mempool/v1/state_diff.go +++ b/pkg/protocol/engine/mempool/v1/state_diff.go @@ -56,17 +56,17 @@ func (s *StateDiff) Mutations() ads.Set[iotago.TransactionID] { } func (s *StateDiff) updateCompactedStateChanges(transaction *TransactionMetadata, direction int) { - transaction.Inputs().Range(func(input mempool.StateMetadata) { - s.compactStateChanges(input, s.stateUsageCounters.Compute(input.StateID(), func(currentValue int, _ bool) int { + for _, input := range transaction.inputs { + s.compactStateChanges(input, s.stateUsageCounters.Compute(input.State().StateID(), func(currentValue int, _ bool) int { return currentValue - direction })) - }) + } - transaction.Outputs().Range(func(output mempool.StateMetadata) { - s.compactStateChanges(output, s.stateUsageCounters.Compute(output.StateID(), func(currentValue int, _ bool) int { + for _, output := range transaction.outputs { + s.compactStateChanges(output, s.stateUsageCounters.Compute(output.State().StateID(), func(currentValue int, _ bool) int { return currentValue + direction })) - }) + } } func (s *StateDiff) AddTransaction(transaction *TransactionMetadata, errorHandler func(error)) error { @@ -97,19 +97,19 @@ func (s *StateDiff) RollbackTransaction(transaction *TransactionMetadata) error return nil } -func (s *StateDiff) compactStateChanges(output mempool.StateMetadata, newValue int) { - if output.State().Type() != iotago.InputUTXO { +func (s *StateDiff) compactStateChanges(output *StateMetadata, newValue int) { + if output.state.Type() != iotago.InputUTXO { return } switch { case newValue > 0: - s.createdOutputs.Set(output.StateID(), output) + s.createdOutputs.Set(output.State().StateID(), output) case newValue < 0: - s.spentOutputs.Set(output.StateID(), output) + s.spentOutputs.Set(output.State().StateID(), output) default: - s.createdOutputs.Delete(output.StateID()) - s.spentOutputs.Delete(output.StateID()) + s.createdOutputs.Delete(output.State().StateID()) + s.spentOutputs.Delete(output.State().StateID()) } } diff --git a/pkg/protocol/engine/mempool/v1/state_metadata.go b/pkg/protocol/engine/mempool/v1/state_metadata.go index 67a3f6aca..f7f1a6593 100644 --- a/pkg/protocol/engine/mempool/v1/state_metadata.go +++ b/pkg/protocol/engine/mempool/v1/state_metadata.go @@ -59,14 +59,6 @@ func (s *StateMetadata) setup(optSource ...*TransactionMetadata) *StateMetadata return s } -func (s *StateMetadata) StateID() mempool.StateID { - return s.state.StateID() -} - -func (s *StateMetadata) Type() iotago.StateType { - return iotago.InputUTXO -} - func (s *StateMetadata) State() mempool.State { return s.state } diff --git a/pkg/protocol/engine/mempool/v1/transaction_metadata.go b/pkg/protocol/engine/mempool/v1/transaction_metadata.go index 98e968577..c14b6fa08 100644 --- a/pkg/protocol/engine/mempool/v1/transaction_metadata.go +++ b/pkg/protocol/engine/mempool/v1/transaction_metadata.go @@ -16,7 +16,7 @@ import ( type TransactionMetadata struct { id iotago.TransactionID - inputReferences []iotago.Input + inputReferences []mempool.StateReference inputs []*StateMetadata outputs []*StateMetadata transaction mempool.Transaction @@ -57,7 +57,7 @@ func (t *TransactionMetadata) ValidAttachments() []iotago.BlockID { return t.validAttachments.Keys() } -func NewTransactionMetadata(transaction mempool.Transaction, referencedInputs []iotago.Input) (*TransactionMetadata, error) { +func NewTransactionMetadata(transaction mempool.Transaction, referencedInputs []mempool.StateReference) (*TransactionMetadata, error) { transactionID, transactionIDErr := transaction.ID() if transactionIDErr != nil { return nil, ierrors.Errorf("failed to retrieve transaction ID: %w", transactionIDErr) diff --git a/pkg/protocol/engine/mempool/v1/transaction_metadata_test.go b/pkg/protocol/engine/mempool/v1/transaction_metadata_test.go index dae4107c4..988bbb42f 100644 --- a/pkg/protocol/engine/mempool/v1/transaction_metadata_test.go +++ b/pkg/protocol/engine/mempool/v1/transaction_metadata_test.go @@ -15,7 +15,7 @@ func TestAttachments(t *testing.T) { "2": iotago.SlotIdentifierRepresentingData(2, []byte("block2")), } - transactionMetadata, err := NewTransactionMetadata(mempooltests.NewTransaction(2), []iotago.Input{}) + transactionMetadata, err := NewTransactionMetadata(mempooltests.NewTransaction(2), nil) require.NoError(t, err) signedTransactionMetadata, err := NewSignedTransactionMetadata(mempooltests.NewSignedTransaction(transactionMetadata.Transaction()), transactionMetadata) diff --git a/pkg/protocol/engine/mempool/vm.go b/pkg/protocol/engine/mempool/vm.go index 49ffc1d80..95f62ec74 100644 --- a/pkg/protocol/engine/mempool/vm.go +++ b/pkg/protocol/engine/mempool/vm.go @@ -2,18 +2,16 @@ package mempool import ( "context" - - iotago "github.com/iotaledger/iota.go/v4" ) // VM is the interface that defines the virtual machine that is used to validate and execute transactions. type VM interface { - // TransactionInputs returns the inputs of the given transaction. - TransactionInputs(transaction Transaction) ([]iotago.Input, error) + // StateReferences returns the inputs of the given transaction. + StateReferences(transaction Transaction) ([]StateReference, error) // ValidateSignatures validates the signatures of the given SignedTransaction and returns the execution context. - ValidateSignatures(signedTransaction SignedTransaction, resolvedInputs []State) (executionContext context.Context, err error) + ValidateSignatures(signedTransaction SignedTransaction, inputs []State) (executionContext context.Context, err error) - // ExecuteTransaction executes the transaction in the given execution context and returns the resulting states. - ExecuteTransaction(executionContext context.Context, transaction Transaction) (outputs []State, err error) + // Execute executes the transaction in the given execution context and returns the resulting states. + Execute(executionContext context.Context, transaction Transaction) (outputs []State, err error) } From 6b7a3c6d85b6b1e5110b0073e124a91acd4613f3 Mon Sep 17 00:00:00 2001 From: Hans Moog <3293976+hmoog@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:41:27 +0200 Subject: [PATCH 3/5] Fix: fixed more types --- pkg/protocol/engine/mempool/v1/state_diff.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pkg/protocol/engine/mempool/v1/state_diff.go b/pkg/protocol/engine/mempool/v1/state_diff.go index f9efc845d..e0dcc847e 100644 --- a/pkg/protocol/engine/mempool/v1/state_diff.go +++ b/pkg/protocol/engine/mempool/v1/state_diff.go @@ -57,13 +57,13 @@ func (s *StateDiff) Mutations() ads.Set[iotago.TransactionID] { func (s *StateDiff) updateCompactedStateChanges(transaction *TransactionMetadata, direction int) { for _, input := range transaction.inputs { - s.compactStateChanges(input, s.stateUsageCounters.Compute(input.State().StateID(), func(currentValue int, _ bool) int { + s.compactStateChanges(input, s.stateUsageCounters.Compute(input.state.StateID(), func(currentValue int, _ bool) int { return currentValue - direction })) } for _, output := range transaction.outputs { - s.compactStateChanges(output, s.stateUsageCounters.Compute(output.State().StateID(), func(currentValue int, _ bool) int { + s.compactStateChanges(output, s.stateUsageCounters.Compute(output.state.StateID(), func(currentValue int, _ bool) int { return currentValue + direction })) } @@ -91,6 +91,7 @@ func (s *StateDiff) RollbackTransaction(transaction *TransactionMetadata) error if _, err := s.mutations.Delete(transaction.ID()); err != nil { return ierrors.Wrapf(err, "failed to delete transaction from state diff's mutations, txID: %s", transaction.ID()) } + s.updateCompactedStateChanges(transaction, -1) } @@ -104,12 +105,12 @@ func (s *StateDiff) compactStateChanges(output *StateMetadata, newValue int) { switch { case newValue > 0: - s.createdOutputs.Set(output.State().StateID(), output) + s.createdOutputs.Set(output.state.StateID(), output) case newValue < 0: - s.spentOutputs.Set(output.State().StateID(), output) + s.spentOutputs.Set(output.state.StateID(), output) default: - s.createdOutputs.Delete(output.State().StateID()) - s.spentOutputs.Delete(output.State().StateID()) + s.createdOutputs.Delete(output.state.StateID()) + s.spentOutputs.Delete(output.state.StateID()) } } From 58b7668dbd8d4c7514c2532923736784031055fb Mon Sep 17 00:00:00 2001 From: Hans Moog <3293976+hmoog@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:46:33 +0200 Subject: [PATCH 4/5] Feat: removed unused params --- pkg/protocol/engine/mempool/tests/vm.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/protocol/engine/mempool/tests/vm.go b/pkg/protocol/engine/mempool/tests/vm.go index 594e49d5f..f8a05cbfe 100644 --- a/pkg/protocol/engine/mempool/tests/vm.go +++ b/pkg/protocol/engine/mempool/tests/vm.go @@ -14,11 +14,11 @@ func (V *VM) StateReferences(transaction mempool.Transaction) ([]mempool.StateRe return transaction.(*Transaction).Inputs() } -func (V *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, resolvedInputs []mempool.State) (executionContext context.Context, err error) { +func (V *VM) ValidateSignatures(_ mempool.SignedTransaction, _ []mempool.State) (executionContext context.Context, err error) { return context.Background(), nil } -func (V *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { +func (V *VM) Execute(_ context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) { typedTransaction, ok := transaction.(*Transaction) if !ok { return nil, ierrors.New("invalid transaction type in MockedVM") From eddc037f90a4c67f0d43976ca52241a7ff508227 Mon Sep 17 00:00:00 2001 From: Hans Moog <3293976+hmoog@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:53:44 +0200 Subject: [PATCH 5/5] Refactor: refactored more code --- pkg/protocol/engine/ledger/ledger/vm.go | 2 +- pkg/protocol/engine/mempool/tests/vm.go | 9 +++++++-- pkg/protocol/engine/mempool/v1/mempool.go | 2 +- pkg/protocol/engine/mempool/vm.go | 4 ++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pkg/protocol/engine/ledger/ledger/vm.go b/pkg/protocol/engine/ledger/ledger/vm.go index 8e4ef5d94..34fd045c1 100644 --- a/pkg/protocol/engine/ledger/ledger/vm.go +++ b/pkg/protocol/engine/ledger/ledger/vm.go @@ -21,7 +21,7 @@ func NewVM(ledger *Ledger) *VM { } } -func (v *VM) StateReferences(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) { +func (v *VM) Inputs(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) { stardustTransaction, ok := transaction.(*iotago.Transaction) if !ok { return nil, iotago.ErrTxTypeInvalid diff --git a/pkg/protocol/engine/mempool/tests/vm.go b/pkg/protocol/engine/mempool/tests/vm.go index f8a05cbfe..e1e558954 100644 --- a/pkg/protocol/engine/mempool/tests/vm.go +++ b/pkg/protocol/engine/mempool/tests/vm.go @@ -10,8 +10,13 @@ import ( type VM struct{} -func (V *VM) StateReferences(transaction mempool.Transaction) ([]mempool.StateReference, error) { - return transaction.(*Transaction).Inputs() +func (V *VM) Inputs(transaction mempool.Transaction) ([]mempool.StateReference, error) { + testTransaction, ok := transaction.(*Transaction) + if !ok { + return nil, ierrors.New("invalid transaction type in MockedVM") + } + + return testTransaction.Inputs() } func (V *VM) ValidateSignatures(_ mempool.SignedTransaction, _ []mempool.State) (executionContext context.Context, err error) { diff --git a/pkg/protocol/engine/mempool/v1/mempool.go b/pkg/protocol/engine/mempool/v1/mempool.go index 83a9a763d..a831184d1 100644 --- a/pkg/protocol/engine/mempool/v1/mempool.go +++ b/pkg/protocol/engine/mempool/v1/mempool.go @@ -190,7 +190,7 @@ func (m *MemPool[VoteRank]) storeTransaction(signedTransaction mempool.SignedTra return nil, false, false, ierrors.Errorf("blockID %d is older than last evicted slot %d", blockID.Slot(), m.lastEvictedSlot) } - inputReferences, err := m.vm.StateReferences(transaction) + inputReferences, err := m.vm.Inputs(transaction) if err != nil { return nil, false, false, ierrors.Wrap(err, "failed to get input references of transaction") } diff --git a/pkg/protocol/engine/mempool/vm.go b/pkg/protocol/engine/mempool/vm.go index 95f62ec74..a294ab057 100644 --- a/pkg/protocol/engine/mempool/vm.go +++ b/pkg/protocol/engine/mempool/vm.go @@ -6,8 +6,8 @@ import ( // VM is the interface that defines the virtual machine that is used to validate and execute transactions. type VM interface { - // StateReferences returns the inputs of the given transaction. - StateReferences(transaction Transaction) ([]StateReference, error) + // Inputs returns the referenced inputs of the given transaction. + Inputs(transaction Transaction) ([]StateReference, error) // ValidateSignatures validates the signatures of the given SignedTransaction and returns the execution context. ValidateSignatures(signedTransaction SignedTransaction, inputs []State) (executionContext context.Context, err error)