Skip to content

Commit

Permalink
Feat: added alias for StateReference
Browse files Browse the repository at this point in the history
  • Loading branch information
hmoog committed Oct 3, 2023
1 parent 2338e8b commit e2ed66c
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 78 deletions.
4 changes: 2 additions & 2 deletions pkg/protocol/engine/ledger/ledger/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func NewVM(ledger *Ledger) *VM {
}
}

func (v *VM) TransactionInputs(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) {
func (v *VM) StateReferences(transaction mempool.Transaction) (inputReferences []iotago.Input, err error) {
stardustTransaction, ok := transaction.(*iotago.Transaction)
if !ok {
return nil, iotago.ErrTxTypeInvalid
Expand Down Expand Up @@ -143,7 +143,7 @@ func (v *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, res
return executionContext, nil
}

func (v *VM) ExecuteTransaction(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
func (v *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
stardustTransaction, ok := transaction.(*iotago.Transaction)
if !ok {
return nil, iotago.ErrTxTypeInvalid
Expand Down
2 changes: 1 addition & 1 deletion pkg/protocol/engine/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] interface {

MarkAttachmentIncluded(blockID iotago.BlockID) bool

StateMetadata(reference iotago.Input) (state StateMetadata, err error)
StateMetadata(reference StateReference) (state StateMetadata, err error)

TransactionMetadata(id iotago.TransactionID) (transaction TransactionMetadata, exists bool)

Expand Down
2 changes: 0 additions & 2 deletions pkg/protocol/engine/mempool/state_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
)

type StateMetadata interface {
StateID() StateID

State() State

ConflictIDs() reactive.Set[iotago.TransactionID]
Expand Down
8 changes: 2 additions & 6 deletions pkg/protocol/engine/mempool/state_reference.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package mempool

import (
"github.com/iotaledger/iota-core/pkg/core/promise"
iotago "github.com/iotaledger/iota.go/v4"
)
import iotago "github.com/iotaledger/iota.go/v4"

// StateReferenceResolver is a function that resolves a StateReference to a State.
type StateReferenceResolver func(reference iotago.Input) *promise.Promise[State]
type StateReference = iotago.Input
8 changes: 8 additions & 0 deletions pkg/protocol/engine/mempool/state_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package mempool

import (
"github.com/iotaledger/iota-core/pkg/core/promise"
)

// StateResolver is a function that resolves a StateReference to a Promise with the State.
type StateResolver func(reference StateReference) *promise.Promise[State]
8 changes: 4 additions & 4 deletions pkg/protocol/engine/mempool/tests/testframework.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type TestFramework struct {
Instance mempool.MemPool[vote.MockedRank]
ConflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, vote.MockedRank]

referencesByAlias map[string]iotago.Input
referencesByAlias map[string]mempool.StateReference
stateIDByAlias map[string]mempool.StateID
signedTransactionByAlias map[string]mempool.SignedTransaction
transactionByAlias map[string]mempool.Transaction
Expand All @@ -39,7 +39,7 @@ func NewTestFramework(test *testing.T, instance mempool.MemPool[vote.MockedRank]
t := &TestFramework{
Instance: instance,
ConflictDAG: conflictDAG,
referencesByAlias: make(map[string]iotago.Input),
referencesByAlias: make(map[string]mempool.StateReference),
stateIDByAlias: make(map[string]mempool.StateID),
signedTransactionByAlias: make(map[string]mempool.SignedTransaction),
transactionByAlias: make(map[string]mempool.Transaction),
Expand Down Expand Up @@ -311,7 +311,7 @@ func (t *TestFramework) setupHookedEvents() {
})
}

func (t *TestFramework) stateReference(alias string) iotago.Input {
func (t *TestFramework) stateReference(alias string) mempool.StateReference {
if alias == "genesis" {
return &iotago.UTXOInput{}
}
Expand Down Expand Up @@ -393,7 +393,7 @@ func (t *TestFramework) Cleanup() {

iotago.UnregisterIdentifierAliases()

t.referencesByAlias = make(map[string]iotago.Input)
t.referencesByAlias = make(map[string]mempool.StateReference)
t.stateIDByAlias = make(map[string]mempool.StateID)
t.transactionByAlias = make(map[string]mempool.Transaction)
t.signedTransactionByAlias = make(map[string]mempool.SignedTransaction)
Expand Down
6 changes: 3 additions & 3 deletions pkg/protocol/engine/mempool/tests/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (s *SignedTransaction) String() string {

type Transaction struct {
id iotago.TransactionID
inputs []iotago.Input
inputs []mempool.StateReference
outputCount uint16
invalidTransaction bool
}
Expand All @@ -33,7 +33,7 @@ func NewSignedTransaction(transaction mempool.Transaction) *SignedTransaction {
}
}

func NewTransaction(outputCount uint16, inputs ...iotago.Input) *Transaction {
func NewTransaction(outputCount uint16, inputs ...mempool.StateReference) *Transaction {
return &Transaction{
id: tpkg.RandTransactionID(),
inputs: inputs,
Expand All @@ -45,7 +45,7 @@ func (t *Transaction) ID() (iotago.TransactionID, error) {
return t.id, nil
}

func (t *Transaction) Inputs() ([]iotago.Input, error) {
func (t *Transaction) Inputs() ([]mempool.StateReference, error) {
return t.inputs, nil
}

Expand Down
5 changes: 2 additions & 3 deletions pkg/protocol/engine/mempool/tests/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@ import (
"github.com/iotaledger/hive.go/ierrors"
ledgertests "github.com/iotaledger/iota-core/pkg/protocol/engine/ledger/tests"
"github.com/iotaledger/iota-core/pkg/protocol/engine/mempool"
iotago "github.com/iotaledger/iota.go/v4"
)

type VM struct{}

func (V *VM) TransactionInputs(transaction mempool.Transaction) ([]iotago.Input, error) {
func (V *VM) StateReferences(transaction mempool.Transaction) ([]mempool.StateReference, error) {
return transaction.(*Transaction).Inputs()

Check failure on line 14 in pkg/protocol/engine/mempool/tests/vm.go

View workflow job for this annotation

GitHub Actions / golangci

[golangci] pkg/protocol/engine/mempool/tests/vm.go#L14

type assertion must be checked (forcetypeassert)
Raw output
pkg/protocol/engine/mempool/tests/vm.go:14:9: type assertion must be checked (forcetypeassert)
	return transaction.(*Transaction).Inputs()
	       ^
}

func (V *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, resolvedInputs []mempool.State) (executionContext context.Context, err error) {

Check failure on line 17 in pkg/protocol/engine/mempool/tests/vm.go

View workflow job for this annotation

GitHub Actions / golangci

[golangci] pkg/protocol/engine/mempool/tests/vm.go#L17

unused-parameter: parameter 'signedTransaction' seems to be unused, consider removing or renaming it as _ (revive)
Raw output
pkg/protocol/engine/mempool/tests/vm.go:17:33: unused-parameter: parameter 'signedTransaction' seems to be unused, consider removing or renaming it as _ (revive)
func (V *VM) ValidateSignatures(signedTransaction mempool.SignedTransaction, resolvedInputs []mempool.State) (executionContext context.Context, err error) {
                                ^
return context.Background(), nil
}

func (V *VM) ExecuteTransaction(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
func (V *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {

Check failure on line 21 in pkg/protocol/engine/mempool/tests/vm.go

View workflow job for this annotation

GitHub Actions / golangci

[golangci] pkg/protocol/engine/mempool/tests/vm.go#L21

unused-parameter: parameter 'executionContext' seems to be unused, consider removing or renaming it as _ (revive)
Raw output
pkg/protocol/engine/mempool/tests/vm.go:21:22: unused-parameter: parameter 'executionContext' seems to be unused, consider removing or renaming it as _ (revive)
func (V *VM) Execute(executionContext context.Context, transaction mempool.Transaction) (outputs []mempool.State, err error) {
                     ^
typedTransaction, ok := transaction.(*Transaction)
if !ok {
return nil, ierrors.New("invalid transaction type in MockedVM")
Expand Down
32 changes: 17 additions & 15 deletions pkg/protocol/engine/mempool/v1/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct {
vm mempool.VM

// resolveState is the function that is used to request state from the ledger.
resolveState mempool.StateReferenceResolver
resolveState mempool.StateResolver

// attachments is the storage that is used to keep track of the attachments of transactions.
attachments *memstorage.IndexedStorage[iotago.SlotIndex, iotago.BlockID, *SignedTransactionMetadata]
Expand Down Expand Up @@ -66,7 +66,7 @@ type MemPool[VoteRank conflictdag.VoteRankType[VoteRank]] struct {
// New is the constructor of the MemPool.
func New[VoteRank conflictdag.VoteRankType[VoteRank]](
vm mempool.VM,
stateResolver mempool.StateReferenceResolver,
stateResolver mempool.StateResolver,
workers *workerpool.Group,
conflictDAG conflictdag.ConflictDAG[iotago.TransactionID, mempool.StateID, VoteRank],
errorHandler func(error),
Expand Down Expand Up @@ -128,7 +128,7 @@ func (m *MemPool[VoteRank]) TransactionMetadata(id iotago.TransactionID) (transa
}

// StateMetadata returns the metadata of the output state with the given ID.
func (m *MemPool[VoteRank]) StateMetadata(stateReference iotago.Input) (state mempool.StateMetadata, err error) {
func (m *MemPool[VoteRank]) StateMetadata(stateReference mempool.StateReference) (state mempool.StateMetadata, err error) {
stateRequest, exists := m.cachedStateRequests.Get(stateReference.StateID())

// create a new request that does not wait for missing states
Expand Down Expand Up @@ -190,7 +190,7 @@ func (m *MemPool[VoteRank]) storeTransaction(signedTransaction mempool.SignedTra
return nil, false, false, ierrors.Errorf("blockID %d is older than last evicted slot %d", blockID.Slot(), m.lastEvictedSlot)
}

inputReferences, err := m.vm.TransactionInputs(transaction)
inputReferences, err := m.vm.StateReferences(transaction)
if err != nil {
return nil, false, false, ierrors.Wrap(err, "failed to get input references of transaction")
}
Expand Down Expand Up @@ -249,7 +249,7 @@ func (m *MemPool[VoteRank]) solidifyInputs(transaction *TransactionMetadata) {

func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context, transaction *TransactionMetadata) {
m.executionWorkers.Submit(func() {
if outputStates, err := m.vm.ExecuteTransaction(executionContext, transaction.Transaction()); err != nil {
if outputStates, err := m.vm.Execute(executionContext, transaction.Transaction()); err != nil {
transaction.setInvalid(err)
} else {
transaction.setExecuted(outputStates)
Expand All @@ -261,11 +261,13 @@ func (m *MemPool[VoteRank]) executeTransaction(executionContext context.Context,

func (m *MemPool[VoteRank]) bookTransaction(transaction *TransactionMetadata) {
if m.optForkAllTransactions {
m.forkTransaction(transaction, ds.NewSet(lo.Map(transaction.inputs, (*StateMetadata).StateID)...))
m.forkTransaction(transaction, ds.NewSet(lo.Map(transaction.inputs, func(stateMetadata *StateMetadata) mempool.StateID {
return stateMetadata.state.StateID()
})...))
} else {
lo.ForEach(transaction.inputs, func(input *StateMetadata) {
input.OnDoubleSpent(func() {
m.forkTransaction(transaction, ds.NewSet(input.StateID()))
m.forkTransaction(transaction, ds.NewSet(input.state.StateID()))
})
})
}
Expand All @@ -287,7 +289,7 @@ func (m *MemPool[VoteRank]) forkTransaction(transactionMetadata *TransactionMeta

func (m *MemPool[VoteRank]) publishOutputStates(transaction *TransactionMetadata) {
for _, output := range transaction.outputs {
stateRequest, isNew := m.cachedStateRequests.GetOrCreate(output.StateID(), lo.NoVariadic(promise.New[*StateMetadata]))
stateRequest, isNew := m.cachedStateRequests.GetOrCreate(output.State().StateID(), lo.NoVariadic(promise.New[*StateMetadata]))
stateRequest.Resolve(output)

if isNew {
Expand All @@ -296,7 +298,7 @@ func (m *MemPool[VoteRank]) publishOutputStates(transaction *TransactionMetadata
}
}

func (m *MemPool[VoteRank]) requestState(stateRef iotago.Input, waitIfMissing ...bool) *promise.Promise[*StateMetadata] {
func (m *MemPool[VoteRank]) requestState(stateRef mempool.StateReference, waitIfMissing ...bool) *promise.Promise[*StateMetadata] {
return promise.New(func(p *promise.Promise[*StateMetadata]) {
request := m.resolveState(stateRef)

Expand Down Expand Up @@ -430,15 +432,15 @@ func (m *MemPool[VoteRank]) setupTransaction(transaction *TransactionMetadata) {
})
}

func (m *MemPool[VoteRank]) setupOutputState(state *StateMetadata) {
state.OnCommitted(func() {
if !m.cachedStateRequests.Delete(state.StateID(), state.HasNoSpenders) && m.cachedStateRequests.Has(state.StateID()) {
state.onAllSpendersRemoved(func() { m.cachedStateRequests.Delete(state.StateID(), state.HasNoSpenders) })
func (m *MemPool[VoteRank]) setupOutputState(stateMetadata *StateMetadata) {
stateMetadata.OnCommitted(func() {
if !m.cachedStateRequests.Delete(stateMetadata.state.StateID(), stateMetadata.HasNoSpenders) && m.cachedStateRequests.Has(stateMetadata.state.StateID()) {
stateMetadata.onAllSpendersRemoved(func() { m.cachedStateRequests.Delete(stateMetadata.state.StateID(), stateMetadata.HasNoSpenders) })
}
})

state.OnOrphaned(func() {
m.cachedStateRequests.Delete(state.StateID())
stateMetadata.OnOrphaned(func() {
m.cachedStateRequests.Delete(stateMetadata.state.StateID())
})
}

Expand Down
24 changes: 12 additions & 12 deletions pkg/protocol/engine/mempool/v1/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) {

ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0))
conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](func() int { return 0 })
mempoolInstance := New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] {
memPoolInstance := New[vote.MockedRank](new(mempooltests.VM), func(reference mempool.StateReference) *promise.Promise[mempool.State] {
return ledgerState.ResolveOutputState(reference.StateID())
}, workers, conflictDAG, func(error) {})

tf := mempooltests.NewTestFramework(t, mempoolInstance, conflictDAG, ledgerState, workers)
tf := mempooltests.NewTestFramework(t, memPoolInstance, conflictDAG, ledgerState, workers)

issueTransactions := func(startIndex, transactionCount int, prevStateAlias string) (int, string) {
index := startIndex
Expand All @@ -58,7 +58,7 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) {
tf.CommitSlot(iotago.SlotIndex(index))
tf.Instance.Evict(iotago.SlotIndex(index))

require.Nil(t, mempoolInstance.attachments.Get(iotago.SlotIndex(index), false))
require.Nil(t, memPoolInstance.attachments.Get(iotago.SlotIndex(index), false))

}
return index, prevStateAlias
Expand All @@ -70,19 +70,19 @@ func TestMempoolV1_ResourceCleanup(t *testing.T) {
txIndex, prevStateAlias := issueTransactions(1, 10, "genesis")
tf.WaitChildren()

require.Equal(t, 0, mempoolInstance.cachedTransactions.Size())
require.Equal(t, 0, mempoolInstance.stateDiffs.Size())
require.Equal(t, 0, mempoolInstance.cachedStateRequests.Size())
require.Equal(t, 0, memPoolInstance.cachedTransactions.Size())
require.Equal(t, 0, memPoolInstance.stateDiffs.Size())
require.Equal(t, 0, memPoolInstance.cachedStateRequests.Size())

txIndex, prevStateAlias = issueTransactions(txIndex, 10, prevStateAlias)
tf.WaitChildren()

require.Equal(t, 0, mempoolInstance.cachedTransactions.Size())
require.Equal(t, 0, mempoolInstance.stateDiffs.Size())
require.Equal(t, 0, mempoolInstance.cachedStateRequests.Size())
require.Equal(t, 0, memPoolInstance.cachedTransactions.Size())
require.Equal(t, 0, memPoolInstance.stateDiffs.Size())
require.Equal(t, 0, memPoolInstance.cachedStateRequests.Size())

attachmentsSlotCount := 0
mempoolInstance.attachments.ForEach(func(index iotago.SlotIndex, storage *shrinkingmap.ShrinkingMap[iotago.BlockID, *SignedTransactionMetadata]) {
memPoolInstance.attachments.ForEach(func(index iotago.SlotIndex, storage *shrinkingmap.ShrinkingMap[iotago.BlockID, *SignedTransactionMetadata]) {
attachmentsSlotCount++
})

Expand All @@ -103,7 +103,7 @@ func newTestFramework(t *testing.T) *mempooltests.TestFramework {
ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0))
conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](account.NewAccounts().SelectCommittee().SeatCount)

return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] {
return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference mempool.StateReference) *promise.Promise[mempool.State] {
return ledgerState.ResolveOutputState(reference.StateID())
}, workers, conflictDAG, func(error) {}), conflictDAG, ledgerState, workers)
}
Expand All @@ -114,7 +114,7 @@ func newForkingTestFramework(t *testing.T) *mempooltests.TestFramework {
ledgerState := ledgertests.New(ledgertests.NewMockedState(iotago.TransactionID{}, 0))
conflictDAG := conflictdagv1.New[iotago.TransactionID, mempool.StateID, vote.MockedRank](account.NewAccounts().SelectCommittee().SeatCount)

return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference iotago.Input) *promise.Promise[mempool.State] {
return mempooltests.NewTestFramework(t, New[vote.MockedRank](new(mempooltests.VM), func(reference mempool.StateReference) *promise.Promise[mempool.State] {
return ledgerState.ResolveOutputState(reference.StateID())
}, workers, conflictDAG, func(error) {}, WithForkAllTransactions[vote.MockedRank](true)), conflictDAG, ledgerState, workers)
}
24 changes: 12 additions & 12 deletions pkg/protocol/engine/mempool/v1/state_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ func (s *StateDiff) Mutations() ads.Set[iotago.TransactionID] {
}

func (s *StateDiff) updateCompactedStateChanges(transaction *TransactionMetadata, direction int) {
transaction.Inputs().Range(func(input mempool.StateMetadata) {
s.compactStateChanges(input, s.stateUsageCounters.Compute(input.StateID(), func(currentValue int, _ bool) int {
for _, input := range transaction.inputs {
s.compactStateChanges(input, s.stateUsageCounters.Compute(input.State().StateID(), func(currentValue int, _ bool) int {
return currentValue - direction
}))
})
}

transaction.Outputs().Range(func(output mempool.StateMetadata) {
s.compactStateChanges(output, s.stateUsageCounters.Compute(output.StateID(), func(currentValue int, _ bool) int {
for _, output := range transaction.outputs {
s.compactStateChanges(output, s.stateUsageCounters.Compute(output.State().StateID(), func(currentValue int, _ bool) int {
return currentValue + direction
}))
})
}
}

func (s *StateDiff) AddTransaction(transaction *TransactionMetadata, errorHandler func(error)) error {
Expand Down Expand Up @@ -97,19 +97,19 @@ func (s *StateDiff) RollbackTransaction(transaction *TransactionMetadata) error
return nil
}

func (s *StateDiff) compactStateChanges(output mempool.StateMetadata, newValue int) {
if output.State().Type() != iotago.InputUTXO {
func (s *StateDiff) compactStateChanges(output *StateMetadata, newValue int) {
if output.state.Type() != iotago.InputUTXO {
return
}

switch {
case newValue > 0:
s.createdOutputs.Set(output.StateID(), output)
s.createdOutputs.Set(output.State().StateID(), output)
case newValue < 0:
s.spentOutputs.Set(output.StateID(), output)
s.spentOutputs.Set(output.State().StateID(), output)
default:
s.createdOutputs.Delete(output.StateID())
s.spentOutputs.Delete(output.StateID())
s.createdOutputs.Delete(output.State().StateID())
s.spentOutputs.Delete(output.State().StateID())
}
}

Expand Down
8 changes: 0 additions & 8 deletions pkg/protocol/engine/mempool/v1/state_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,6 @@ func (s *StateMetadata) setup(optSource ...*TransactionMetadata) *StateMetadata
return s
}

func (s *StateMetadata) StateID() mempool.StateID {
return s.state.StateID()
}

func (s *StateMetadata) Type() iotago.StateType {
return iotago.InputUTXO
}

func (s *StateMetadata) State() mempool.State {
return s.state
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocol/engine/mempool/v1/transaction_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

type TransactionMetadata struct {
id iotago.TransactionID
inputReferences []iotago.Input
inputReferences []mempool.StateReference
inputs []*StateMetadata
outputs []*StateMetadata
transaction mempool.Transaction
Expand Down Expand Up @@ -57,7 +57,7 @@ func (t *TransactionMetadata) ValidAttachments() []iotago.BlockID {
return t.validAttachments.Keys()
}

func NewTransactionMetadata(transaction mempool.Transaction, referencedInputs []iotago.Input) (*TransactionMetadata, error) {
func NewTransactionMetadata(transaction mempool.Transaction, referencedInputs []mempool.StateReference) (*TransactionMetadata, error) {
transactionID, transactionIDErr := transaction.ID()
if transactionIDErr != nil {
return nil, ierrors.Errorf("failed to retrieve transaction ID: %w", transactionIDErr)
Expand Down
Loading

0 comments on commit e2ed66c

Please sign in to comment.