Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: code to compilation file from gateway #344

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::panic;
use std::sync::OnceLock;

use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1};
use blockifier::execution::execution_utils::felt_to_stark_felt;
use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, CasmContractEntryPoints,
};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::RPCDeclareTransaction;
use starknet_sierra_compile::compile::compile_sierra_to_casm;
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;

use crate::errors::{GatewayError, GatewayResult};
use crate::utils::is_subsequence;

#[cfg(test)]
#[path = "compilation_test.rs"]
mod compilation_test;

/// Formats the contract class for compilation, compiles it, and returns the compiled contract class
/// wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);

// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
validate_casm_class(&casm_contract_class)?;

let hash_result =
CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash()));
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
}

// TODO(Arni): Add to a config.
// TODO(Arni): Use the Builtin enum from Starknet-api, and explicitly tag each builtin as supported
// or unsupported so that the compiler would alert us on new builtins.
fn supported_builtins() -> &'static Vec<String> {
static SUPPORTED_BUILTINS: OnceLock<Vec<String>> = OnceLock::new();
SUPPORTED_BUILTINS.get_or_init(|| {
// The OS expects this order for the builtins.
const SUPPORTED_BUILTIN_NAMES: [&str; 7] =
["pedersen", "range_check", "ecdsa", "bitwise", "ec_op", "poseidon", "segment_arena"];
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

// TODO(Arni): Add test.
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
}
Ok(())
}
64 changes: 64 additions & 0 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use assert_matches::assert_matches;
use blockifier::execution::contract_class::ContractClass;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_sierra_compile::errors::CompilationUtilError;
use test_utils::starknet_api_test_utils::declare_tx;

use crate::compilation::compile_contract_class;
use crate::errors::GatewayError;

#[test]
fn test_compile_contract_class_compiled_class_hash_missmatch() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
);
}

#[test]
fn test_compile_contract_class_bad_sierra() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
// Truncate the sierra program to trigger an error.
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
AllowedLibfuncsError::SierraProgramError
))
)
}

#[test]
fn test_compile_contract_class() {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = compile_contract_class(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
}
2 changes: 2 additions & 0 deletions crates/gateway/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pub enum GatewayError {
UnsupportedBuiltins { builtins: Vec<String>, supported_builtins: Vec<String> },
}

pub type GatewayResult<T> = Result<T, GatewayError>;

impl IntoResponse for GatewayError {
// TODO(Arni, 1/5/2024): Be more fine tuned about the error response. Not all Gateway errors
// are internal server errors.
Expand Down
93 changes: 5 additions & 88 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
use std::clone::Clone;
use std::net::SocketAddr;
use std::panic;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;

use async_trait::async_trait;
use axum::extract::State;
use axum::routing::{get, post};
use axum::{Json, Router};
use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1};
use blockifier::execution::execution_utils::felt_to_stark_felt;
use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, CasmContractEntryPoints,
};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::transaction::TransactionHash;
use starknet_mempool_infra::component_runner::{ComponentRunner, ComponentStartError};
use starknet_mempool_types::communication::SharedMempoolClient;
use starknet_mempool_types::mempool_types::{Account, MempoolInput};
use starknet_sierra_compile::compile::compile_sierra_to_casm;
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;

use crate::compilation::compile_contract_class;
use crate::config::{GatewayConfig, GatewayNetworkConfig, RpcStateReaderConfig};
use crate::errors::{GatewayError, GatewayRunError};
use crate::errors::{GatewayError, GatewayResult, GatewayRunError};
use crate::rpc_state_reader::RpcStateReaderFactory;
use crate::state_reader::StateReaderFactory;
use crate::stateful_transaction_validator::StatefulTransactionValidator;
use crate::stateless_transaction_validator::StatelessTransactionValidator;
use crate::utils::{external_tx_to_thin_tx, get_sender_address, is_subsequence};
use crate::utils::{external_tx_to_thin_tx, get_sender_address};

#[cfg(test)]
#[path = "gateway_test.rs"]
pub mod gateway_test;

pub type GatewayResult<T> = Result<T, GatewayError>;

pub struct Gateway {
pub config: GatewayConfig,
app_state: AppState,
Expand Down Expand Up @@ -145,47 +134,6 @@ fn process_tx(
})
}

/// Formats the contract class for compilation, compiles it, and returns the compiled contract class
/// wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);

// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
validate_casm_class(&casm_contract_class)?;

let hash_result =
CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash()));
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
}

pub fn create_gateway(
config: GatewayConfig,
rpc_state_reader_config: RpcStateReaderConfig,
Expand All @@ -203,34 +151,3 @@ impl ComponentRunner for Gateway {
Ok(())
}
}

// TODO(Arni): Add to a config.
// TODO(Arni): Use the Builtin enum from Starknet-api, and explicitly tag each builtin as supported
// or unsupported so that the compiler would alert us on new builtins.
fn supported_builtins() -> &'static Vec<String> {
static SUPPORTED_BUILTINS: OnceLock<Vec<String>> = OnceLock::new();
SUPPORTED_BUILTINS.get_or_init(|| {
// The OS expects this order for the builtins.
const SUPPORTED_BUILTIN_NAMES: [&str; 7] =
["pedersen", "range_check", "ecdsa", "bitwise", "ec_op", "poseidon", "segment_arena"];
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

// TODO(Arni): Add test.
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
}
Ok(())
}
62 changes: 1 addition & 61 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
use std::sync::Arc;

use assert_matches::assert_matches;
use axum::body::{Bytes, HttpBody};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use blockifier::execution::contract_class::ContractClass;
use blockifier::test_utils::CairoVersion;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use rstest::{fixture, rstest};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::transaction::TransactionHash;
use starknet_mempool::communication::create_mempool_server;
use starknet_mempool::mempool::Mempool;
use starknet_mempool_infra::component_server::ComponentServerStarter;
use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAndResponseSender};
use starknet_sierra_compile::errors::CompilationUtilError;
use test_utils::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx};
use tokio::sync::mpsc::channel;
use tokio::task;

use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::errors::GatewayError;
use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
Expand Down Expand Up @@ -110,60 +104,6 @@ async fn test_add_tx(
assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap());
}

#[test]
fn test_compile_contract_class_compiled_class_hash_missmatch() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
);
}

#[test]
fn test_compile_contract_class_bad_sierra() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
// Truncate the sierra program to trigger an error.
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
AllowedLibfuncsError::SierraProgramError
))
)
}

#[test]
fn test_compile_contract_class() {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = compile_contract_class(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
}

async fn to_bytes(res: Response) -> Bytes {
res.into_body().collect().await.unwrap().to_bytes()
}
Expand Down
1 change: 1 addition & 0 deletions crates/gateway/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod communication;
pub mod compilation;
pub mod compiler_version;
pub mod config;
pub mod errors;
Expand Down
Loading
Loading