Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: gateway compiler handle declare tx #439

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 55 additions & 45 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, CasmContractEntryPoints,
};
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::RPCDeclareTransaction;
use starknet_sierra_compile::compile::compile_sierra_to_casm;
Expand All @@ -19,6 +20,7 @@ use crate::utils::is_subsequence;
#[path = "compilation_test.rs"]
mod compilation_test;

// TODO(Arni): Pass the compiler with dependancy injection.
#[derive(Clone)]
pub struct GatewayCompiler {
#[allow(dead_code)]
Expand All @@ -29,64 +31,56 @@ impl GatewayCompiler {
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
/// class wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(
pub fn process_declare_tx(
&self,
declare_tx: &RPCDeclareTransaction,
) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);
let rpc_contract_class = &tx.contract_class;
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);

// Compile Sierra to Casm.
let casm_contract_class = self.compile(cairo_lang_contract_class)?;

validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;
validate_casm_class(&casm_contract_class)?;

Ok(ClassInfo::new(
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
rpc_contract_class.sierra_program.len(),
rpc_contract_class.abi.len(),
)?)
}

// TODO(Arni): Pass the compilation args from the config.
fn compile(
&self,
cairo_lang_contract_class: CairoLangContractClass,
) -> Result<CasmContractClass, GatewayError> {
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
self.validate_casm_class(&casm_contract_class)?;
let casm_contract_class =
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
Ok(casm_contract_class)
}
}

// TODO(Arni): Add test.
fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator =
external.iter().chain(l1_handler.iter()).chain(constructor.iter());
// TODO(Arni): Add test.
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
Ok(())
}
Ok(())
}

// TODO(Arni): Add to a config.
Expand All @@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec<String> {
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

/// Validates that the compiled class hash of the compiled contract class matches the supplied
/// compiled class hash.
fn validate_compiled_class_hash(
casm_contract_class: &CasmContractClass,
supplied_compiled_class_hash: &CompiledClassHash,
) -> Result<(), GatewayError> {
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
if compiled_class_hash != *supplied_compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: *supplied_compiled_class_hash,
hash_result: compiled_class_hash,
});
}
Ok(())
}
20 changes: 10 additions & 10 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ fn gateway_compiler() -> GatewayCompiler {
GatewayCompiler { config: Default::default() }
}

// TODO(Arni): Redesign this test once the compiler is passed with dependancy injection.
#[rstest]
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let expected_hash = tx.compiled_class_hash;
let wrong_supplied_hash = CompiledClassHash::default();
tx.compiled_class_hash = wrong_supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.compile_contract_class(&declare_tx);
let result = gateway_compiler.process_declare_tx(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
if supplied == wrong_supplied_hash && hash_result == expected_hash
);
}

Expand All @@ -45,7 +45,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.compile_contract_class(&declare_tx);
let result = gateway_compiler.process_declare_tx(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
Expand All @@ -55,15 +55,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
}

#[rstest]
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
fn test_process_declare_tx_success(gateway_compiler: GatewayCompiler) {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn process_tx(
// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => {
Some(gateway_compiler.compile_contract_class(declare_tx)?)
Some(gateway_compiler.process_declare_tx(declare_tx)?)
}
_ => None,
};
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn test_stateful_tx_validator(
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: GatewayCompilerConfig {} }
.compile_contract_class(declare_tx)
.process_declare_tx(declare_tx)
.unwrap(),
),
_ => None,
Expand Down
16 changes: 13 additions & 3 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,21 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
)
}

pub fn declare_tx() -> RPCTransaction {
/// Get the contract class used for testing.
pub fn contract_class() -> ContractClass {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
}

/// Get the compiled class hash corresponding to the contract class used for testing.
pub fn compiled_class_hash() -> CompiledClassHash {
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
}

pub fn declare_tx() -> RPCTransaction {
let contract_class = contract_class();
let compiled_class_hash = compiled_class_hash();

let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let account_address = account_contract.get_instance_address(0);
Expand Down
10 changes: 5 additions & 5 deletions crates/starknet_sierra_compile/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@ use cairo_lang_starknet_classes::contract_class::{
};
use cairo_lang_utils::bigint::BigUintAsHex;
use starknet_api::rpc_transaction::{
ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType,
ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType,
};
use starknet_api::state::EntryPoint as StarknetApiEntryPoint;
use starknet_types_core::felt::Felt;

/// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi`
/// field is None as it is not relevant for the compilation.
pub fn into_contract_class_for_compilation(
starknet_api_contract_class: &StarknetApiContractClass,
rpc_contract_class: &RpcContractClass,
) -> CairoLangContractClass {
let sierra_program =
starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
let entry_points_by_type =
into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type);
into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type);

CairoLangContractClass {
sierra_program,
sierra_program_debug_info: None,
contract_class_version: starknet_api_contract_class.contract_class_version.clone(),
contract_class_version: rpc_contract_class.contract_class_version.clone(),
entry_points_by_type,
abi: None,
}
Expand Down
Loading