Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add supported builtins config #392

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 59 additions & 62 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::panic;
use std::sync::OnceLock;

use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1};
use cairo_lang_starknet_classes::casm_contract_class::{
Expand All @@ -11,80 +10,78 @@ use starknet_sierra_compile::compile::compile_sierra_to_casm;
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;

use crate::compilation_config::GatewayCompilerConfig;
use crate::errors::{GatewayError, GatewayResult};
use crate::utils::is_subsequence;

#[cfg(test)]
#[path = "compilation_test.rs"]
mod compilation_test;

/// Formats the contract class for compilation, compiles it, and returns the compiled contract class
/// wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);
pub struct GatewayCompiler {
pub config: GatewayCompilerConfig,
}

// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
validate_casm_class(&casm_contract_class)?;
impl GatewayCompiler {
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
/// class wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(
&self,
declare_tx: &RPCDeclareTransaction,
) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}
// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
self.validate_casm_class(&casm_contract_class)?;

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
}
let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// TODO(Arni): Add to a config.
// TODO(Arni): Use the Builtin enum from Starknet-api, and explicitly tag each builtin as supported
// or unsupported so that the compiler would alert us on new builtins.
fn supported_builtins() -> &'static Vec<String> {
static SUPPORTED_BUILTINS: OnceLock<Vec<String>> = OnceLock::new();
SUPPORTED_BUILTINS.get_or_init(|| {
// The OS expects this order for the builtins.
const SUPPORTED_BUILTIN_NAMES: [&str; 7] =
["pedersen", "range_check", "ecdsa", "bitwise", "ec_op", "poseidon", "segment_arena"];
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}
// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
}

// TODO(Arni): Add test.
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());
// TODO(Arni): Add test.
fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator =
external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, &self.config.supported_builtins) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: self.config.supported_builtins.to_vec(),
});
}
}
Ok(())
}
Ok(())
}
36 changes: 36 additions & 0 deletions crates/gateway/src/compilation_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::collections::BTreeMap;
use std::sync::OnceLock;

use papyrus_config::dumping::SerializeConfig;
use papyrus_config::{ParamPath, SerializedParam};
use serde::{Deserialize, Serialize};
use validator::Validate;

#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
pub struct GatewayCompilerConfig {
pub supported_builtins: Vec<String>,
}

impl Default for GatewayCompilerConfig {
fn default() -> Self {
Self { supported_builtins: supported_builtins().clone() }
}
}

impl SerializeConfig for GatewayCompilerConfig {
fn dump(&self) -> BTreeMap<ParamPath, SerializedParam> {
todo!("Impelement GatewayCompilerConfig::dump");
}
}

