Skip to content

Commit

Permalink
implement legacy trace API handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Sep 29, 2023
1 parent 476bcc2 commit 311aeb1
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 50 deletions.
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 35 additions & 10 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,15 @@ func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *
// It follows the specification defined here:
// https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11
func (h *Handler) TraceTransaction(hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
_, _, blockNumber, err := h.bcReader.Receipt(&hash)
return h.traceTransaction(&hash, false)
}

func (h *Handler) LegacyTraceTransaction(hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
return h.traceTransaction(&hash, true)
}

func (h *Handler) traceTransaction(hash *felt.Felt, legacyTraceJSON bool) (json.RawMessage, *jsonrpc.Error) {
_, _, blockNumber, err := h.bcReader.Receipt(hash)
if err != nil {
return nil, ErrInvalidTxHash
}
Expand All @@ -1190,13 +1198,13 @@ func (h *Handler) TraceTransaction(hash felt.Felt) (json.RawMessage, *jsonrpc.Er
}

txIndex := slices.IndexFunc(block.Transactions, func(tx core.Transaction) bool {
return tx.Hash().Equal(&hash)
return tx.Hash().Equal(hash)
})
if txIndex == -1 {
return nil, ErrTxnHashNotFound
}

traceResults, traceBlockErr := h.traceBlockTransactions(block, txIndex+1)
traceResults, traceBlockErr := h.traceBlockTransactions(block, txIndex+1, legacyTraceJSON)
if traceBlockErr != nil {
return nil, traceBlockErr
}
Expand All @@ -1206,6 +1214,18 @@ func (h *Handler) TraceTransaction(hash felt.Felt) (json.RawMessage, *jsonrpc.Er

func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, false)
}

func (h *Handler) LegacySimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, true)
}

func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag, legacyTraceJSON bool,
) ([]SimulatedTransaction, *jsonrpc.Error) {
if slices.Contains(simulationFlags, SkipValidateFlag) {
return nil, jsonrpc.Err(jsonrpc.InvalidParams, "Skip validate is not supported")
Expand Down Expand Up @@ -1257,7 +1277,7 @@ func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTra
sequencerAddress = core.NetworkBlockHashMetaInfo(h.network).FallBackSequencerAddress
}
overallFees, traces, err := h.vm.Execute(txns, classes, blockNumber, header.Timestamp, sequencerAddress,
state, h.network, paidFeesOnL1, skipFeeCharge, header.GasPrice)
state, h.network, paidFeesOnL1, skipFeeCharge, header.GasPrice, legacyTraceJSON)
if err != nil {
rpcErr := *ErrContractError
rpcErr.Data = err.Error()
Expand Down Expand Up @@ -1286,14 +1306,19 @@ func (h *Handler) TraceBlockTransactions(id BlockID) ([]TracedBlockTransaction,
return nil, ErrBlockNotFound
}

return h.traceBlockTransactions(block, len(block.Transactions))
return h.traceBlockTransactions(block, len(block.Transactions), false)
}

func (h *Handler) LegacyTraceBlockTransactions(hash felt.Felt) ([]TracedBlockTransaction, *jsonrpc.Error) {
return h.TraceBlockTransactions(BlockID{Hash: &hash})
block, err := h.bcReader.BlockByHash(&hash)
if err != nil {
return nil, ErrBlockNotFound
}

return h.traceBlockTransactions(block, len(block.Transactions), true)
}

