diff --git a/x/evm/ante/sequence.go b/x/evm/ante/sequence.go index 63f5611..c5e85a3 100644 --- a/x/evm/ante/sequence.go +++ b/x/evm/ante/sequence.go @@ -43,6 +43,7 @@ func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim } // set a flag in context to indicate that sequence has been incremented in ante handler - ctx = ctx.WithValue(ContextKeySequenceIncremented, true) + incremented := true // use pointer to enable revert after first call + ctx = ctx.WithValue(ContextKeySequenceIncremented, &incremented) return next(ctx, tx, simulate) } diff --git a/x/evm/keeper/msg_server.go b/x/evm/keeper/msg_server.go index 1724f8f..e0dd20b 100644 --- a/x/evm/keeper/msg_server.go +++ b/x/evm/keeper/msg_server.go @@ -33,7 +33,7 @@ func (ms *msgServerImpl) Create(ctx context.Context, msg *types.MsgCreate) (*typ } // handle cosmos<>evm different sequence increment logic - ctx, err = ms.handleSequenceIncremented(ctx, sender, true) + err = ms.handleSequenceIncremented(ctx, sender, true) if err != nil { return nil, err } @@ -98,7 +98,7 @@ func (ms *msgServerImpl) Create2(ctx context.Context, msg *types.MsgCreate2) (*t } // handle cosmos<>evm different sequence increment logic - ctx, err = ms.handleSequenceIncremented(ctx, sender, true) + err = ms.handleSequenceIncremented(ctx, sender, true) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (ms *msgServerImpl) Call(ctx context.Context, msg *types.MsgCall) (*types.M } // handle cosmos<>evm different sequence increment logic - ctx, err = ms.handleSequenceIncremented(ctx, sender, false) + err = ms.handleSequenceIncremented(ctx, sender, false) if err != nil { return nil, err } @@ -258,20 +258,36 @@ func (ms *msgServerImpl) testFeeDenom(ctx context.Context, params types.Params) // In the Cosmos SDK, the sequence number is incremented in the ante handler. // In the EVM, the sequence number is incremented during the execution of create and create2 messages. -// However, for call messages, the sequence number is incremented in the ante handler like the Cosmos SDK. -// To prevent double incrementing the sequence number during EVM execution, we need to decrement it here for create messages. -func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sdk.AccAddress, isCreate bool) (context.Context, error) { +// +// If the sequence number is already incremented in the ante handler and the message is create, decrement the sequence number to prevent double incrementing. +// If the sequence number is not incremented in the ante handler and the message is call, increment the sequence number to ensure proper sequencing. +func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sdk.AccAddress, isCreate bool) error { sdkCtx := sdk.UnwrapSDKContext(ctx) + if sdkCtx.Value(evmante.ContextKeySequenceIncremented) == nil { + return nil + } - // decrement sequence of the sender - if isCreate && sdkCtx.Value(evmante.ContextKeySequenceIncremented) != nil { + incremented := sdkCtx.Value(evmante.ContextKeySequenceIncremented).(*bool) + if isCreate && *incremented { + // if the sequence is already incremented, decrement it to prevent double incrementing the sequence number at create. acc := k.accountKeeper.GetAccount(ctx, sender) if err := acc.SetSequence(acc.GetSequence() - 1); err != nil { - return ctx, err + return err + } + + k.accountKeeper.SetAccount(ctx, acc) + } else if !isCreate && !*incremented { + // if the sequence is not incremented and the message is call, increment the sequence number. + acc := k.accountKeeper.GetAccount(ctx, sender) + if err := acc.SetSequence(acc.GetSequence() + 1); err != nil { + return err } k.accountKeeper.SetAccount(ctx, acc) } - return sdkCtx.WithValue(evmante.ContextKeySequenceIncremented, nil), nil + // set the flag to false + *incremented = false + + return nil } diff --git a/x/evm/keeper/msg_server_test.go b/x/evm/keeper/msg_server_test.go index e9cc9fa..31f895b 100644 --- a/x/evm/keeper/msg_server_test.go +++ b/x/evm/keeper/msg_server_test.go @@ -6,10 +6,14 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" + "github.com/holiman/uint256" + + evmante "github.com/initia-labs/minievm/x/evm/ante" "github.com/initia-labs/minievm/x/evm/contracts/counter" "github.com/initia-labs/minievm/x/evm/keeper" "github.com/initia-labs/minievm/x/evm/types" + "github.com/stretchr/testify/require" ) @@ -176,3 +180,128 @@ func Test_MsgServer_UpdateParams(t *testing.T) { }) require.ErrorContains(t, err, "sudoMint and sudoBurn") } + +func Test_MsgServer_NonceIncrement_Call(t *testing.T) { + ctx, input := createDefaultTestInput(t) + _, _, addr := keyPubAddr() + caller := common.BytesToAddress(addr.Bytes()) + + counterBz, err := hexutil.Decode(counter.CounterBin) + require.NoError(t, err) + + retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil) + require.NoError(t, err) + require.NotEmpty(t, retBz) + require.Len(t, contractAddr, 20) + + parsed, err := counter.CounterMetaData.GetAbi() + require.NoError(t, err) + + // increment sequence + incremented := true + ctx = ctx.WithValue(evmante.ContextKeySequenceIncremented, &incremented) + acc := input.AccountKeeper.GetAccount(ctx, addr) + seq := acc.GetSequence() + 1 + acc.SetSequence(seq) + input.AccountKeeper.SetAccount(ctx, acc) + + inputBz, err := parsed.Pack("increase") + require.NoError(t, err) + + // should not increment sequence + msgServer := keeper.NewMsgServerImpl(&input.EVMKeeper) + res, err := msgServer.Call(ctx, &types.MsgCall{ + Sender: addr.String(), + ContractAddr: contractAddr.Hex(), + Input: hexutil.Encode(inputBz), + }) + require.NoError(t, err) + require.Equal(t, "0x", res.Result) + require.NotEmpty(t, res.Logs) + + acc = input.AccountKeeper.GetAccount(ctx, addr) + require.Equal(t, seq, acc.GetSequence()) + + // call again should increment sequence + res, err = msgServer.Call(ctx, &types.MsgCall{ + Sender: addr.String(), + ContractAddr: contractAddr.Hex(), + Input: hexutil.Encode(inputBz), + }) + require.NoError(t, err) + require.Equal(t, "0x", res.Result) + require.NotEmpty(t, res.Logs) + + acc = input.AccountKeeper.GetAccount(ctx, addr) + require.Equal(t, seq+1, acc.GetSequence()) + + // create should increment sequence + _, err = msgServer.Create(ctx, &types.MsgCreate{ + Sender: addr.String(), + Code: counter.CounterBin, + }) + require.NoError(t, err) + + acc = input.AccountKeeper.GetAccount(ctx, addr) + require.Equal(t, seq+2, acc.GetSequence()) +} + +func Test_MsgServer_NonceIncrement_Create(t *testing.T) { + ctx, input := createDefaultTestInput(t) + _, _, addr := keyPubAddr() + caller := common.BytesToAddress(addr.Bytes()) + + counterBz, err := hexutil.Decode(counter.CounterBin) + require.NoError(t, err) + + retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil) + require.NoError(t, err) + require.NotEmpty(t, retBz) + require.Len(t, contractAddr, 20) + + parsed, err := counter.CounterMetaData.GetAbi() + require.NoError(t, err) + + // increment sequence + incremented := true + ctx = ctx.WithValue(evmante.ContextKeySequenceIncremented, &incremented) + acc := input.AccountKeeper.GetAccount(ctx, addr) + seq := acc.GetSequence() + 1 + acc.SetSequence(seq) + input.AccountKeeper.SetAccount(ctx, acc) + + inputBz, err := parsed.Pack("increase") + require.NoError(t, err) + + // should not increment sequence + msgServer := keeper.NewMsgServerImpl(&input.EVMKeeper) + _, err = msgServer.Create(ctx, &types.MsgCreate{ + Sender: addr.String(), + Code: counter.CounterBin, + }) + require.NoError(t, err) + + acc = input.AccountKeeper.GetAccount(ctx, addr) + require.Equal(t, seq, acc.GetSequence()) + + // call again should increment sequence + _, err = msgServer.Call(ctx, &types.MsgCall{ + Sender: addr.String(), + ContractAddr: contractAddr.Hex(), + Input: hexutil.Encode(inputBz), + }) + require.NoError(t, err) + + acc = input.AccountKeeper.GetAccount(ctx, addr) + require.Equal(t, seq+1, acc.GetSequence()) + + // create should increment sequence + _, err = msgServer.Create(ctx, &types.MsgCreate{ + Sender: addr.String(), + Code: counter.CounterBin, + }) + require.NoError(t, err) + + acc = input.AccountKeeper.GetAccount(ctx, addr) + require.Equal(t, seq+2, acc.GetSequence()) +}