// TODO(Arni): Use the Builtin enum from Starknet-api, and explicitly tag each builtin as supported
// or unsupported so that the compiler would alert us on new builtins.
fn supported_builtins() -> &'static Vec<String> {
static SUPPORTED_BUILTINS: OnceLock<Vec<String>> = OnceLock::new();
SUPPORTED_BUILTINS.get_or_init(|| {
// The OS expects this order for the builtins.
const SUPPORTED_BUILTIN_NAMES: [&str; 7] =
["pedersen", "range_check", "ecdsa", "bitwise", "ec_op", "poseidon", "segment_arena"];
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}
9 changes: 5 additions & 4 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_sierra_compile::errors::CompilationUtilError;

use crate::compilation::compile_contract_class;
use crate::compilation::GatewayCompiler;
use crate::errors::GatewayError;

#[test]
Expand All @@ -21,7 +21,7 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() {
tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
let result = GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
Expand All @@ -39,7 +39,7 @@ fn test_compile_contract_class_bad_sierra() {
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
let result = GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
Expand All @@ -57,7 +57,8 @@ fn test_compile_contract_class() {
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = compile_contract_class(&declare_tx).unwrap();
let class_info =
GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
Expand Down
3 changes: 3 additions & 0 deletions crates/gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ use starknet_api::core::{ChainId, ContractAddress, Nonce};
use starknet_types_core::felt::Felt;
use validator::Validate;

use crate::compilation_config::GatewayCompilerConfig;
use crate::compiler_version::VersionId;

#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, PartialEq)]
pub struct GatewayConfig {
pub network_config: GatewayNetworkConfig,
pub stateless_tx_validator_config: StatelessTransactionValidatorConfig,
pub stateful_tx_validator_config: StatefulTransactionValidatorConfig,
pub compiler_config: GatewayCompilerConfig,
}

impl SerializeConfig for GatewayConfig {
Expand All @@ -30,6 +32,7 @@ impl SerializeConfig for GatewayConfig {
self.stateful_tx_validator_config.dump(),
"stateful_tx_validator_config",
),
append_sub_config_name(self.compiler_config.dump(), "compiler_config"),
]
.into_iter()
.flatten()
Expand Down
7 changes: 5 additions & 2 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_mempool_types::communication::SharedMempoolClient;
use starknet_mempool_types::mempool_types::{Account, MempoolInput};
use tracing::info;

use crate::compilation::compile_contract_class;
use crate::compilation::GatewayCompiler;
use crate::config::{GatewayConfig, GatewayNetworkConfig, RpcStateReaderConfig};
use crate::errors::{GatewayError, GatewayResult, GatewayRunError};
use crate::rpc_state_reader::RpcStateReaderFactory;
Expand Down Expand Up @@ -120,7 +120,10 @@ fn process_tx(

// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx)?),
RPCTransaction::Declare(declare_tx) => {
let gateway_compiler = GatewayCompiler { config: Default::default() };
Some(gateway_compiler.compile_contract_class(declare_tx)?)
}
_ => None,
};

Expand Down
8 changes: 6 additions & 2 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::sync::mpsc::channel;
use tokio::task;

use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient};
use crate::gateway::{add_tx, AppState, GatewayCompiler, SharedMempoolClient};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
TestStateReaderFactory,
Expand Down Expand Up @@ -110,7 +110,11 @@ async fn to_bytes(res: Response) -> Bytes {

fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash {
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx).unwrap()),
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: Default::default() }
.compile_contract_class(declare_tx)
.unwrap(),
),
_ => None,
};

Expand Down
1 change: 1 addition & 0 deletions crates/gateway/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod communication;
mod compilation;
mod compilation_config;
mod compiler_version;
pub mod config;
pub mod errors;
Expand Down
7 changes: 5 additions & 2 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_api::felt;
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::transaction::TransactionHash;

use crate::compilation::compile_contract_class;
use crate::compilation::GatewayCompiler;
use crate::config::StatefulTransactionValidatorConfig;
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader_test_utils::{
Expand Down Expand Up @@ -81,7 +81,10 @@ fn test_stateful_tx_validator(
},
};
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx).unwrap()),
RPCTransaction::Declare(declare_tx) => {
let gateway_compiler = GatewayCompiler { config: Default::default() };
Some(gateway_compiler.compile_contract_class(declare_tx).unwrap())
}
_ => None,
};

Expand Down
8 changes: 7 additions & 1 deletion crates/tests-integration/src/integration_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ async fn create_gateway_config() -> GatewayConfig {
let socket = get_available_socket().await;
let network_config = GatewayNetworkConfig { ip: socket.ip(), port: socket.port() };
let stateful_tx_validator_config = StatefulTransactionValidatorConfig::create_for_testing();
let gateway_compiler_config = Default::default();

GatewayConfig { network_config, stateless_tx_validator_config, stateful_tx_validator_config }
GatewayConfig {
network_config,
stateless_tx_validator_config,
stateful_tx_validator_config,
compiler_config: gateway_compiler_config,
}
}

pub async fn create_config(rpc_server_addr: SocketAddr) -> MempoolNodeConfig {
Expand Down
Loading