From 708b5f762d69e35dea48138b1c8c2071033be46a Mon Sep 17 00:00:00 2001 From: rianhughes Date: Wed, 21 Feb 2024 17:52:09 +0200 Subject: [PATCH] abilengths wip --- vm/class.go | 12 +++++++++++ vm/rust/Cargo.toml | 2 +- vm/rust/src/juno_state_reader.rs | 1 - vm/rust/src/lib.rs | 37 +++++++++++++++++++++++--------- vm/vm.go | 22 ++++++++++++++++--- 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/vm/class.go b/vm/class.go index 34c0d28514..d45de8b97d 100644 --- a/vm/class.go +++ b/vm/class.go @@ -16,6 +16,7 @@ func marshalCompiledClass(class core.Class) (json.RawMessage, error) { if err != nil { return nil, err } + return json.Marshal(compiledCairo0Class) case *core.Cairo1Class: if c.Compiled == nil { @@ -28,3 +29,14 @@ func marshalCompiledClass(class core.Class) (json.RawMessage, error) { return nil, fmt.Errorf("unsupported class type %T", c) } } + +func marshalABILength(class core.Class) (json.RawMessage, error) { + switch c := class.(type) { + case *core.Cairo0Class: + return json.Marshal(len(c.Abi)) + case *core.Cairo1Class: + return json.Marshal(len(c.Abi)) + default: + return nil, fmt.Errorf("unsupported class type %T", c) + } +} diff --git a/vm/rust/Cargo.toml b/vm/rust/Cargo.toml index eccc00cd0e..881a782076 100644 --- a/vm/rust/Cargo.toml +++ b/vm/rust/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] serde = "1.0.171" serde_json = { version = "1.0.96", features = ["raw_value"] } -blockifier = "=0.5.0-rc.1" +blockifier = "=0.5.0-rc.3" starknet_api = "=0.8.0" cairo-vm = "=0.9.2" indexmap = "2.1.0" diff --git a/vm/rust/src/juno_state_reader.rs b/vm/rust/src/juno_state_reader.rs index 109724895c..4e1764b1be 100644 --- a/vm/rust/src/juno_state_reader.rs +++ b/vm/rust/src/juno_state_reader.rs @@ -2,7 +2,6 @@ use std::{ ffi::{c_char, c_uchar, c_void, CStr}, slice, sync::Mutex, - mem, }; use blockifier::execution::contract_class::ContractClass; diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index b474f3899e..ee19fa8b06 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -13,6 +13,7 @@ use blockifier::{ block::{BlockInfo,GasPrices}, context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext}, execution::{ + contract_class::ClassInfo, common_hints::ExecutionMode, entry_point::{CallEntryPoint, CallType, EntryPointExecutionContext}, }, @@ -25,7 +26,7 @@ use blockifier::{ ValidateTransactionError, }, objects::{DeprecatedTransactionInfo, HasRelatedFeeType, TransactionInfo}, transaction_execution::Transaction, - transactions::{ExecutableTransaction,ClassInfo} + transactions::ExecutableTransaction }, versioned_constants::VersionedConstants @@ -52,7 +53,7 @@ extern "C" { fn JunoAppendActualFee(reader_handle: usize, ptr: *const c_uchar); } -const GLOBAL_CONTRACT_CACHE_SIZE: usize= 100; // Todo ? default used to set this to 100. +const GLOBAL_CONTRACT_CACHE_SIZE: usize= 100; @@ -171,6 +172,7 @@ pub struct TxnAndQueryBit { pub txn: StarknetApiTransaction, pub txn_hash: TransactionHash, pub query_bit: bool, + pub abi_length : usize, } #[no_mangle] @@ -263,8 +265,15 @@ pub extern "C" fn cairoVMExecute( let mut trace_buffer = Vec::with_capacity(10_000); for (txn_index, txn_and_query_bit) in txns_and_query_bits.iter().enumerate() { + + let mut abi_length = 0; + let contract_class = match txn_and_query_bit.txn.clone() { + StarknetApiTransaction::Declare(_) => { + + abi_length = txn_and_query_bit.abi_length; + if classes.is_empty() { report_error(reader_handle, "missing declared class", txn_index as i64); return; @@ -293,18 +302,26 @@ pub extern "C" fn cairoVMExecute( }; - let sierra_program_length = 0; // Todo: Should be a new parameter? - let abi_length = 0; // Todo: Should be a new parameter? - let class_info = ClassInfo { - contract_class: contract_class.unwrap(), - sierra_program_length, - abi_length, - }; + let contract_class_unwrap = contract_class.unwrap(); + let sierra_program_length = contract_class_unwrap.bytecode_length(); + + + let class_info_result = ClassInfo::new(&contract_class_unwrap, sierra_program_length, abi_length); + let mut class_info: Option = None; + match class_info_result { + Ok(info) => { + class_info = Some(info); + }, + Err(e) => { + report_error(reader_handle, e.to_string().as_str(), txn_index as i64); + return; + } + } let txn = transaction_from_api( txn_and_query_bit.txn.clone(), txn_and_query_bit.txn_hash, - Some(class_info), + class_info, paid_fee_on_l1, txn_and_query_bit.query_bit, ); diff --git a/vm/vm.go b/vm/vm.go index e3cb91eec4..881d29d88b 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -177,7 +177,7 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, bloc handle := cgo.NewHandle(context) defer handle.Delete() - txnsJSON, classesJSON, err := marshalTxnsAndDeclaredClasses(txns, declaredClasses) + txnsJSON, classesJSON, abiLengthJSON, err := marshalTxnsAndDeclaredClasses(txns, declaredClasses) if err != nil { return nil, nil, err } @@ -190,6 +190,7 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, bloc paidFeesOnL1CStr := cstring(paidFeesOnL1Bytes) txnsJSONCstr := cstring(txnsJSON) classesJSONCStr := cstring(classesJSON) + abiLengthsJSONCStr := cstring(abiLengthJSON) sequencerAddressBytes := sequencerAddress.Bytes() @@ -211,6 +212,7 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, bloc chainID := C.CString(network.L2ChainID) C.cairoVMExecute(txnsJSONCstr, classesJSONCStr, + abiLengthsJSONCStr, C.uintptr_t(handle), C.ulonglong(blockNumber), C.ulonglong(blockTimestamp), @@ -267,7 +269,7 @@ func boolToByte(b bool) byte { return 0 } -func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []core.Class) (json.RawMessage, json.RawMessage, error) { +func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []core.Class) (json.RawMessage, json.RawMessage,json.RawMessage,, error) { txnJSONs := []json.RawMessage{} for _, txn := range txns { txnJSON, err := marshalTxn(txn) @@ -286,6 +288,15 @@ func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []co classJSONs = append(classJSONs, declaredClassJSON) } + abiLengthsJSONs := []json.RawMessage{} + for _, declaredClass := range declaredClasses { + abiLengthJSON, cErr := marshalABILength(declaredClass) + if cErr != nil { + return nil, nil, cErr + } + abiLengthsJSONs = append(abiLengthsJSONs, abiLengthJSON) + } + txnsJSON, err := json.Marshal(txnJSONs) if err != nil { return nil, nil, err @@ -295,5 +306,10 @@ func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []co return nil, nil, err } - return txnsJSON, classesJSON, nil + abiLengthsJSON, err := json.Marshal(abiLengthsJSONs) + if err != nil { + return nil, nil, err + } + + return txnsJSON, classesJSON, abiLengthsJSON, nil }