Skip to content

Commit

Permalink
refactor: gateway compiler handle declare tx
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware committed Jul 11, 2024
1 parent e470743 commit f8e4a28
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 59 deletions.
79 changes: 52 additions & 27 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, CasmContractEntryPoints,
};
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::RPCDeclareTransaction;
use starknet_sierra_compile::compile::compile_sierra_to_casm;
Expand All @@ -29,44 +30,40 @@ impl GatewayCompiler {
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
/// class wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(
pub fn handle_declare_tx(
&self,
declare_tx: &RPCDeclareTransaction,
) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);
let rpc_contract_class = &tx.contract_class;
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);

// Compile Sierra to Casm.
let casm_contract_class = self.compile(cairo_lang_contract_class)?;

validate_compiled_class_hash(&casm_contract_class, tx.compiled_class_hash)?;
self.validate_casm_class(&casm_contract_class)?;

build_result_class_info(
casm_contract_class,
rpc_contract_class.sierra_program.len(),
rpc_contract_class.abi.len(),
)
}

/// TODO(Arni): Pass the compilation args from the config.
fn compile(
&self,
cairo_lang_contract_class: CairoLangContractClass,
) -> Result<CasmContractClass, GatewayError> {
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
match catch_unwind_result {
Ok(compilation_result) => Ok(compilation_result?),
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic))
}
};
self.validate_casm_class(&casm_contract_class)?;

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
}

// TODO(Arni): Add test.
Expand Down Expand Up @@ -101,3 +98,31 @@ fn supported_builtins() -> &'static Vec<String> {
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

/// Returns a [`ClassInfo`] struct from the compiled contract class.
fn build_result_class_info(
casm_contract_class: CasmContractClass,
sierra_program_len: usize,
abi_len: usize,
) -> GatewayResult<ClassInfo> {
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(&blockifier_contract_class, sierra_program_len, abi_len)?;
Ok(class_info)
}

/// Validates that the compiled class hash of the compiled contract class matches the supplied
/// compiled class hash.
fn validate_compiled_class_hash(
casm_contract_class: &CasmContractClass,
suppled_compiled_class_hash: CompiledClassHash,
) -> Result<(), GatewayError> {
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
if compiled_class_hash != suppled_compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: suppled_compiled_class_hash,
hash_result: compiled_class_hash,
});
}
Ok(())
}
37 changes: 16 additions & 21 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use assert_matches::assert_matches;
use blockifier::execution::contract_class::ContractClass;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use mempool_test_utils::starknet_api_test_utils::declare_tx;
use mempool_test_utils::starknet_api_test_utils::{
compiled_class_hash, contract_class, declare_tx,
};
use rstest::{fixture, rstest};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;

use crate::compilation::GatewayCompiler;
use crate::compilation::{validate_compiled_class_hash, GatewayCompiler};
use crate::errors::GatewayError;

#[fixture]
Expand All @@ -17,17 +20,12 @@ fn gateway_compiler() -> GatewayCompiler {

#[rstest]
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let casm_contract_class =
gateway_compiler.compile(into_contract_class_for_compilation(&contract_class())).unwrap();
let expected_hash_result = compiled_class_hash();
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.compile_contract_class(&declare_tx);
let result = validate_compiled_class_hash(&casm_contract_class, supplied_hash);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
Expand All @@ -37,15 +35,12 @@ fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: G

#[rstest]
fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
// Truncate the sierra program to trigger an error.
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);
// Create a currupted contract class.
let mut contract_class = contract_class();
contract_class.sierra_program = contract_class.sierra_program[..100].to_vec();

let result = gateway_compiler.compile_contract_class(&declare_tx);
let cairo_lang_contract_class = into_contract_class_for_compilation(&contract_class);
let result = gateway_compiler.compile(cairo_lang_contract_class);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
Expand All @@ -55,15 +50,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
}

#[rstest]
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
fn test_handle_declare_tx(gateway_compiler: GatewayCompiler) {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
let class_info = gateway_compiler.handle_declare_tx(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ fn process_tx(
// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => {
Some(gateway_compiler.compile_contract_class(declare_tx)?)
Some(gateway_compiler.handle_declare_tx(declare_tx)?)
}
_ => None,
};
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn calculate_hash(
) -> TransactionHash {
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => {
Some(gateway_compiler.compile_contract_class(declare_tx).unwrap())
Some(gateway_compiler.handle_declare_tx(declare_tx).unwrap())
}
_ => None,
};
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn test_stateful_tx_validator(
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: GatewayCompilerConfig {} }
.compile_contract_class(declare_tx)
.handle_declare_tx(declare_tx)
.unwrap(),
),
_ => None,
Expand Down
16 changes: 13 additions & 3 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,21 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
)
}

pub fn declare_tx() -> RPCTransaction {
/// Get the contract class used for testing.
pub fn contract_class() -> ContractClass {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
}

/// Get the compiled class hash corresponding to the contract class used for testing.
pub fn compiled_class_hash() -> CompiledClassHash {
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
}

pub fn declare_tx() -> RPCTransaction {
let contract_class = contract_class();
let compiled_class_hash = compiled_class_hash();

let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let account_address = account_contract.get_instance_address(0);
Expand Down
10 changes: 5 additions & 5 deletions crates/starknet_sierra_compile/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@ use cairo_lang_starknet_classes::contract_class::{
};
use cairo_lang_utils::bigint::BigUintAsHex;
use starknet_api::rpc_transaction::{
ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType,
ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType,
};
use starknet_api::state::EntryPoint as StarknetApiEntryPoint;
use starknet_types_core::felt::Felt;

/// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi`
/// field is None as it is not relevant for the compilation.
pub fn into_contract_class_for_compilation(
starknet_api_contract_class: &StarknetApiContractClass,
rpc_contract_class: &RpcContractClass,
) -> CairoLangContractClass {
let sierra_program =
starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
let entry_points_by_type =
into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type);
into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type);

CairoLangContractClass {
sierra_program,
sierra_program_debug_info: None,
contract_class_version: starknet_api_contract_class.contract_class_version.clone(),
contract_class_version: rpc_contract_class.contract_class_version.clone(),
entry_points_by_type,
abi: None,
}
Expand Down

0 comments on commit f8e4a28

Please sign in to comment.