Skip to content

Commit

Permalink
refactor: refactor the stateful tx validator to get executable tx as …
Browse files Browse the repository at this point in the history
…input
  • Loading branch information
ArniStarkware committed Sep 15, 2024
1 parent 1b40d77 commit e503534
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 167 deletions.
2 changes: 2 additions & 0 deletions crates/gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ pub struct StatefulTransactionValidatorConfig {
pub max_nonce_for_validation_skip: Nonce,
pub validate_max_n_steps: u32,
pub max_recursion_depth: usize,
// TODO(Arni): Move this member out of the stateful transaction validator config. Move it into
// the gateway config. This is used during the transalation from external_tx to executable_tx.
pub chain_info: ChainInfo,
}

Expand Down
13 changes: 1 addition & 12 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,10 @@ fn process_tx(
}
}

let optional_class_info = match executable_tx {
starknet_api::executable_transaction::Transaction::Declare(tx) => {
Some(tx.class_info.try_into().map_err(|e| {
error!("Failed to convert Starknet API ClassInfo to Blockifier ClassInfo: {:?}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error.".to_owned() }
})?)
}
_ => None,
};

let validator = stateful_tx_validator.instantiate_validator(state_reader_factory)?;
// TODO(Yael 31/7/24): refactor after IntrnalTransaction is ready, delete validate_info and
// compute all the info outside of run_validate.
let validate_info =
stateful_tx_validator.run_validate(&copy_of_rpc_tx, optional_class_info, validator)?;
let validate_info = stateful_tx_validator.run_validate(&executable_tx, validator)?;

