Skip to content

Commit

Permalink
Merge pull request #2775 from jorgemmsilva/fix/mempool-nonce
Browse files Browse the repository at this point in the history
feat: mempool nonce override
  • Loading branch information
jorgemmsilva authored Aug 7, 2023
2 parents a3e3a80 + eb463ce commit 672f92d
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 53 deletions.
81 changes: 29 additions & 52 deletions packages/chain/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import (
"time"

"github.com/samber/lo"
"golang.org/x/exp/slices"

"github.com/iotaledger/hive.go/logger"
consGR "github.com/iotaledger/wasp/packages/chain/cons/cons_gr"
Expand Down Expand Up @@ -129,7 +128,7 @@ type mempoolImpl struct {
tangleTime time.Time
timePool TimePool
onLedgerPool RequestPool[isc.OnLedgerRequest]
offLedgerPool RequestPool[isc.OffLedgerRequest]
offLedgerPool *TypedPoolByNonce[isc.OffLedgerRequest]
distSync gpa.GPA
chainHeadAO *isc.AliasOutputWithID
chainHeadState state.State
Expand Down Expand Up @@ -214,7 +213,7 @@ func New(
tangleTime: time.Time{},
timePool: NewTimePool(metrics.SetTimePoolSize, log.Named("TIM")),
onLedgerPool: NewTypedPool[isc.OnLedgerRequest](waitReq, metrics.SetOnLedgerPoolSize, metrics.SetOnLedgerReqTime, log.Named("ONL")),
offLedgerPool: NewTypedPool[isc.OffLedgerRequest](waitReq, metrics.SetOffLedgerPoolSize, metrics.SetOffLedgerReqTime, log.Named("OFF")),
offLedgerPool: NewTypedPoolByNonce[isc.OffLedgerRequest](waitReq, metrics.SetOffLedgerPoolSize, metrics.SetOffLedgerReqTime, log.Named("OFF")),
chainHeadAO: nil,
serverNodesUpdatedPipe: pipe.NewInfinitePipe[*reqServerNodesUpdated](),
serverNodes: []*cryptolib.PublicKey{},
Expand Down Expand Up @@ -480,11 +479,11 @@ func (mpi *mempoolImpl) shouldAddOffledgerRequest(req isc.OffLedgerRequest) erro
return fmt.Errorf("bad nonce, expected: %d", accountNonce)
}

governanceState := governance.NewStateAccess(mpi.chainHeadState)
// check user has on-chain balance
accountsState := accounts.NewStateAccess(mpi.chainHeadState)
if !accountsState.AccountExists(req.SenderAccount()) {
// make an exception for gov calls (sender is chan owner and target is gov contract)
governanceState := governance.NewStateAccess(mpi.chainHeadState)
chainOwner := governanceState.ChainOwnerID()
isGovRequest := req.SenderAccount().Equals(chainOwner) && req.CallTarget().Contract == governance.Contract.Hname()
if !isGovRequest {
Expand Down Expand Up @@ -530,17 +529,12 @@ func (mpi *mempoolImpl) handleConsensusProposal(recv *reqConsensusProposal) {
mpi.handleConsensusProposalForChainHead(recv)
}

type reqRefNonce struct {
ref *isc.RequestRef
nonce uint64
}

func (mpi *mempoolImpl) refsToPropose() []*isc.RequestRef {
//
// The case for matching ChainHeadAO and request BaseAO
reqRefs := []*isc.RequestRef{}
if !mpi.tangleTime.IsZero() { // Wait for tangle-time to process the on ledger requests.
mpi.onLedgerPool.Filter(func(request isc.OnLedgerRequest, ts time.Time) bool {
mpi.onLedgerPool.Filter(func(request isc.OnLedgerRequest, _ time.Time) bool {
if isc.RequestIsExpired(request, mpi.tangleTime) {
return false // Drop it from the mempool
}
Expand All @@ -551,53 +545,36 @@ func (mpi *mempoolImpl) refsToPropose() []*isc.RequestRef {
})
}

expectedAccountNonces := map[string]uint64{} // string is isc.AgentID.String()
requestsNonces := map[string][]reqRefNonce{} // string is isc.AgentID.String()

mpi.offLedgerPool.Filter(func(request isc.OffLedgerRequest, ts time.Time) bool {
ref := isc.RequestRefFromRequest(request)
reqRefs = append(reqRefs, ref)

// collect the nonces for each account
senderKey := request.SenderAccount().String()
_, ok := expectedAccountNonces[senderKey]
if !ok {
// get the current state nonce so we can detect gaps with it
expectedAccountNonces[senderKey] = mpi.nonce(request.SenderAccount())
mpi.offLedgerPool.Iterate(func(account string, entries []*OrderedPoolEntry[isc.OffLedgerRequest]) {
agentID, err := isc.AgentIDFromString(account)
if err != nil {
panic(fmt.Errorf("invalid agentID string: %s", err.Error()))
}
requestsNonces[senderKey] = append(requestsNonces[senderKey], reqRefNonce{ref: ref, nonce: request.Nonce()})

return true // Keep them for now
})

// remove any gaps in the nonces of each account
{
doNotPropose := []*isc.RequestRef{}
for account, refNonces := range requestsNonces {
// sort by nonce
slices.SortFunc(refNonces, func(a, b reqRefNonce) bool {
return a.nonce < b.nonce
})
// check for gaps with the state nonce
if expectedAccountNonces[account] != refNonces[0].nonce {
// if the first one doesn't match the nonce required from the state, don't propose any of the following
for _, ref := range refNonces {
doNotPropose = append(doNotPropose, ref.ref)
}
accountNonce := mpi.nonce(agentID)
for _, e := range entries {
reqNonce := e.req.Nonce()
if reqNonce < accountNonce {
// nonce too old, delete
mpi.log.Debugf("refsToPropose, account: %s, removing old nonce from pool: %d", account, e.req.Nonce())
mpi.offLedgerPool.Remove(e.req)
}
if e.old {
// this request was marked as "old", do not propose it
mpi.log.Debugf("refsToPropose, account: %s, skipping old request: %s", account, e.req.ID().String())
continue
}
// check for gaps within the request list
for i := 1; i < len(refNonces); i++ {
if refNonces[i].nonce != refNonces[i-1].nonce+1 {
doNotPropose = append(doNotPropose, refNonces[i].ref)
}
if reqNonce == accountNonce {
// expected nonce, add it to the list to propose
mpi.log.Debugf("refsToPropose, account: %s, proposing reqID %s with nonce %d: d", account, e.req.ID().String(), e.req.Nonce())
reqRefs = append(reqRefs, isc.RequestRefFromRequest(e.req))
accountNonce++ // increment the account nonce to match the next valid request
}
if reqNonce > accountNonce {
mpi.log.Debugf("refsToPropose, account: %s, req %s has a nonce %d which is too high, won't be proposed", account, e.req.ID().String(), e.req.Nonce())
return // no more valid nonces for this account, continue to the next account
}
}
// remove undesirable requests from the proposal
reqRefs = lo.Filter(reqRefs, func(x *isc.RequestRef, _ int) bool {
return !slices.Contains(doNotPropose, x)
})
}
})

return reqRefs
}
Expand Down
68 changes: 68 additions & 0 deletions packages/chain/mempool/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,74 @@ func TestMempoolsNonceGaps(t *testing.T) {
// nonce 10 was never proposed
}

func TestMempoolOverrideNonce(t *testing.T) {
// 1 node setup
// send nonce 0
// send another request with the same nonce 0
// assert the last request is proposed
te := newEnv(t, 1, 0, true)
defer te.close()

tangleTime := time.Now()
for _, node := range te.mempools {
node.ServerNodesUpdated(te.peerPubKeys, te.peerPubKeys)
node.TangleTimeUpdated(tangleTime)
}
awaitTrackHeadChannels := make([]<-chan bool, len(te.mempools))
// deposit some funds so off-ledger requests can go through
t.Log("TrackNewChainHead")
for i, node := range te.mempools {
awaitTrackHeadChannels[i] = node.TrackNewChainHead(te.stateForAO(i, te.originAO), nil, te.originAO, []state.Block{}, []state.Block{})
}
for i := range te.mempools {
<-awaitTrackHeadChannels[i]
}

output := transaction.BasicOutputFromPostData(
te.governor.Address(),
isc.HnameNil,
isc.RequestParameters{
TargetAddress: te.chainID.AsAddress(),
Assets: isc.NewAssetsBaseTokens(10 * isc.Million),
},
)
onLedgerReq, err := isc.OnLedgerFromUTXO(output, tpkg.RandOutputID(uint16(0)))
require.NoError(t, err)
for _, node := range te.mempools {
node.ReceiveOnLedgerRequest(onLedgerReq)
}
currentAO := blockFn(te, []isc.Request{onLedgerReq}, te.originAO, tangleTime)

initialReq := isc.NewOffLedgerRequest(
isc.RandomChainID(),
isc.Hn("foo"),
isc.Hn("bar"),
dict.New(),
0,
gas.LimitsDefault.MaxGasPerRequest,
).Sign(te.governor)

require.NoError(t, te.mempools[0].ReceiveOffLedgerRequest(initialReq))
time.Sleep(200 * time.Millisecond) // give some time for the requests to reach the pool

overwritingReq := isc.NewOffLedgerRequest(
isc.RandomChainID(),
isc.Hn("baz"),
isc.Hn("bar"),
dict.New(),
0,
gas.LimitsDefault.MaxGasPerRequest,
).Sign(te.governor)

require.NoError(t, te.mempools[0].ReceiveOffLedgerRequest(overwritingReq))
time.Sleep(200 * time.Millisecond) // give some time for the requests to reach the pool
reqRefs := <-te.mempools[0].ConsensusProposalAsync(te.ctx, currentAO)
proposedReqs := <-te.mempools[0].ConsensusRequestsAsync(te.ctx, reqRefs)
require.Len(t, proposedReqs, 1)
require.Equal(t, overwritingReq, proposedReqs[0])
require.NotEqual(t, initialReq, proposedReqs[0])
}

////////////////////////////////////////////////////////////////////////////////
// testEnv

Expand Down
169 changes: 169 additions & 0 deletions packages/chain/mempool/typed_pool_by_nonce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2020 IOTA Stiftung
// SPDX-License-Identifier: Apache-2.0

package mempool

import (
"fmt"
"time"

"golang.org/x/exp/slices"

"github.com/iotaledger/hive.go/ds/shrinkingmap"
"github.com/iotaledger/hive.go/logger"
"github.com/iotaledger/wasp/packages/isc"
)

// keeps a map of requests ordered by nonce for each account
type TypedPoolByNonce[V isc.OffLedgerRequest] struct {
waitReq WaitReq
refLUT *shrinkingmap.ShrinkingMap[isc.RequestRefKey, *OrderedPoolEntry[V]]
// reqsByAcountOrdered keeps an ordered map of reqsByAcountOrdered for each account by nonce
reqsByAcountOrdered *shrinkingmap.ShrinkingMap[string, []*OrderedPoolEntry[V]] // string is isc.AgentID.String()
sizeMetric func(int)
timeMetric func(time.Duration)
log *logger.Logger
}

var _ RequestPool[isc.OffLedgerRequest] = &TypedPoolByNonce[isc.OffLedgerRequest]{}

func NewTypedPoolByNonce[V isc.OffLedgerRequest](waitReq WaitReq, sizeMetric func(int), timeMetric func(time.Duration), log *logger.Logger) *TypedPoolByNonce[V] {
return &TypedPoolByNonce[V]{
waitReq: waitReq,
reqsByAcountOrdered: shrinkingmap.New[string, []*OrderedPoolEntry[V]](),
refLUT: shrinkingmap.New[isc.RequestRefKey, *OrderedPoolEntry[V]](),
sizeMetric: sizeMetric,
timeMetric: timeMetric,
log: log,
}
}

type OrderedPoolEntry[V isc.OffLedgerRequest] struct {
req V
old bool
ts time.Time
}

func (p *TypedPoolByNonce[V]) Has(reqRef *isc.RequestRef) bool {
return p.refLUT.Has(reqRef.AsKey())
}

func (p *TypedPoolByNonce[V]) Get(reqRef *isc.RequestRef) V {
entry, exists := p.refLUT.Get(reqRef.AsKey())
if !exists {
return *new(V)
}
return entry.req
}

func (p *TypedPoolByNonce[V]) Add(request V) {
ref := isc.RequestRefFromRequest(request)
entry := &OrderedPoolEntry[V]{req: request, ts: time.Now()}
account := request.SenderAccount().String()

if !p.refLUT.Set(ref.AsKey(), entry) {
p.log.Debugf("NOT ADDED, already exists. reqID: %v as key=%v, senderAccount: ", request.ID(), ref, account)
return // not added already exists
}

defer func() {
p.log.Debugf("ADD %v as key=%v, senderAccount: ", request.ID(), ref, account)
p.sizeMetric(p.refLUT.Size())
p.waitReq.MarkAvailable(request)
}()

reqsForAcount, exists := p.reqsByAcountOrdered.Get(account)
if !exists {
// no other requests for this account
p.reqsByAcountOrdered.Set(account, []*OrderedPoolEntry[V]{entry})
return
}

// add to the account requests, keep the slice ordered

// find the index where the new entry should be added
index, exists := slices.BinarySearchFunc(reqsForAcount, entry,
func(a, b *OrderedPoolEntry[V]) int {
aNonce := a.req.Nonce()
bNonce := b.req.Nonce()
if aNonce == bNonce {
return 0
}
if aNonce > bNonce {
return 1
}
return -1
},
)
if exists {
// same nonce, mark the existing request with overlapping nonce as "old", place the new one
// NOTE: do not delete the request here, as it might already be part of an on-going consensus round
reqsForAcount[index].old = true
}

reqsForAcount = append(reqsForAcount, entry) // add to the end of the list (thus extending the array)

// make room if target position is not at the end
if index != len(reqsForAcount)+1 {
copy(reqsForAcount[index+1:], reqsForAcount[index:])
reqsForAcount[index] = entry
}
p.reqsByAcountOrdered.Set(account, reqsForAcount)
}

func (p *TypedPoolByNonce[V]) Remove(request V) {
refKey := isc.RequestRefFromRequest(request).AsKey()
entry, exists := p.refLUT.Get(refKey)
if !exists {
return // does not exist
}
defer func() {
p.sizeMetric(p.refLUT.Size())
p.timeMetric(time.Since(entry.ts))
}()
if p.refLUT.Delete(refKey) {
p.log.Debugf("DEL %v as key=%v", request.ID(), refKey)
}
account := entry.req.SenderAccount().String()
reqsByAccount, exists := p.reqsByAcountOrdered.Get(account)
if !exists {
p.log.Error("inconsistency trying to DEL %v as key=%v, no request list for account %s", request.ID(), refKey, account)
return
}
// find the request in the accounts map
indexToDel := slices.IndexFunc(reqsByAccount, func(e *OrderedPoolEntry[V]) bool {
return true
})
if indexToDel == -1 {
p.log.Error("inconsistency trying to DEL %v as key=%v, request not found in list for account %s", request.ID(), refKey, account)
return
}
if len(reqsByAccount) == 1 { // just remove the entire array for the account
p.reqsByAcountOrdered.Delete(account)
return
}
reqsByAccount[indexToDel] = nil // remove the pointer reference to allow GC of the entry object
reqsByAccount = slices.Delete(reqsByAccount, indexToDel, indexToDel+1)
p.reqsByAcountOrdered.Set(account, reqsByAccount)
}

func (p *TypedPoolByNonce[V]) Iterate(f func(account string, requests []*OrderedPoolEntry[V])) {
p.reqsByAcountOrdered.ForEach(func(acc string, reqs []*OrderedPoolEntry[V]) bool {
f(acc, reqs)
return true
})
}

func (p *TypedPoolByNonce[V]) Filter(predicate func(request V, ts time.Time) bool) {
p.refLUT.ForEach(func(refKey isc.RequestRefKey, entry *OrderedPoolEntry[V]) bool {
if !predicate(entry.req, entry.ts) {
p.Remove(entry.req)
}
return true
})
p.sizeMetric(p.refLUT.Size())
}

func (p *TypedPoolByNonce[V]) StatusString() string {
return fmt.Sprintf("{|req|=%d}", p.refLUT.Size())
}
2 changes: 1 addition & 1 deletion packages/vm/vmimpl/runreq.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"github.com/iotaledger/wasp/packages/vm/vmexceptions"
)

// runRequest processes a single isc.Request in the batch
// runRequest processes a single isc.Request in the batch, returning an error means the request will be skipped
func (vmctx *vmContext) runRequest(req isc.Request, requestIndex uint16, maintenanceMode bool) (
res *vm.RequestResult,
unprocessableToRetry []isc.OnLedgerRequest,
Expand Down

0 comments on commit 672f92d

Please sign in to comment.