func (h *Handler) traceBlockTransactions(block *core.Block, numTxns int) ([]TracedBlockTransaction, *jsonrpc.Error) {
func (h *Handler) traceBlockTransactions(block *core.Block, numTxns int, legacyTraceJSON bool) ([]TracedBlockTransaction, *jsonrpc.Error) {
isPending := block.Hash == nil

state, closer, err := h.bcReader.StateAtBlockHash(block.ParentHash)
Expand Down Expand Up @@ -1351,7 +1376,7 @@ func (h *Handler) traceBlockTransactions(block *core.Block, numTxns int) ([]Trac
}

_, traces, err := h.vm.Execute(transactions, classes, blockNumber, header.Timestamp,
sequencerAddress, state, h.network, paidFeesOnL1, false, header.GasPrice)
sequencerAddress, state, h.network, paidFeesOnL1, false, header.GasPrice, legacyTraceJSON)
if err != nil {
rpcErr := *ErrContractError
rpcErr.Data = err.Error()
Expand Down Expand Up @@ -1652,12 +1677,12 @@ func (h *Handler) LegacyMethods() ([]jsonrpc.Method, string) { //nolint: funlen
{
Name: "starknet_traceTransaction",
Params: []jsonrpc.Parameter{{Name: "transaction_hash"}},
Handler: h.TraceTransaction,
Handler: h.LegacyTraceTransaction,
},
{
Name: "starknet_simulateTransactions",
Params: []jsonrpc.Parameter{{Name: "block_id"}, {Name: "transactions"}, {Name: "simulation_flags"}},
Handler: h.SimulateTransactions,
Handler: h.LegacySimulateTransactions,
},
{
Name: "starknet_traceBlockTransactions",
Expand Down
12 changes: 6 additions & 6 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2030,10 +2030,10 @@ func TestEstimateMessageFee(t *testing.T) {

expectedGasConsumed := new(felt.Felt).SetUint64(37)
mockVM.EXPECT().Execute(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), utils.MAINNET, gomock.Any(), gomock.Any(), latestHeader.GasPrice).DoAndReturn(
gomock.Any(), utils.MAINNET, gomock.Any(), gomock.Any(), latestHeader.GasPrice, gomock.Any()).DoAndReturn(
func(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee bool, gasPrice *felt.Felt,
skipChargeFee bool, gasPrice *felt.Felt, legacyTraceJson bool,
) ([]*felt.Felt, []json.RawMessage, error) {
require.Len(t, txns, 1)
assert.NotNil(t, txns[0].(*core.L1HandlerTransaction))
Expand Down Expand Up @@ -2112,7 +2112,7 @@ func TestTraceTransaction(t *testing.T) {
"fee_transfer_invocation": {"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3", "entry_point_type": "EXTERNAL", "call_type": "CALL", "result": ["0x1"], "calls": [{"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0", "entry_point_type": "EXTERNAL", "call_type": "LIBRARY_CALL", "result": ["0x1"], "calls": [], "events": [{"keys": ["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"], "data": ["0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"]}], "messages": []}], "events": [], "messages": []}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
nil, utils.MAINNET, []*felt.Felt{}, false, gomock.Any()).Return(nil, []json.RawMessage{vmTrace}, nil)
nil, utils.MAINNET, []*felt.Felt{}, false, gomock.Any(), false).Return(nil, []json.RawMessage{vmTrace}, nil)

trace, err := handler.TraceTransaction(*hash)
require.Nil(t, err)
Expand Down Expand Up @@ -2142,7 +2142,7 @@ func TestSimulateTransactions(t *testing.T) {
mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil)

sequencerAddress := core.NetworkBlockHashMetaInfo(network).FallBackSequencerAddress
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, nil).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, nil, false).
Return([]*felt.Felt{}, []json.RawMessage{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
Expand Down Expand Up @@ -2210,7 +2210,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation": {}
}`)
mockVM.EXPECT().Execute(block.Transactions, []core.Class{declaredClass.Class}, height+1, header.Timestamp, sequencerAddress,
state, network, paidL1Fees, false, header.GasPrice).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)
state, network, paidL1Fees, false, header.GasPrice, false).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)

result, err := handler.TraceBlockTransactions(rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
Expand Down Expand Up @@ -2253,7 +2253,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation":{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3","entry_point_type":"EXTERNAL","call_type":"CALL","result":["0x1"],"calls":[{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0","entry_point_type":"EXTERNAL","call_type":"LIBRARY_CALL","result":["0x1"],"calls":[],"events":[{"keys":["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"],"data":["0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"]}],"messages":[]}],"events":[],"messages":[]}}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
nil, network, []*felt.Felt{}, false, header.GasPrice).Return(nil, []json.RawMessage{vmTrace}, nil)
nil, network, []*felt.Felt{}, false, header.GasPrice, false).Return(nil, []json.RawMessage{vmTrace}, nil)

expectedResult := []rpc.TracedBlockTransaction{
{
Expand Down
80 changes: 60 additions & 20 deletions vm/rust/src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,37 @@ pub struct TransactionTrace {
constructor_invocation: Option<FunctionInvocation>,
#[serde(skip_serializing_if = "Option::is_none")]
function_invocation: Option<FunctionInvocation>,
r#type: TransactionType,
#[serde(skip_serializing_if = "Option::is_none")]
r#type: Option<TransactionType>,
#[serde(skip_serializing_if = "Option::is_none")]
state_diff: Option<ThinStateDiff>,
}

impl TransactionTrace {
pub fn make_legacy(&mut self) {
self.state_diff = None;
self.r#type = None;
if let Some(invocation) = &mut self.validate_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.execute_invocation {
match invocation {
ExecuteInvocation::Ok(fn_invocation) => { fn_invocation.make_legacy() }
_ => {}
}
}
if let Some(invocation) = &mut self.fee_transfer_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.constructor_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.function_invocation {
invocation.make_legacy()
}
}
}

impl Default for TransactionTrace {
fn default() -> Self {
Self {
Expand All @@ -52,7 +78,7 @@ impl Default for TransactionTrace {
fee_transfer_invocation: None,
constructor_invocation: None,
function_invocation: None,
r#type: TransactionType::Unknown,
r#type: None,
state_diff: None,
}
}
Expand All @@ -76,13 +102,13 @@ pub fn new_transaction_trace(
match tx {
StarknetApiTransaction::L1Handler(_) => {
trace.function_invocation = info.execute_call_info.map(|v| v.into());
trace.r#type = TransactionType::L1Handler;
trace.r#type = Some(TransactionType::L1Handler);
}
StarknetApiTransaction::DeployAccount(_) => {
trace.validate_invocation = info.validate_call_info.map(|v| v.into());
trace.constructor_invocation = info.execute_call_info.map(|v| v.into());
trace.fee_transfer_invocation = info.fee_transfer_call_info.map(|v| v.into());
trace.r#type = TransactionType::DeployAccount;
trace.r#type = Some(TransactionType::DeployAccount);
}
StarknetApiTransaction::Invoke(_) => {
trace.validate_invocation = info.validate_call_info.map(|v| v.into());
Expand All @@ -93,12 +119,12 @@ pub fn new_transaction_trace(
.map(|v| ExecuteInvocation::Ok(v.into())),
};
trace.fee_transfer_invocation = info.fee_transfer_call_info.map(|v| v.into());
trace.r#type = TransactionType::Invoke;
trace.r#type = Some(TransactionType::Invoke);
}
StarknetApiTransaction::Declare(declare_txn) => {
trace.validate_invocation = info.validate_call_info.map(|v| v.into());
trace.fee_transfer_invocation = info.fee_transfer_call_info.map(|v| v.into());
trace.r#type = TransactionType::Declare;
trace.r#type = Some(TransactionType::Declare);
deprecated_declared_class = if info.revert_error.is_none() {
match declare_txn {
DeclareTransaction::V0(_) => Some(declare_txn.class_hash()),
Expand All @@ -121,7 +147,8 @@ pub fn new_transaction_trace(

#[derive(Serialize)]
pub struct OrderedEvent {
pub order: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub order: Option<usize>,
#[serde(flatten)]
pub event: EventContent,
}
Expand All @@ -130,7 +157,7 @@ type BlockifierOrderedEvent = blockifier::execution::entry_point::OrderedEvent;
impl From<BlockifierOrderedEvent> for OrderedEvent {
fn from(val: BlockifierOrderedEvent) -> Self {
OrderedEvent {
order: val.order,
order: Some(val.order),
event: val.event,
}
}
Expand All @@ -144,10 +171,24 @@ pub struct FunctionInvocation {
pub class_hash: Option<ClassHash>,
pub entry_point_type: EntryPointType,
pub call_type: String,
pub result: Option<Vec<StarkFelt>>,
pub calls: Option<Vec<FunctionInvocation>>,
pub events: Option<Vec<OrderedEvent>>,
pub messages: Option<Vec<OrderedMessage>>,
pub result: Vec<StarkFelt>,
pub calls: Vec<FunctionInvocation>,
pub events: Vec<OrderedEvent>,
pub messages: Vec<OrderedMessage>,
}

impl FunctionInvocation {
fn make_legacy(&mut self) {
for indx in 0..self.events.len() {
self.events[indx].order = None;
}
for indx in 0..self.messages.len() {
self.messages[indx].order = None;
}
for indx in 0..self.calls.len() {
self.calls[indx].make_legacy();
}
}
}

type BlockifierCallInfo = blockifier::execution::entry_point::CallInfo;
Expand All @@ -162,16 +203,15 @@ impl From<BlockifierCallInfo> for FunctionInvocation {
.to_string(),
caller_address: val.call.caller_address,
class_hash: val.call.class_hash,
result: Some(val.execution.retdata.0),
result: val.execution.retdata.0,
function_call: FunctionCall {
contract_address: val.call.storage_address,
entry_point_selector: val.call.entry_point_selector,
calldata: val.call.calldata,
},
calls: Some(val.inner_calls.into_iter().map(|v| v.into()).collect()),
events: Some(val.execution.events.into_iter().map(|v| v.into()).collect()),
messages: Some(
val.execution
calls: val.inner_calls.into_iter().map(|v| v.into()).collect(),
events: val.execution.events.into_iter().map(|v| v.into()).collect(),
messages: val.execution
.l2_to_l1_messages
.into_iter()
.map(|v| {
Expand All @@ -180,7 +220,6 @@ impl From<BlockifierCallInfo> for FunctionInvocation {
ordered_message
})
.collect(),
),
}
}
}
Expand All @@ -194,7 +233,8 @@ pub struct FunctionCall {

#[derive(Serialize)]
pub struct OrderedMessage {
pub order: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub order: Option<usize>,
pub from_address: ContractAddress,
pub to_address: EthAddress,
pub payload: L2ToL1Payload,
Expand All @@ -203,7 +243,7 @@ pub struct OrderedMessage {
impl From<OrderedL2ToL1Message> for OrderedMessage {
fn from(val: OrderedL2ToL1Message) -> Self {
OrderedMessage {
order: val.order,
order: Some(val.order),
from_address: ContractAddress(PatriciaKey::default()),
to_address: val.message.to_address,
payload: val.message.payload,
Expand Down
Loading

0 comments on commit 311aeb1

Please sign in to comment.