// TODO(Arni): Add the Sierra and the Casm to the mempool input.
Ok(MempoolInput {
Expand Down
142 changes: 137 additions & 5 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,51 @@ use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use blockifier::execution::contract_class::ClassInfo;
use blockifier::test_utils::CairoVersion;
use blockifier::transaction::account_transaction::AccountTransaction;
use blockifier::transaction::transactions::{
DeclareTransaction as BlockifierDeclareTransaction,
DeployAccountTransaction as BlockifierDeployAccountTransaction,
InvokeTransaction as BlockifierInvokeTransaction,
};
use mempool_test_utils::starknet_api_test_utils::{create_executable_tx, declare_tx, invoke_tx};
use mockall::predicate::eq;
use starknet_api::core::{CompiledClassHash, ContractAddress};
use starknet_api::rpc_transaction::{RpcDeclareTransaction, RpcTransaction};
use starknet_api::transaction::{TransactionHash, ValidResourceBounds};
use starknet_api::core::{
calculate_contract_address,
ChainId,
ClassHash,
CompiledClassHash,
ContractAddress,
};
use starknet_api::rpc_transaction::{
RpcDeclareTransaction,
RpcDeployAccountTransaction,
RpcInvokeTransaction,
RpcTransaction,
};
use starknet_api::transaction::{
DeclareTransaction,
DeclareTransactionV3,
DeployAccountTransaction,
DeployAccountTransactionV3,
InvokeTransactionV3,
TransactionHash,
TransactionHasher,
ValidResourceBounds,
};
use starknet_mempool_types::communication::MockMempoolClient;
use starknet_mempool_types::mempool_types::{Account, AccountState, MempoolInput};
use starknet_sierra_compile::config::SierraToCasmCompilationConfig;
use tracing::error;

use crate::compilation::GatewayCompiler;
use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::errors::GatewaySpecError;
use crate::errors::{GatewaySpecError, StatefulTransactionValidatorResult};
use crate::gateway::{add_tx, AppState, SharedMempoolClient};
use crate::state_reader_test_utils::{local_test_state_reader_factory, TestStateReaderFactory};
use crate::stateful_transaction_validator::StatefulTransactionValidator;
use crate::stateless_transaction_validator::StatelessTransactionValidator;
use crate::utils::rpc_tx_to_account_tx;

pub fn app_state(
mempool_client: SharedMempoolClient,
Expand Down Expand Up @@ -130,3 +157,108 @@ fn calculate_hash(rpc_tx: &RpcTransaction) -> TransactionHash {
.unwrap();
account_tx.tx_hash()
}

// TODO(Arni): Remove this function. Replace with a function that take ownership of RpcTransaction.
fn rpc_tx_to_account_tx(
rpc_tx: &RpcTransaction,
// FIXME(yael 15/4/24): calculate class_info inside the function once compilation code is ready
optional_class_info: Option<ClassInfo>,
chain_id: &ChainId,
) -> StatefulTransactionValidatorResult<AccountTransaction> {
match rpc_tx {
RpcTransaction::Declare(RpcDeclareTransaction::V3(tx)) => {
let declare_tx = DeclareTransaction::V3(DeclareTransactionV3 {
class_hash: ClassHash::default(), /* FIXME(yael 15/4/24): call the starknet-api
* function once ready */
resource_bounds: ValidResourceBounds::AllResources(
rpc_tx.resource_bounds().clone(),
),
tip: tx.tip,
signature: tx.signature.clone(),
nonce: tx.nonce,
compiled_class_hash: tx.compiled_class_hash,
sender_address: tx.sender_address,
nonce_data_availability_mode: tx.nonce_data_availability_mode,
fee_data_availability_mode: tx.fee_data_availability_mode,
paymaster_data: tx.paymaster_data.clone(),
account_deployment_data: tx.account_deployment_data.clone(),
});
let tx_hash = declare_tx
.calculate_transaction_hash(chain_id, &declare_tx.version())
.map_err(|e| {
error!("Failed to calculate tx hash: {}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
let class_info =
optional_class_info.expect("declare transaction should contain class info");
let declare_tx = BlockifierDeclareTransaction::new(declare_tx, tx_hash, class_info)
.map_err(|e| {
error!("Failed to convert declare tx hash to blockifier tx type: {}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
Ok(AccountTransaction::Declare(declare_tx))
}
RpcTransaction::DeployAccount(RpcDeployAccountTransaction::V3(tx)) => {
let deploy_account_tx = DeployAccountTransaction::V3(DeployAccountTransactionV3 {
resource_bounds: ValidResourceBounds::AllResources(
rpc_tx.resource_bounds().clone(),
),
tip: tx.tip,
signature: tx.signature.clone(),
nonce: tx.nonce,
class_hash: tx.class_hash,
contract_address_salt: tx.contract_address_salt,
constructor_calldata: tx.constructor_calldata.clone(),
nonce_data_availability_mode: tx.nonce_data_availability_mode,
fee_data_availability_mode: tx.fee_data_availability_mode,
paymaster_data: tx.paymaster_data.clone(),
});
let contract_address = calculate_contract_address(
deploy_account_tx.contract_address_salt(),
deploy_account_tx.class_hash(),
&deploy_account_tx.constructor_calldata(),
ContractAddress::default(),
)
.map_err(|e| {
error!("Failed to calculate contract address: {}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
let tx_hash = deploy_account_tx
.calculate_transaction_hash(chain_id, &deploy_account_tx.version())
.map_err(|e| {
error!("Failed to calculate tx hash: {}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
let deploy_account_tx = BlockifierDeployAccountTransaction::new(
deploy_account_tx,
tx_hash,
contract_address,
);
Ok(AccountTransaction::DeployAccount(deploy_account_tx))
}
RpcTransaction::Invoke(RpcInvokeTransaction::V3(tx)) => {
let invoke_tx = starknet_api::transaction::InvokeTransaction::V3(InvokeTransactionV3 {
resource_bounds: ValidResourceBounds::AllResources(
rpc_tx.resource_bounds().clone(),
),
tip: tx.tip,
signature: tx.signature.clone(),
nonce: tx.nonce,
sender_address: tx.sender_address,
calldata: tx.calldata.clone(),
nonce_data_availability_mode: tx.nonce_data_availability_mode,
fee_data_availability_mode: tx.fee_data_availability_mode,
paymaster_data: tx.paymaster_data.clone(),
account_deployment_data: tx.account_deployment_data.clone(),
});
let tx_hash = invoke_tx
.calculate_transaction_hash(chain_id, &invoke_tx.version())
.map_err(|e| {
error!("Failed to calculate tx hash: {}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
let invoke_tx = BlockifierInvokeTransaction::new(invoke_tx, tx_hash);
Ok(AccountTransaction::Invoke(invoke_tx))
}
}
}
31 changes: 19 additions & 12 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@ use blockifier::blockifier::stateful_validator::{
};
use blockifier::bouncer::BouncerConfig;
use blockifier::context::BlockContext;
use blockifier::execution::contract_class::ClassInfo;
use blockifier::state::cached_state::CachedState;
use blockifier::transaction::account_transaction::AccountTransaction;
use blockifier::versioned_constants::VersionedConstants;
#[cfg(test)]
use mockall::automock;
use starknet_api::core::{ContractAddress, Nonce};
use starknet_api::rpc_transaction::{RpcInvokeTransaction, RpcTransaction};
use starknet_api::executable_transaction::{
InvokeTransaction as ExecutableInvokeTransaction,
Transaction as ExecutableTransaction,
};
use starknet_api::transaction::TransactionHash;
use starknet_types_core::felt::Felt;
use tracing::error;

use crate::config::StatefulTransactionValidatorConfig;
use crate::errors::{GatewaySpecError, StatefulTransactionValidatorResult};
use crate::state_reader::{MempoolStateReader, StateReaderFactory};
use crate::utils::{get_sender_address, rpc_tx_to_account_tx};
use crate::utils::get_sender_address;

#[cfg(test)]
#[path = "stateful_transaction_validator_test.rs"]
Expand Down Expand Up @@ -69,19 +71,24 @@ impl StatefulTransactionValidator {
// conversion is also relevant for the Mempool.
pub fn run_validate<V: StatefulTransactionValidatorTrait>(
&self,
rpc_tx: &RpcTransaction,
optional_class_info: Option<ClassInfo>,
executable_tx: &ExecutableTransaction,
mut validator: V,
) -> StatefulTransactionValidatorResult<ValidateInfo> {
let account_tx =
rpc_tx_to_account_tx(rpc_tx, optional_class_info, &self.config.chain_info.chain_id)?;
let account_tx = AccountTransaction::try_from(
// TODO(Arni): create a try_from for &ExecutableTransaction.
executable_tx.clone(),
)
.map_err(|error| {
error!("Failed to convert executable transaction into account transaction: {}", error);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
let tx_hash = account_tx.tx_hash();
let sender_address = get_sender_address(&account_tx);
let account_nonce = validator.get_nonce(sender_address).map_err(|e| {
error!("Failed to get nonce for sender address {}: {}", sender_address, e);
GatewaySpecError::UnexpectedError { data: "Internal server error.".to_owned() }
})?;
let skip_validate = skip_stateful_validations(rpc_tx, account_nonce);
let skip_validate = skip_stateful_validations(executable_tx, account_nonce);
validator
.validate(account_tx, skip_validate)
.map_err(|err| GatewaySpecError::ValidationFailure { data: err.to_string() })?;
Expand Down Expand Up @@ -119,15 +126,15 @@ impl StatefulTransactionValidator {
// Check if validation of an invoke transaction should be skipped due to deploy_account not being
// proccessed yet. This feature is used to improve UX for users sending deploy_account + invoke at
// once.
fn skip_stateful_validations(tx: &RpcTransaction, account_nonce: Nonce) -> bool {
fn skip_stateful_validations(tx: &ExecutableTransaction, account_nonce: Nonce) -> bool {
match tx {
RpcTransaction::Invoke(RpcInvokeTransaction::V3(tx)) => {
ExecutableTransaction::Invoke(ExecutableInvokeTransaction { tx, .. }) => {
// check if the transaction nonce is 1, meaning it is post deploy_account, and the
// account nonce is zero, meaning the account was not deployed yet. The mempool also
// verifies that the deploy_account transaction exists.
tx.nonce == Nonce(Felt::ONE) && account_nonce == Nonce(Felt::ZERO)
tx.nonce() == Nonce(Felt::ONE) && account_nonce == Nonce(Felt::ZERO)
}
RpcTransaction::DeployAccount(_) | RpcTransaction::Declare(_) => false,
ExecutableTransaction::DeployAccount(_) | ExecutableTransaction::Declare(_) => false,
}
}

Expand Down
46 changes: 30 additions & 16 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ use blockifier::blockifier::stateful_validator::{
use blockifier::context::BlockContext;
use blockifier::test_utils::CairoVersion;
use blockifier::transaction::errors::{TransactionFeeError, TransactionPreValidationError};
use mempool_test_utils::invoke_tx_args;
use mempool_test_utils::starknet_api_test_utils::{
deploy_account_tx,
invoke_tx,
rpc_invoke_tx,
executable_invoke_tx,
TEST_SENDER_ADDRESS,
VALID_L1_GAS_MAX_AMOUNT,
VALID_L1_GAS_MAX_PRICE_PER_UNIT,
Expand All @@ -19,7 +16,7 @@ use num_bigint::BigUint;
use pretty_assertions::assert_eq;
use rstest::{fixture, rstest};
use starknet_api::core::{ContractAddress, Nonce, PatriciaKey};
use starknet_api::rpc_transaction::RpcTransaction;
use starknet_api::executable_transaction::Transaction;
use starknet_api::transaction::TransactionHash;
use starknet_api::{contract_address, felt, patricia_key};
use starknet_types_core::felt::Felt;
Expand Down Expand Up @@ -65,7 +62,7 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat
// TODO(Arni): consider testing declare and deploy account.
#[rstest]
#[case::valid_tx(
invoke_tx(CairoVersion::Cairo1),
executable_invoke_tx(CairoVersion::Cairo1),
Ok(ValidateInfo{
tx_hash: TransactionHash(felt!(
"0x3b93426272b6e281bc9bde29b91a9fb100c2f9689388c62360b2be2f4e7b493"
Expand All @@ -74,9 +71,9 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat
account_nonce: Nonce::default()
})
)]
#[case::invalid_tx(invoke_tx(CairoVersion::Cairo1), Err(STATEFUL_VALIDATOR_FEE_ERROR))]
#[case::invalid_tx(executable_invoke_tx(CairoVersion::Cairo1), Err(STATEFUL_VALIDATOR_FEE_ERROR))]
fn test_stateful_tx_validator(
#[case] rpc_tx: RpcTransaction,
#[case] executable_tx: Transaction,
#[case] expected_result: BlockifierStatefulValidatorResult<ValidateInfo>,
stateful_validator: StatefulTransactionValidator,
) {
Expand All @@ -89,7 +86,7 @@ fn test_stateful_tx_validator(
mock_validator.expect_validate().return_once(|_, _| expected_result.map(|_| ()));
mock_validator.expect_get_nonce().returning(|_| Ok(Nonce(Felt::ZERO)));

let result = stateful_validator.run_validate(&rpc_tx, None, mock_validator);
let result = stateful_validator.run_validate(&executable_tx, mock_validator);
assert_eq!(result, expected_result_as_stateful_transaction_result);
}

Expand Down Expand Up @@ -129,28 +126,45 @@ fn test_instantiate_validator() {

#[rstest]
#[case::should_skip_validation(
rpc_invoke_tx(invoke_tx_args!{nonce: Nonce(Felt::ONE)}),
Transaction::Invoke(starknet_api::test_utils::invoke::executable_invoke_tx(
starknet_api::invoke_tx_args!(nonce: Nonce(Felt::ONE))
)),
Nonce::default(),
true
)]
#[case::should_not_skip_validation_nonce_over_max_nonce_for_skip(
rpc_invoke_tx(invoke_tx_args!{nonce: Nonce(Felt::TWO)}),
Transaction::Invoke(starknet_api::test_utils::invoke::executable_invoke_tx(
starknet_api::invoke_tx_args!(nonce: Nonce(Felt::ZERO))
)),
Nonce::default(),
false
)]
#[case::should_not_skip_validation_non_invoke(deploy_account_tx(), Nonce::default(), false)]
#[case::should_not_skip_validation_non_invoke(
Transaction::DeployAccount(
starknet_api::test_utils::deploy_account::executable_deploy_account_tx(
starknet_api::deploy_account_tx_args!(), Nonce::default()
)
),
Nonce::default(),
false)]
#[case::should_not_skip_validation_account_nonce_1(
rpc_invoke_tx(invoke_tx_args!{sender_address: ContractAddress::from(TEST_SENDER_ADDRESS), nonce: Nonce(Felt::ONE)}),
Transaction::Invoke(starknet_api::test_utils::invoke::executable_invoke_tx(
starknet_api::invoke_tx_args!(
nonce: Nonce(Felt::ONE),
sender_address: TEST_SENDER_ADDRESS.into()
)
)),
Nonce(Felt::ONE),
false
)]
fn test_skip_stateful_validation(
#[case] rpc_tx: RpcTransaction,
#[case] executable_tx: Transaction,
#[case] sender_nonce: Nonce,
#[case] should_skip_validate: bool,
stateful_validator: StatefulTransactionValidator,
) {
let sender_address = rpc_tx.calculate_sender_address().unwrap();
let sender_address = executable_tx.contract_address();

let mut mock_validator = MockStatefulTransactionValidatorTrait::new();
mock_validator
.expect_get_nonce()
Expand All @@ -160,5 +174,5 @@ fn test_skip_stateful_validation(
.expect_validate()
.withf(move |_, skip_validate| *skip_validate == should_skip_validate)
.returning(|_, _| Ok(()));
let _ = stateful_validator.run_validate(&rpc_tx, None, mock_validator);
let _ = stateful_validator.run_validate(&executable_tx, mock_validator);
}
Loading

0 comments on commit e503534

Please sign in to comment.