Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Introduce alias for StateReferences in mempool and cleanup types #395

Merged
merged 5 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/protocol/engine/ledger/ledger/ledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func NewProvider() module.Provider[*engine.Engine, ledger.Ledger] {

l.setRetainTransactionFailureFunc(e.Retainer.RetainTransactionFailure)

l.memPool = mempoolv1.New(l.validateStardustTransaction, l.executeStardustVM, l.extractInputReferences, l.resolveState, e.Workers.CreateGroup("MemPool"), l.conflictDAG, e, l.errorHandler, mempoolv1.WithForkAllTransactions[ledger.BlockVoteRank](true))
l.memPool = mempoolv1.New(NewVM(l), l.resolveState, e.Workers.CreateGroup("MemPool"), l.conflictDAG, l.errorHandler, mempoolv1.WithForkAllTransactions[ledger.BlockVoteRank](true))
e.EvictionState.Events.SlotEvicted.Hook(l.memPool.Evict)

l.manaManager = mana.NewManager(l.apiProvider, l.resolveAccountOutput)
Expand Down
26 changes: 18 additions & 8 deletions pkg/protocol/engine/ledger/ledger/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ import (
"github.com/iotaledger/iota.go/v4/vm/stardust"
)

func (l *Ledger) extractInputReferences(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) {
type VM struct {
ledger *Ledger
}

func NewVM(ledger *Ledger) *VM {
return &VM{
ledger: ledger,
}
}

func (v *VM) Inputs(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) {
stardustTransaction, ok := transaction.(*iotago.Transaction)
if !ok {
return nil, iotago.ErrTxTypeInvalid
Expand All @@ -27,7 +37,7 @@ func (l *Ledger) extractInputReferences(transaction mempool.Transaction) (inputR
return inputReferences, nil
}

func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTransaction, resolvedInputStates []mempool.State) (executionContext context.Context, err error) {
func (v *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, resolvedInputStates []mempool.State) (executionContext context.Context, err error) {
signedStardustTransaction, ok := signedTransaction.(*iotago.SignedTransaction)
if !ok {
return nil, iotago.ErrTxTypeInvalid
Expand Down Expand Up @@ -57,7 +67,7 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra

bicInputSet := make(iotagovm.BlockIssuanceCreditInputSet)
for _, inp := range bicInputs {
accountData, exists, accountErr := l.accountsLedger.Account(inp.AccountID, commitmentInput.Slot)
accountData, exists, accountErr := v.ledger.accountsLedger.Account(inp.AccountID, commitmentInput.Slot)
if accountErr != nil {
return nil, ierrors.Join(iotago.ErrBICInputInvalid, ierrors.Wrapf(accountErr, "could not get BIC input for account %s in slot %d", inp.AccountID, commitmentInput.Slot))
}
Expand Down Expand Up @@ -87,7 +97,7 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra
accountID = iotago.AccountIDFromOutputID(outputID)
}

reward, _, _, rewardErr := l.sybilProtection.ValidatorReward(accountID, stakingFeature.StakedAmount, stakingFeature.StartEpoch, stakingFeature.EndEpoch)
reward, _, _, rewardErr := v.ledger.sybilProtection.ValidatorReward(accountID, stakingFeature.StakedAmount, stakingFeature.StartEpoch, stakingFeature.EndEpoch)
if rewardErr != nil {
return nil, ierrors.Wrapf(iotago.ErrFailedToClaimStakingReward, "failed to get Validator reward for AccountOutput %s at index %d (StakedAmount: %d, StartEpoch: %d, EndEpoch: %d", outputID, inp.Index, stakingFeature.StakedAmount, stakingFeature.StartEpoch, stakingFeature.EndEpoch)
}
Expand All @@ -102,10 +112,10 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra

delegationEnd := castOutput.EndEpoch
if delegationEnd == 0 {
delegationEnd = l.apiProvider.APIForSlot(commitmentInput.Slot).TimeProvider().EpochFromSlot(commitmentInput.Slot) - iotago.EpochIndex(1)
delegationEnd = v.ledger.apiProvider.APIForSlot(commitmentInput.Slot).TimeProvider().EpochFromSlot(commitmentInput.Slot) - iotago.EpochIndex(1)
}

reward, _, _, rewardErr := l.sybilProtection.DelegatorReward(castOutput.ValidatorAddress.AccountID(), castOutput.DelegatedAmount, castOutput.StartEpoch, delegationEnd)
reward, _, _, rewardErr := v.ledger.sybilProtection.DelegatorReward(castOutput.ValidatorAddress.AccountID(), castOutput.DelegatedAmount, castOutput.StartEpoch, delegationEnd)
if rewardErr != nil {
return nil, ierrors.Wrapf(iotago.ErrFailedToClaimDelegationReward, "failed to get Delegator reward for DelegationOutput %s at index %d (StakedAmount: %d, StartEpoch: %d, EndEpoch: %d", outputID, inp.Index, castOutput.DelegatedAmount, castOutput.StartEpoch, castOutput.EndEpoch)
}
Expand Down Expand Up @@ -133,7 +143,7 @@ func (l *Ledger) validateStardustTransaction(signedTransaction mempool.SignedTra
return executionContext, nil
}

func (l *Ledger) executeStardustVM(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
func (v *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
stardustTransaction, ok := transaction.(*iotago.Transaction)
if !ok {
return nil, iotago.ErrTxTypeInvalid
Expand Down Expand Up @@ -161,7 +171,7 @@ func (l *Ledger) executeStardustVM(executionContext context.Context, transaction

for index, output := range createdOutputs {
outputs = append(outputs, utxoledger.CreateOutput(
l.apiProvider,
v.ledger.apiProvider,
iotago.OutputIDFromTransactionIDAndIndex(transactionID, uint16(index)),
iotago.EmptyBlockID(),
0,
Expand Down
2 changes: 1 addition & 1 deletion pkg/protocol/engine/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] interface {

MarkAttachmentIncluded(blockID iotago.BlockID) bool

StateMetadata(reference iotago.Input) (state StateMetadata, err error)
StateMetadata(reference StateReference) (state StateMetadata, err error)

TransactionMetadata(id iotago.TransactionID) (transaction TransactionMetadata, exists bool)

Expand Down
2 changes: 0 additions & 2 deletions pkg/protocol/engine/mempool/state_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
)

type StateMetadata interface {
StateID() StateID

State() State

ConflictIDs() reactive.Set[iotago.TransactionID]
Expand Down
8 changes: 2 additions & 6 deletions pkg/protocol/engine/mempool/state_reference.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package mempool

import (
"github.com/iotaledger/iota-core/pkg/core/promise"
iotago "github.com/iotaledger/iota.go/v4"
)
import iotago "github.com/iotaledger/iota.go/v4"

// StateReferenceResolver is a function that resolves a StateReference to a State.
type StateReferenceResolver func(reference iotago.Input) *promise.Promise[State]
type StateReference = iotago.Input
8 changes: 8 additions & 0 deletions pkg/protocol/engine/mempool/state_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package mempool

import (
"github.com/iotaledger/iota-core/pkg/core/promise"
)

// StateResolver is a function that resolves a StateReference to a Promise with the State.
type StateResolver func(reference StateReference) *promise.Promise[State]
8 changes: 4 additions & 4 deletions pkg/protocol/engine/mempool/tests/testframework.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type TestFramework struct {
Instance mempool.MemPool[vote.MockedRank]
ConflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, vote.MockedRank]

referencesByAlias map[string]iotago.Input
referencesByAlias map[string]mempool.StateReference
stateIDByAlias map[string]mempool.StateID
signedTransactionByAlias map[string]mempool.SignedTransaction
transactionByAlias map[string]mempool.Transaction
Expand All @@ -39,7 +39,7 @@ func NewTestFramework(test *testing.T, instance mempool.MemPool[vote.MockedRank]
t := &TestFramework{
Instance: instance,
ConflictDAG: conflictDAG,
referencesByAlias: make(map[string]iotago.Input),
referencesByAlias: make(map[string]mempool.StateReference),
stateIDByAlias: make(map[string]mempool.StateID),
signedTransactionByAlias: make(map[string]mempool.SignedTransaction),
transactionByAlias: make(map[string]mempool.Transaction),
Expand Down Expand Up @@ -311,7 +311,7 @@ func (t *TestFramework) setupHookedEvents() {
})
}

func (t *TestFramework) stateReference(alias string) iotago.Input {
func (t *TestFramework) stateReference(alias string) mempool.StateReference {
if alias == "genesis" {
return &iotago.UTXOInput{}
}
Expand Down Expand Up @@ -393,7 +393,7 @@ func (t *TestFramework) Cleanup() {

iotago.UnregisterIdentifierAliases()

t.referencesByAlias = make(map[string]iotago.Input)
t.referencesByAlias = make(map[string]mempool.StateReference)
t.stateIDByAlias = make(map[string]mempool.StateID)
t.transactionByAlias = make(map[string]mempool.Transaction)
t.signedTransactionByAlias = make(map[string]mempool.SignedTransaction)
Expand Down
6 changes: 3 additions & 3 deletions pkg/protocol/engine/mempool/tests/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (s *SignedTransaction) String() string {

type Transaction struct {
id iotago.TransactionID
inputs []iotago.Input
inputs []mempool.StateReference
outputCount uint16
invalidTransaction bool
}
Expand All @@ -33,7 +33,7 @@ func NewSignedTransaction(transaction mempool.Transaction) *SignedTransaction {
}
}

func NewTransaction(outputCount uint16, inputs ...iotago.Input) *Transaction {
func NewTransaction(outputCount uint16, inputs ...mempool.StateReference) *Transaction {
return &Transaction{
id: tpkg.RandTransactionID(),
inputs: inputs,
Expand All @@ -45,7 +45,7 @@ func (t *Transaction) ID() (iotago.TransactionID, error) {
return t.id, nil
}

func (t *Transaction) Inputs() ([]iotago.Input, error) {
func (t *Transaction) Inputs() ([]mempool.StateReference, error) {
return t.inputs, nil
}

Expand Down
23 changes: 17 additions & 6 deletions pkg/protocol/engine/mempool/tests/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,33 @@ import (
"github.com/iotaledger/iota-core/pkg/protocol/engine/mempool"
)

func TransactionValidator(_ mempool.SignedTransaction, _ []mempool.State) (executionContext context.Context, err error) {
type VM struct{}

func (V *VM) Inputs(transaction mempool.Transaction) ([]mempool.StateReference, error) {
testTransaction, ok := transaction.(*Transaction)
if !ok {
return nil, ierrors.New("invalid transaction type in MockedVM")
}

return testTransaction.Inputs()
}

func (V *VM) ValidateSignatures(_ mempool.SignedTransaction, _ []mempool.State) (executionContext context.Context, err error) {
return context.Background(), nil
}

func TransactionExecutor(_ context.Context, inputTransaction mempool.Transaction) (outputs []mempool.State, err error) {
transaction, ok := inputTransaction.(*Transaction)
func (V *VM) Execute(_ context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
typedTransaction, ok := transaction.(*Transaction)
if !ok {
return nil, ierrors.New("invalid transaction type in MockedVM")
}

if transaction.invalidTransaction {
if typedTransaction.invalidTransaction {
return nil, ierrors.New("invalid transaction")
}

for i := uint16(0); i < transaction.outputCount; i++ {
id, err := transaction.ID()
for i := uint16(0); i < typedTransaction.outputCount; i++ {
id, err := typedTransaction.ID()
if err != nil {
return nil, err
}
Expand Down
2 changes: 0 additions & 2 deletions pkg/protocol/engine/mempool/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@ type Transaction interface {
// ID returns the identifier of the Transaction.
ID() (iotago.TransactionID, error)
}

type TransactionInputReferenceRetriever func(transaction Transaction) ([]iotago.Input, error)
64 changes: 27 additions & 37 deletions pkg/protocol/engine/mempool/v1/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,11 @@ import (

// MemPool is a component that manages the state of transactions that are not yet included in the ledger state.
type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct {
signedTransactionAttached *event.Event1[mempool.SignedTransactionMetadata]

transactionAttached *event.Event1[mempool.TransactionMetadata]

// executeStateTransition is the TransactionExecutor that is used to execute the state transition of transactions.
executeStateTransition mempool.TransactionExecutor
// vm is the virtual machine that is used to validate and execute transactions.
vm mempool.VM

// resolveState is the function that is used to request state from the ledger.
resolveState mempool.StateReferenceResolver

inputsOfTransaction mempool.TransactionInputReferenceRetriever

validateSignedTransaction mempool.TransactionValidator
resolveState mempool.StateResolver

// attachments is the storage that is used to keep track of the attachments of transactions.
attachments *memstorage.IndexedStorage[iotago.SlotIndex, iotago.BlockID, *SignedTransactionMetadata]
Expand Down Expand Up @@ -66,27 +58,22 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct {

optForkAllTransactions bool

apiProvider iotago.APIProvider
signedTransactionAttached *event.Event1[mempool.SignedTransactionMetadata]

transactionAttached *event.Event1[mempool.TransactionMetadata]
}

// New is the constructor of the MemPool.
func New[VoteRank conflictdag.VoteRankType[VoteRank]](
transactionValidator mempool.TransactionValidator,
transactionExecutor mempool.TransactionExecutor,
transactionInputReferenceRetriever mempool.TransactionInputReferenceRetriever,
stateResolver mempool.StateReferenceResolver,
vm mempool.VM,
stateResolver mempool.StateResolver,
workers *workerpool.Group,
conflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, VoteRank],
apiProvider iotago.APIProvider,
errorHandler func(error),
opts ...options.Option[MemPool[VoteRank]],
) *MemPool[VoteRank] {
return options.Apply(&MemPool[VoteRank]{
signedTransactionAttached: event.New1[mempool.SignedTransactionMetadata](),
transactionAttached: event.New1[mempool.TransactionMetadata](),
validateSignedTransaction: transactionValidator,
executeStateTransition: transactionExecutor,
inputsOfTransaction: transactionInputReferenceRetriever,
vm: vm,
resolveState: stateResolver,
attachments: memstorage.NewIndexedStorage[iotago.SlotIndex, iotago.BlockID, *SignedTransactionMetadata](),
cachedTransactions: shrinkingmap.New[iotago.TransactionID, *TransactionMetadata](),
Expand All @@ -95,8 +82,9 @@ func New[VoteRank conflictdag.VoteRankType[VoteRank]](
stateDiffs: shrinkingmap.New[iotago.SlotIndex, *StateDiff](),
executionWorkers: workers.CreatePool("executionWorkers", 1),
conflictDAG: conflictDAG,
apiProvider: apiProvider,
errorHandler: errorHandler,
signedTransactionAttached: event.New1[mempool.SignedTransactionMetadata](),
transactionAttached: event.New1[mempool.TransactionMetadata](),
}, opts, (*MemPool[VoteRank]).setup)
}

Expand Down Expand Up @@ -140,7 +128,7 @@ func (m *MemPool[VoteRank]) TransactionMetadata(id iotago.TransactionID) (transa
}

// StateMetadata returns the metadata of the output state with the given ID.
func (m *MemPool[VoteRank]) StateMetadata(stateReference iotago.Input) (state mempool.StateMetadata, err error) {
func (m *MemPool[VoteRank]) StateMetadata(stateReference mempool.StateReference) (state mempool.StateMetadata, err error) {
stateRequest, exists := m.cachedStateRequests.Get(stateReference.StateID())

// create a new request that does not wait for missing states
Expand Down Expand Up @@ -202,7 +190,7 @@ func (m *MemPool[VoteRank]) storeTransaction(signedTransaction mempool.SignedTra
return nil, false, false, ierrors.Errorf("blockID %d is older than last evicted slot %d", blockID.Slot(), m.lastEvictedSlot)
}

inputReferences, err := m.inputsOfTransaction(transaction)
inputReferences, err := m.vm.Inputs(transaction)
if err != nil {
return nil, false, false, ierrors.Wrap(err, "failed to get input references of transaction")
}
Expand Down Expand Up @@ -261,7 +249,7 @@ func (m *MemPool[VoteRank]) solidifyInputs(transaction *TransactionMetadata) {

func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context, transaction *TransactionMetadata) {
m.executionWorkers.Submit(func() {
if outputStates, err := m.executeStateTransition(executionContext, transaction.Transaction()); err != nil {
if outputStates, err := m.vm.Execute(executionContext, transaction.Transaction()); err != nil {
transaction.setInvalid(err)
} else {
transaction.setExecuted(outputStates)
Expand All @@ -273,11 +261,13 @@ func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context,

func (m *MemPool[VoteRank]) bookTransaction(transaction *TransactionMetadata) {
if m.optForkAllTransactions {
m.forkTransaction(transaction, ds.NewSet(lo.Map(transaction.inputs, (*StateMetadata).StateID)...))
m.forkTransaction(transaction, ds.NewSet(lo.Map(transaction.inputs, func(stateMetadata *StateMetadata) mempool.StateID {
return stateMetadata.state.StateID()
})...))
} else {
lo.ForEach(transaction.inputs, func(input *StateMetadata) {
input.OnDoubleSpent(func() {
m.forkTransaction(transaction, ds.NewSet(input.StateID()))
m.forkTransaction(transaction, ds.NewSet(input.state.StateID()))
})
})
}
Expand All @@ -299,7 +289,7 @@ func (m *MemPool[VoteRank]) forkTransaction(transactionMetadata *TransactionMeta

func (m *MemPool[VoteRank]) publishOutputStates(transaction *TransactionMetadata) {
for _, output := range transaction.outputs {
stateRequest, isNew := m.cachedStateRequests.GetOrCreate(output.StateID(), lo.NoVariadic(promise.New[*StateMetadata]))
stateRequest, isNew := m.cachedStateRequests.GetOrCreate(output.State().StateID(), lo.NoVariadic(promise.New[*StateMetadata]))
stateRequest.Resolve(output)

if isNew {
Expand All @@ -308,7 +298,7 @@ func (m *MemPool[VoteRank]) publishOutputStates(transaction *TransactionMetadata
}
}

func (m *MemPool[VoteRank]) requestState(stateRef iotago.Input, waitIfMissing ...bool) *promise.Promise[*StateMetadata] {
func (m *MemPool[VoteRank]) requestState(stateRef mempool.StateReference, waitIfMissing ...bool) *promise.Promise[*StateMetadata] {
return promise.New(func(p *promise.Promise[*StateMetadata]) {
request := m.resolveState(stateRef)

Expand Down Expand Up @@ -442,23 +432,23 @@ func (m *MemPool[VoteRank]) setupTransaction(transaction *TransactionMetadata) {
})
}

func (m *MemPool[VoteRank]) setupOutputState(state *StateMetadata) {
state.OnCommitted(func() {
if !m.cachedStateRequests.Delete(state.StateID(), state.HasNoSpenders) && m.cachedStateRequests.Has(state.StateID()) {
state.onAllSpendersRemoved(func() { m.cachedStateRequests.Delete(state.StateID(), state.HasNoSpenders) })
func (m *MemPool[VoteRank]) setupOutputState(stateMetadata *StateMetadata) {
stateMetadata.OnCommitted(func() {
if !m.cachedStateRequests.Delete(stateMetadata.state.StateID(), stateMetadata.HasNoSpenders) && m.cachedStateRequests.Has(stateMetadata.state.StateID()) {
stateMetadata.onAllSpendersRemoved(func() { m.cachedStateRequests.Delete(stateMetadata.state.StateID(), stateMetadata.HasNoSpenders) })
}
})

state.OnOrphaned(func() {
m.cachedStateRequests.Delete(state.StateID())
stateMetadata.OnOrphaned(func() {
m.cachedStateRequests.Delete(stateMetadata.state.StateID())
})
}

func (m *MemPool[VoteRank]) setupSignedTransaction(signedTransactionMetadata *SignedTransactionMetadata, transaction *TransactionMetadata) {
transaction.addSigningTransaction(signedTransactionMetadata)

transaction.OnSolid(func() {
executionContext, err := m.validateSignedTransaction(signedTransactionMetadata.SignedTransaction(), lo.Map(signedTransactionMetadata.transactionMetadata.inputs, (*StateMetadata).State))
executionContext, err := m.vm.ValidateSignatures(signedTransactionMetadata.SignedTransaction(), lo.Map(signedTransactionMetadata.transactionMetadata.inputs, (*StateMetadata).State))
if err != nil {
_ = signedTransactionMetadata.signaturesInvalid.Set(err)
return
Expand Down
Loading
Loading