Skip to content

Commit

Permalink
Add errOnRevert flag to vm.Execute
Browse files Browse the repository at this point in the history
to treat reverted txns as hard-failures
  • Loading branch information
omerfirmak committed Dec 4, 2023
1 parent e1125b6 commit 4d10e42
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 26 deletions.
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ func (tvm *ThrottledVM) Call(contractAddr, classHash, selector *felt.Felt, calld

func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
skipChargeFee, skipValidate, errOnRevert bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
) ([]*felt.Felt, []json.RawMessage, error) {
var ret []*felt.Felt
var traces []json.RawMessage
throttler := (*utils.Throttler[vm.VM])(tvm)
return ret, traces, throttler.Do(func(vm *vm.VM) error {
var err error
ret, traces, err = (*vm).Execute(txns, declaredClasses, blockNumber, blockTimestamp, sequencerAddress,
state, network, paidFeesOnL1, skipChargeFee, skipValidate, gasPriceWEI, gasPriceSTRK, legacyTraceJSON)
state, network, paidFeesOnL1, skipChargeFee, skipValidate, errOnRevert, gasPriceWEI, gasPriceSTRK, legacyTraceJSON)
return err
})
}
14 changes: 7 additions & 7 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ func (h *Handler) TransactionStatus(ctx context.Context, hash felt.Felt) (*Trans
func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction,
simulationFlags []SimulationFlag, id BlockID,
) ([]FeeEstimate, *jsonrpc.Error) {
result, err := h.SimulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag))
result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), false, true)
if err != nil {
return nil, err
}
Expand All @@ -1297,7 +1297,7 @@ func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction,
}

func (h *Handler) LegacyEstimateFee(broadcastedTxns []BroadcastedTransaction, id BlockID) ([]FeeEstimate, *jsonrpc.Error) {
result, err := h.LegacySimulateTransactions(id, broadcastedTxns, []SimulationFlag{SkipFeeChargeFlag})
result, err := h.simulateTransactions(id, broadcastedTxns, []SimulationFlag{SkipFeeChargeFlag}, true, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1396,17 +1396,17 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyT
func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, false)
return h.simulateTransactions(id, transactions, simulationFlags, false, false)
}

func (h *Handler) LegacySimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, true)
return h.simulateTransactions(id, transactions, simulationFlags, true, false)
}

func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTransaction, //nolint: gocyclo
simulationFlags []SimulationFlag, legacyTraceJSON bool,
simulationFlags []SimulationFlag, legacyTraceJSON, errOnRevert bool,
) ([]SimulatedTransaction, *jsonrpc.Error) {
skipFeeCharge := slices.Contains(simulationFlags, SkipFeeChargeFlag)
skipValidate := slices.Contains(simulationFlags, SkipValidateFlag)
Expand Down Expand Up @@ -1456,7 +1456,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra
sequencerAddress = core.NetworkBlockHashMetaInfo(h.network).FallBackSequencerAddress
}
overallFees, traces, err := h.vm.Execute(txns, classes, blockNumber, header.Timestamp, sequencerAddress,
state, h.network, paidFeesOnL1, skipFeeCharge, skipValidate, header.GasPrice, header.GasPriceSTRK, legacyTraceJSON)
state, h.network, paidFeesOnL1, skipFeeCharge, skipValidate, errOnRevert, header.GasPrice, header.GasPriceSTRK, legacyTraceJSON)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrInternal.CloneWithData(err.Error())
Expand Down Expand Up @@ -1606,7 +1606,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block,
}

