Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: increment sequence number at every call and create #102

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion x/evm/ante/sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim
}

// set a flag in context to indicate that sequence has been incremented in ante handler
ctx = ctx.WithValue(ContextKeySequenceIncremented, true)
incremented := true // use pointer to enable revert after first call
ctx = ctx.WithValue(ContextKeySequenceIncremented, &incremented)
return next(ctx, tx, simulate)
}
36 changes: 26 additions & 10 deletions x/evm/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
}

// handle cosmos<>evm different sequence increment logic
ctx, err = ms.handleSequenceIncremented(ctx, sender, true)
err = ms.handleSequenceIncremented(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,7 +98,7 @@
}

// handle cosmos<>evm different sequence increment logic
ctx, err = ms.handleSequenceIncremented(ctx, sender, true)
err = ms.handleSequenceIncremented(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -163,7 +163,7 @@
}

// handle cosmos<>evm different sequence increment logic
ctx, err = ms.handleSequenceIncremented(ctx, sender, false)
err = ms.handleSequenceIncremented(ctx, sender, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -258,20 +258,36 @@

// In the Cosmos SDK, the sequence number is incremented in the ante handler.
// In the EVM, the sequence number is incremented during the execution of create and create2 messages.
// However, for call messages, the sequence number is incremented in the ante handler like the Cosmos SDK.
// To prevent double incrementing the sequence number during EVM execution, we need to decrement it here for create messages.
func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sdk.AccAddress, isCreate bool) (context.Context, error) {
//
// If the sequence number is already incremented in the ante handler and the message is create, decrement the sequence number to prevent double incrementing.
// If the sequence number is not incremented in the ante handler and the message is call, increment the sequence number to ensure proper sequencing.
func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sdk.AccAddress, isCreate bool) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
if sdkCtx.Value(evmante.ContextKeySequenceIncremented) == nil {
return nil
}

// decrement sequence of the sender
if isCreate && sdkCtx.Value(evmante.ContextKeySequenceIncremented) != nil {
incremented := sdkCtx.Value(evmante.ContextKeySequenceIncremented).(*bool)
if isCreate && *incremented {
// if the sequence is already incremented, decrement it to prevent double incrementing the sequence number at create.
acc := k.accountKeeper.GetAccount(ctx, sender)
if err := acc.SetSequence(acc.GetSequence() - 1); err != nil {
return ctx, err
return err
}

Check warning on line 276 in x/evm/keeper/msg_server.go

View check run for this annotation

Codecov / codecov/patch

x/evm/keeper/msg_server.go#L275-L276

Added lines #L275 - L276 were not covered by tests

k.accountKeeper.SetAccount(ctx, acc)
} else if !isCreate && !*incremented {
// if the sequence is not incremented and the message is call, increment the sequence number.
acc := k.accountKeeper.GetAccount(ctx, sender)
if err := acc.SetSequence(acc.GetSequence() + 1); err != nil {
return err

Check warning on line 283 in x/evm/keeper/msg_server.go

View check run for this annotation

Codecov / codecov/patch

x/evm/keeper/msg_server.go#L283

Added line #L283 was not covered by tests
}

k.accountKeeper.SetAccount(ctx, acc)
}

return sdkCtx.WithValue(evmante.ContextKeySequenceIncremented, nil), nil
// set the flag to false
*incremented = false

return nil
}
129 changes: 129 additions & 0 deletions x/evm/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"

"github.com/holiman/uint256"

evmante "github.com/initia-labs/minievm/x/evm/ante"
"github.com/initia-labs/minievm/x/evm/contracts/counter"
"github.com/initia-labs/minievm/x/evm/keeper"
"github.com/initia-labs/minievm/x/evm/types"

"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -176,3 +180,128 @@ func Test_MsgServer_UpdateParams(t *testing.T) {
})
require.ErrorContains(t, err, "sudoMint and sudoBurn")
}

func Test_MsgServer_NonceIncrement_Call(t *testing.T) {
ctx, input := createDefaultTestInput(t)
_, _, addr := keyPubAddr()
caller := common.BytesToAddress(addr.Bytes())

counterBz, err := hexutil.Decode(counter.CounterBin)
require.NoError(t, err)

retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, retBz)
require.Len(t, contractAddr, 20)

parsed, err := counter.CounterMetaData.GetAbi()
require.NoError(t, err)

// increment sequence
incremented := true
ctx = ctx.WithValue(evmante.ContextKeySequenceIncremented, &incremented)
acc := input.AccountKeeper.GetAccount(ctx, addr)
seq := acc.GetSequence() + 1
acc.SetSequence(seq)
input.AccountKeeper.SetAccount(ctx, acc)

inputBz, err := parsed.Pack("increase")
require.NoError(t, err)

// should not increment sequence
msgServer := keeper.NewMsgServerImpl(&input.EVMKeeper)
res, err := msgServer.Call(ctx, &types.MsgCall{
Sender: addr.String(),
ContractAddr: contractAddr.Hex(),
Input: hexutil.Encode(inputBz),
})
require.NoError(t, err)
require.Equal(t, "0x", res.Result)
require.NotEmpty(t, res.Logs)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq, acc.GetSequence())

// call again should increment sequence
res, err = msgServer.Call(ctx, &types.MsgCall{
Sender: addr.String(),
ContractAddr: contractAddr.Hex(),
Input: hexutil.Encode(inputBz),
})
require.NoError(t, err)
require.Equal(t, "0x", res.Result)
require.NotEmpty(t, res.Logs)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+1, acc.GetSequence())

// create should increment sequence
_, err = msgServer.Create(ctx, &types.MsgCreate{
Sender: addr.String(),
Code: counter.CounterBin,
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+2, acc.GetSequence())
}

func Test_MsgServer_NonceIncrement_Create(t *testing.T) {
ctx, input := createDefaultTestInput(t)
_, _, addr := keyPubAddr()
caller := common.BytesToAddress(addr.Bytes())

counterBz, err := hexutil.Decode(counter.CounterBin)
require.NoError(t, err)

retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, retBz)
require.Len(t, contractAddr, 20)

parsed, err := counter.CounterMetaData.GetAbi()
require.NoError(t, err)

// increment sequence
incremented := true
ctx = ctx.WithValue(evmante.ContextKeySequenceIncremented, &incremented)
acc := input.AccountKeeper.GetAccount(ctx, addr)
seq := acc.GetSequence() + 1
acc.SetSequence(seq)
input.AccountKeeper.SetAccount(ctx, acc)

inputBz, err := parsed.Pack("increase")
require.NoError(t, err)

// should not increment sequence
msgServer := keeper.NewMsgServerImpl(&input.EVMKeeper)
_, err = msgServer.Create(ctx, &types.MsgCreate{
Sender: addr.String(),
Code: counter.CounterBin,
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq, acc.GetSequence())

// call again should increment sequence
_, err = msgServer.Call(ctx, &types.MsgCall{
Sender: addr.String(),
ContractAddr: contractAddr.Hex(),
Input: hexutil.Encode(inputBz),
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+1, acc.GetSequence())

// create should increment sequence
_, err = msgServer.Create(ctx, &types.MsgCreate{
Sender: addr.String(),
Code: counter.CounterBin,
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+2, acc.GetSequence())
}
beer-1 marked this conversation as resolved.
Show resolved Hide resolved
Loading