diff --git a/config/default_config.json b/config/default_config.json index 22889fb1..75bc665e 100644 --- a/config/default_config.json +++ b/config/default_config.json @@ -9,6 +9,16 @@ "privacy": "Public", "value": true }, + "gateway_config.compiler_config.max_bytecode_size": { + "description": "The maximum bytecode size allowed for a contract.", + "privacy": "Public", + "value": 81920 + }, + "gateway_config.compiler_config.max_raw_class_size": { + "description": "The maximum raw class size allowed for a contract.", + "privacy": "Public", + "value": 4089446 + }, "gateway_config.network_config.ip": { "description": "The gateway server ip.", "privacy": "Public", @@ -49,11 +59,21 @@ "privacy": "Public", "value": 0 }, + "gateway_config.stateless_tx_validator_config.max_bytecode_size": { + "description": "The maximum bytecode size allowed for a contract.", + "privacy": "Public", + "value": 0 + }, "gateway_config.stateless_tx_validator_config.max_calldata_length": { "description": "Validates that a transaction has calldata length less than or equal to this value.", "privacy": "Public", "value": 0 }, + "gateway_config.stateless_tx_validator_config.max_raw_class_size": { + "description": "The maximum raw class size allowed for a contract.", + "privacy": "Public", + "value": 0 + }, "gateway_config.stateless_tx_validator_config.max_sierra_version.major": { "description": "The major version of the configuration.", "privacy": "Public", diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index ff1a126d..37e49ca5 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -19,6 +19,7 @@ use crate::utils::is_subsequence; #[path = "compilation_test.rs"] mod compilation_test; +// TODO(Define a function for `compile_contract_class` - which ignores the `config` parameter). #[derive(Clone)] pub struct GatewayCompiler { pub config: GatewayCompilerConfig, @@ -47,7 +48,7 @@ impl GatewayCompiler { return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); } }; - self.validate_casm_class(&casm_contract_class)?; + self.validate_casm(&casm_contract_class)?; let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); if hash_result != tx.compiled_class_hash { @@ -68,6 +69,12 @@ impl GatewayCompiler { Ok(class_info) } + fn validate_casm(&self, casm_contract_class: &CasmContractClass) -> Result<(), GatewayError> { + self.validate_casm_class(casm_contract_class)?; + self.validate_casm_class_size(casm_contract_class)?; + Ok(()) + } + // TODO(Arni): Add test. fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> { let CasmContractEntryPoints { external, l1_handler, constructor } = @@ -86,6 +93,30 @@ impl GatewayCompiler { } Ok(()) } + + fn validate_casm_class_size( + &self, + casm_contract_class: &CasmContractClass, + ) -> Result<(), GatewayError> { + let bytecode_size = casm_contract_class.bytecode.len(); + if bytecode_size > self.config.max_bytecode_size { + return Err(GatewayError::CasmBytecodeSizeTooLarge { + bytecode_size, + max_bytecode_size: self.config.max_bytecode_size, + }); + } + let contract_class_object_size = serde_json::to_string(&casm_contract_class) + .expect("Unexpected error serializing Casm contract class.") + .len(); + if contract_class_object_size > self.config.max_raw_class_size { + return Err(GatewayError::CasmContractClassObjectSizeTooLarge { + contract_class_object_size, + max_contract_class_object_size: self.config.max_raw_class_size, + }); + } + + Ok(()) + } } // TODO(Arni): Add to a config. diff --git a/crates/gateway/src/compilation_config.rs b/crates/gateway/src/compilation_config.rs index 155b81cd..c4c2fadd 100644 --- a/crates/gateway/src/compilation_config.rs +++ b/crates/gateway/src/compilation_config.rs @@ -1,15 +1,37 @@ use std::collections::BTreeMap; -use papyrus_config::dumping::SerializeConfig; -use papyrus_config::{ParamPath, SerializedParam}; +use papyrus_config::dumping::{ser_param, SerializeConfig}; +use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use serde::{Deserialize, Serialize}; use validator::Validate; -#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, PartialEq)] -pub struct GatewayCompilerConfig {} +#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)] +pub struct GatewayCompilerConfig { + pub max_bytecode_size: usize, + pub max_raw_class_size: usize, +} + +impl Default for GatewayCompilerConfig { + fn default() -> Self { + Self { max_bytecode_size: 81920, max_raw_class_size: 4089446 } + } +} impl SerializeConfig for GatewayCompilerConfig { fn dump(&self) -> BTreeMap { - BTreeMap::new() + BTreeMap::from_iter([ + ser_param( + "max_bytecode_size", + &self.max_bytecode_size, + "The maximum bytecode size allowed for a contract.", + ParamPrivacyInput::Public, + ), + ser_param( + "max_raw_class_size", + &self.max_raw_class_size, + "The maximum raw class size allowed for a contract.", + ParamPrivacyInput::Public, + ), + ]) } } diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index 3cf682d9..cc53b869 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -2,11 +2,13 @@ use assert_matches::assert_matches; use blockifier::execution::contract_class::ContractClass; use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; use mempool_test_utils::starknet_api_test_utils::declare_tx; +use rstest::rstest; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_sierra_compile::errors::CompilationUtilError; use crate::compilation::GatewayCompiler; +use crate::compilation_config::GatewayCompilerConfig; use crate::errors::GatewayError; #[test] @@ -21,7 +23,10 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() { tx.compiled_class_hash = supplied_hash; let declare_tx = RPCDeclareTransaction::V3(tx); - let result = GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx); + let result = GatewayCompiler { + config: GatewayCompilerConfig { max_bytecode_size: 4800, max_raw_class_size: 111037 }, + } + .compile_contract_class(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompiledClassHashMismatch { supplied, hash_result } @@ -29,6 +34,50 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() { ); } +#[rstest] +#[case::bytecode_size( + GatewayCompilerConfig { max_bytecode_size: 1, max_raw_class_size: usize::MAX}, + GatewayError::CasmBytecodeSizeTooLarge { bytecode_size: 4800, max_bytecode_size: 1 } +)] +#[case::raw_class_size( + GatewayCompilerConfig { max_bytecode_size: usize::MAX, max_raw_class_size: 1}, + GatewayError::CasmContractClassObjectSizeTooLarge { + contract_class_object_size: 111037, max_contract_class_object_size: 1 + } +)] +fn test_compile_contract_class_size_validation( + #[case] sierra_to_casm_compilation_config: GatewayCompilerConfig, + #[case] expected_error: GatewayError, +) { + let declare_tx = match declare_tx() { + RPCTransaction::Declare(declare_tx) => declare_tx, + _ => panic!("Invalid transaction type"), + }; + + let gateway_compiler = GatewayCompiler { config: sierra_to_casm_compilation_config }; + let result = gateway_compiler.compile_contract_class(&declare_tx); + if let GatewayError::CasmBytecodeSizeTooLarge { + bytecode_size: expected_bytecode_size, .. + } = expected_error + { + assert_matches!( + result.unwrap_err(), + GatewayError::CasmBytecodeSizeTooLarge { bytecode_size, .. } + if bytecode_size == expected_bytecode_size + ) + } else if let GatewayError::CasmContractClassObjectSizeTooLarge { + contract_class_object_size: expected_contract_class_object_size, + .. + } = expected_error + { + assert_matches!( + result.unwrap_err(), + GatewayError::CasmContractClassObjectSizeTooLarge { contract_class_object_size, .. } + if contract_class_object_size == expected_contract_class_object_size + ) + } +} + #[test] fn test_compile_contract_class_bad_sierra() { let mut tx = assert_matches!( @@ -57,8 +106,11 @@ fn test_compile_contract_class() { let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; let contract_class = &declare_tx_v3.contract_class; - let class_info = - GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx).unwrap(); + let class_info = GatewayCompiler { + config: GatewayCompilerConfig { max_bytecode_size: 4800, max_raw_class_size: 111037 }, + } + .compile_contract_class(&declare_tx) + .unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); assert_eq!(class_info.abi_length(), contract_class.abi.len()); diff --git a/crates/gateway/src/config.rs b/crates/gateway/src/config.rs index b66e87a0..365db2c9 100644 --- a/crates/gateway/src/config.rs +++ b/crates/gateway/src/config.rs @@ -125,6 +125,18 @@ impl SerializeConfig for StatelessTransactionValidatorConfig { value.", ParamPrivacyInput::Public, ), + ser_param( + "max_bytecode_size", + &self.max_bytecode_size, + "The maximum bytecode size allowed for a contract.", + ParamPrivacyInput::Public, + ), + ser_param( + "max_raw_class_size", + &self.max_raw_class_size, + "The maximum raw class size allowed for a contract.", + ParamPrivacyInput::Public, + ), ]); vec![ members, diff --git a/crates/gateway/src/errors.rs b/crates/gateway/src/errors.rs index 24d5ef74..9517a969 100644 --- a/crates/gateway/src/errors.rs +++ b/crates/gateway/src/errors.rs @@ -19,6 +19,19 @@ use crate::compiler_version::{VersionId, VersionIdError}; /// Errors directed towards the end-user, as a result of gateway requests. #[derive(Debug, Error)] pub enum GatewayError { + #[error( + "Cannot declare Casm contract class with bytecode size of {bytecode_size}; max allowed \ + size: {max_bytecode_size}." + )] + CasmBytecodeSizeTooLarge { bytecode_size: usize, max_bytecode_size: usize }, + #[error( + "Cannot declare Casm contract class with size of {contract_class_object_size}; max \ + allowed size: {max_contract_class_object_size}." + )] + CasmContractClassObjectSizeTooLarge { + contract_class_object_size: usize, + max_contract_class_object_size: usize, + }, #[error(transparent)] CompilationError(#[from] CompilationUtilError), #[error( diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index e9e70d28..a5af1b68 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -53,7 +53,9 @@ pub fn app_state( stateful_tx_validator: Arc::new(StatefulTransactionValidator { config: StatefulTransactionValidatorConfig::create_for_testing(), }), - gateway_compiler: GatewayCompiler { config: GatewayCompilerConfig {} }, + gateway_compiler: GatewayCompiler { + config: GatewayCompilerConfig { max_bytecode_size: 10000, max_raw_class_size: 1000000 }, + }, state_reader_factory: Arc::new(state_reader_factory), mempool_client, } @@ -113,9 +115,14 @@ async fn to_bytes(res: Response) -> Bytes { fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash { let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => Some( - GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) - .unwrap(), + GatewayCompiler { + config: GatewayCompilerConfig { + max_bytecode_size: 4800, + max_raw_class_size: 111037, + }, + } + .compile_contract_class(declare_tx) + .unwrap(), ), _ => None, }; diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index 790d45ed..13b8990e 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -81,11 +81,15 @@ fn test_stateful_tx_validator( }, }; let optional_class_info = match &external_tx { - RPCTransaction::Declare(declare_tx) => Some( - GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) - .unwrap(), - ), + RPCTransaction::Declare(declare_tx) => { + let gateway_compiler = GatewayCompiler { + config: GatewayCompilerConfig { + max_bytecode_size: 4800, + max_raw_class_size: 111037, + }, + }; + Some(gateway_compiler.compile_contract_class(declare_tx).unwrap()) + } _ => None, };