_, traces, err := h.vm.Execute(block.Transactions, classes, blockNumber, block.Header.Timestamp,
sequencerAddress, state, h.network, paidFeesOnL1, false, false, block.Header.GasPrice, block.Header.GasPriceSTRK, legacyJSON)
sequencerAddress, state, h.network, paidFeesOnL1, false, false, false, block.Header.GasPrice, block.Header.GasPriceSTRK, legacyJSON)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrInternal.CloneWithData(err.Error())
Expand Down
17 changes: 9 additions & 8 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3045,10 +3045,11 @@ func TestEstimateMessageFee(t *testing.T) {

expectedGasConsumed := new(felt.Felt).SetUint64(37)
mockVM.EXPECT().Execute(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), utils.Mainnet, gomock.Any(), gomock.Any(), gomock.Any(), latestHeader.GasPrice, latestHeader.GasPriceSTRK, false).DoAndReturn(
gomock.Any(), utils.Mainnet, gomock.Any(), gomock.Any(), gomock.Any(), true, latestHeader.GasPrice,
latestHeader.GasPriceSTRK, false).DoAndReturn(
func(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWei, gasPriceSTRK *felt.Felt, legacyTraceJson bool,
skipChargeFee, skipValidate, errOnRevert bool, gasPriceWei, gasPriceSTRK *felt.Felt, legacyTraceJson bool,
) ([]*felt.Felt, []json.RawMessage, error) {
require.Len(t, txns, 1)
assert.NotNil(t, txns[0].(*core.L1HandlerTransaction))
Expand Down Expand Up @@ -3177,7 +3178,7 @@ func TestTraceTransaction(t *testing.T) {
"fee_transfer_invocation": {"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3", "entry_point_type": "EXTERNAL", "call_type": "CALL", "result": ["0x1"], "calls": [{"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0", "entry_point_type": "EXTERNAL", "call_type": "DELEGATE", "result": ["0x1"], "calls": [], "events": [{"keys": ["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"], "data": ["0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"]}], "messages": []}], "events": [], "messages": []}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
gomock.Any(), utils.Mainnet, []*felt.Felt{}, false, false, gomock.Any(), gomock.Any(), false).Return(nil, []json.RawMessage{vmTrace}, nil)
gomock.Any(), utils.Mainnet, []*felt.Felt{}, false, false, false, gomock.Any(), gomock.Any(), false).Return(nil, []json.RawMessage{vmTrace}, nil)

trace, err := handler.TraceTransaction(context.Background(), *hash)
require.Nil(t, err)
Expand All @@ -3202,23 +3203,23 @@ func TestSimulateTransactions(t *testing.T) {
sequencerAddress := core.NetworkBlockHashMetaInfo(network).FallBackSequencerAddress

t.Run("ok with zero values, skip fee", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, false, nil, nil, false).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, false, false, nil, nil, false).
Return([]*felt.Felt{}, []json.RawMessage{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
require.Nil(t, err)
})

t.Run("ok with zero values, skip validate", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, false, true, nil, nil, false).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, false, true, false, nil, nil, false).
Return([]*felt.Felt{}, []json.RawMessage{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag})
require.Nil(t, err)
})

t.Run("transaction execution error", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, false, true, nil, nil, false).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, false, true, false, nil, nil, false).
Return(nil, nil, vm.TransactionExecutionError{
Index: 44,
Cause: errors.New("oops"),
Expand Down Expand Up @@ -3293,7 +3294,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation": {}
}`)
mockVM.EXPECT().Execute(block.Transactions, []core.Class{declaredClass.Class}, height+1, header.Timestamp, sequencerAddress,
gomock.Any(), network, paidL1Fees, false, false, header.GasPrice, header.GasPriceSTRK, false).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)
gomock.Any(), network, paidL1Fees, false, false, false, header.GasPrice, header.GasPriceSTRK, false).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)

result, err := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
Expand Down Expand Up @@ -3337,7 +3338,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation":{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3","entry_point_type":"EXTERNAL","call_type":"CALL","result":["0x1"],"calls":[{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0","entry_point_type":"EXTERNAL","call_type":"DELEGATE","result":["0x1"],"calls":[],"events":[{"keys":["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"],"data":["0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"]}],"messages":[]}],"events":[],"messages":[]}}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
gomock.Any(), network, []*felt.Felt{}, false, false, header.GasPrice, header.GasPriceSTRK, false).Return(nil, []json.RawMessage{vmTrace}, nil)
gomock.Any(), network, []*felt.Felt{}, false, false, false, header.GasPrice, header.GasPriceSTRK, false).Return(nil, []json.RawMessage{vmTrace}, nil)

expectedResult := []rpc.TracedBlockTransaction{
{
Expand Down
15 changes: 15 additions & 0 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ pub extern "C" fn cairoVMExecute(
paid_fees_on_l1_json: *const c_char,
skip_charge_fee: c_uchar,
skip_validate: c_uchar,
err_on_revert: c_uchar,
gas_price_wei: *const c_uchar,
gas_price_strk: *const c_uchar,
legacy_json: c_uchar,
Expand Down Expand Up @@ -206,6 +207,10 @@ pub extern "C" fn cairoVMExecute(
let charge_fee = skip_charge_fee == 0;
let validate = skip_validate == 0;

println!("err_on_revert != 0 {}", err_on_revert != 0);
println!("charge_fee {}", charge_fee);
println!("validate {}", validate);

let mut trace_buffer = Vec::with_capacity(10_000);

for (txn_index, txn_and_query_bit) in txns_and_query_bits.iter().enumerate() {
Expand Down Expand Up @@ -283,6 +288,16 @@ pub extern "C" fn cairoVMExecute(
return;
}
Ok(mut t) => {
if t.is_reverted() && err_on_revert != 0{
report_error(
reader_handle,
format!("reverted: {}",t.revert_error.unwrap())
.as_str(),
txn_index as i64
);
return;
}

// we are estimating fee, override actual fee calculation
if !charge_fee {
t.actual_fee = calculate_tx_fee(&t.actual_resources, &block_context, &fee_type).unwrap();
Expand Down
12 changes: 9 additions & 3 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ package vm
//
// extern void cairoVMExecute(char* txns_json, char* classes_json, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fees_on_l1_json,
// unsigned char skip_charge_fee, unsigned char skip_validate, char* gas_price_wei, char* gas_price_strk, unsigned char legacy_json);
// unsigned char skip_charge_fee, unsigned char skip_validate, unsigned char err_on_revert, char* gas_price_wei,
// char* gas_price_strk, unsigned char legacy_json);
//
// #cgo vm_debug LDFLAGS: -L./rust/target/debug -ljuno_starknet_rs -lm -ldl
// #cgo !vm_debug LDFLAGS: -L./rust/target/release -ljuno_starknet_rs -lm -ldl
Expand All @@ -33,7 +34,7 @@ type VM interface {
) ([]*felt.Felt, error)
Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
skipChargeFee, skipValidate, errOnRevert bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
) ([]*felt.Felt, []json.RawMessage, error)
}

Expand Down Expand Up @@ -163,7 +164,7 @@ func (v *vm) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.
// Execute executes a given transaction set and returns the gas spent per transaction
func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
skipChargeFee, skipValidate, errOnRevert bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
) ([]*felt.Felt, []json.RawMessage, error) {
context := &callContext{
state: state,
Expand Down Expand Up @@ -204,6 +205,10 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, bloc
skipValidateByte = 1
}

var errOnRevertByte byte
if errOnRevert {
errOnRevertByte = 1
}
var legacyTraceJSONByte byte
if legacyTraceJSON {
legacyTraceJSONByte = 1
Expand All @@ -220,6 +225,7 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, bloc
paidFeesOnL1CStr,
C.uchar(skipChargeFeeByte),
C.uchar(skipValidateByte),
C.uchar(errOnRevertByte),
(*C.char)(unsafe.Pointer(&gasPriceWEIBytes[0])),
(*C.char)(unsafe.Pointer(&gasPriceSTRKBytes[0])),
C.uchar(legacyTraceJSONByte),
Expand Down
4 changes: 2 additions & 2 deletions vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ func TestExecute(t *testing.T) {
address = utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b")
timestamp = uint64(1666877926)
)
_, _, err := New(nil).Execute([]core.Transaction{}, []core.Class{}, 0, timestamp, address, state, network, []*felt.Felt{}, false, false, &felt.Zero, &felt.Zero, false)
_, _, err := New(nil).Execute([]core.Transaction{}, []core.Class{}, 0, timestamp, address, state, network, []*felt.Felt{}, false, false, false, &felt.Zero, &felt.Zero, false)
require.NoError(t, err)
})
t.Run("zero data", func(t *testing.T) {
_, _, err := New(nil).Execute(nil, nil, 0, 0, &felt.Zero, state, network, []*felt.Felt{}, false, false, &felt.Zero, &felt.Zero, false)
_, _, err := New(nil).Execute(nil, nil, 0, 0, &felt.Zero, state, network, []*felt.Felt{}, false, false, false, &felt.Zero, &felt.Zero, false)
require.NoError(t, err)
})
}

0 comments on commit 4d10e42

Please sign in to comment.