diff --git a/Cargo.lock b/Cargo.lock index f61e5cb64e..bd5f131dd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2761,6 +2761,7 @@ dependencies = [ "dojo-world", "jsonrpsee", "katana-core", + "katana-primitives", "katana-rpc", "scarb", "scarb-ui", @@ -5451,6 +5452,7 @@ dependencies = [ "clap_complete", "console", "katana-core", + "katana-primitives", "katana-rpc", "metrics 0.4.4", "metrics-process", @@ -5567,6 +5569,8 @@ dependencies = [ "serde_with", "starknet", "starknet_api", + "strum 0.25.0", + "strum_macros 0.25.3", "thiserror", ] diff --git a/crates/dojo-test-utils/Cargo.toml b/crates/dojo-test-utils/Cargo.toml index 3282db9d93..5e549eb37c 100644 --- a/crates/dojo-test-utils/Cargo.toml +++ b/crates/dojo-test-utils/Cargo.toml @@ -18,6 +18,7 @@ dojo-lang = { path = "../dojo-lang" } dojo-world = { path = "../dojo-world", features = [ "manifest", "migration" ] } jsonrpsee = { version = "0.16.2", features = [ "server" ] } katana-core = { path = "../katana/core" } +katana-primitives = { path = "../katana/primitives" } katana-rpc = { path = "../katana/rpc" } scarb-ui.workspace = true scarb.workspace = true diff --git a/crates/dojo-test-utils/src/sequencer.rs b/crates/dojo-test-utils/src/sequencer.rs index 29e1b97c96..fcb83c6163 100644 --- a/crates/dojo-test-utils/src/sequencer.rs +++ b/crates/dojo-test-utils/src/sequencer.rs @@ -4,6 +4,7 @@ use jsonrpsee::core::Error; pub use katana_core::backend::config::{Environment, StarknetConfig}; use katana_core::sequencer::KatanaSequencer; pub use katana_core::sequencer::SequencerConfig; +use katana_primitives::chain::ChainId; use katana_rpc::api::ApiKind; use katana_rpc::config::ServerConfig; use katana_rpc::{spawn, NodeHandle}; @@ -79,7 +80,7 @@ impl TestSequencer { pub fn get_default_test_starknet_config() -> StarknetConfig { StarknetConfig { disable_fee: true, - env: Environment { chain_id: "SN_GOERLI".into(), ..Default::default() }, + env: Environment { chain_id: ChainId::GOERLI, ..Default::default() }, ..Default::default() } } diff --git a/crates/katana/Cargo.toml b/crates/katana/Cargo.toml index 0e2b102a61..5196c53e50 100644 --- a/crates/katana/Cargo.toml +++ b/crates/katana/Cargo.toml @@ -11,6 +11,7 @@ clap.workspace = true clap_complete.workspace = true console.workspace = true katana-core = { path = "core" } +katana-primitives = { path = "primitives" } katana-rpc = { path = "rpc" } metrics = { path = "../metrics" } metrics-process.workspace = true diff --git a/crates/katana/core/src/backend/config.rs b/crates/katana/core/src/backend/config.rs index c466056286..3a0bba627c 100644 --- a/crates/katana/core/src/backend/config.rs +++ b/crates/katana/core/src/backend/config.rs @@ -1,6 +1,6 @@ use blockifier::block_context::{BlockContext, FeeTokenAddresses, GasPrices}; +use katana_primitives::chain::ChainId; use starknet_api::block::{BlockNumber, BlockTimestamp}; -use starknet_api::core::ChainId; use url::Url; use crate::constants::{ @@ -24,7 +24,7 @@ impl StarknetConfig { pub fn block_context(&self) -> BlockContext { BlockContext { block_number: BlockNumber::default(), - chain_id: ChainId(self.env.chain_id.clone()), + chain_id: self.env.chain_id.into(), block_timestamp: BlockTimestamp::default(), sequencer_address: (*SEQUENCER_ADDRESS).into(), // As the fee has two currencies, we also have to adjust their addresses. @@ -67,7 +67,7 @@ impl Default for StarknetConfig { #[derive(Debug, Clone)] pub struct Environment { - pub chain_id: String, + pub chain_id: ChainId, pub gas_price: u128, pub invoke_max_steps: u32, pub validate_max_steps: u32, @@ -77,7 +77,7 @@ impl Default for Environment { fn default() -> Self { Self { gas_price: DEFAULT_GAS_PRICE, - chain_id: "KATANA".to_string(), + chain_id: ChainId::parse("KATANA").unwrap(), invoke_max_steps: DEFAULT_INVOKE_MAX_STEPS, validate_max_steps: DEFAULT_VALIDATE_MAX_STEPS, } diff --git a/crates/katana/core/src/backend/mod.rs b/crates/katana/core/src/backend/mod.rs index 99d6009687..ae04161dbe 100644 --- a/crates/katana/core/src/backend/mod.rs +++ b/crates/katana/core/src/backend/mod.rs @@ -4,6 +4,7 @@ use blockifier::block_context::BlockContext; use katana_primitives::block::{ Block, FinalityStatus, GasPrices, Header, PartialHeader, SealedBlockWithStatus, }; +use katana_primitives::chain::ChainId; use katana_primitives::contract::ContractAddress; use katana_primitives::receipt::Receipt; use katana_primitives::state::StateUpdatesWithDeclaredClasses; @@ -20,7 +21,6 @@ use starknet::core::utils::parse_cairo_short_string; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::{JsonRpcClient, Provider}; use starknet_api::block::{BlockNumber, BlockTimestamp}; -use starknet_api::core::ChainId; use tracing::{info, trace}; pub mod config; @@ -40,6 +40,8 @@ pub struct Backend { pub config: RwLock, /// stores all block related data in memory pub blockchain: Blockchain, + /// The chain id. + pub chain_id: ChainId, /// The chain environment values. pub env: Arc>, pub block_context_generator: RwLock, @@ -57,7 +59,9 @@ impl Backend { .with_balance(*DEFAULT_PREFUNDED_ACCOUNT_BALANCE) .generate(); - let blockchain: Blockchain = if let Some(forked_url) = &config.fork_rpc_url { + let (blockchain, chain_id): (Blockchain, ChainId) = if let Some(forked_url) = + &config.fork_rpc_url + { let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(forked_url.clone()))); let forked_chain_id = provider.chain_id().await.unwrap(); @@ -79,7 +83,8 @@ impl Backend { block_context.block_number = BlockNumber(block.block_number); block_context.block_timestamp = BlockTimestamp(block.timestamp); block_context.sequencer_address = ContractAddress(block.sequencer_address).into(); - block_context.chain_id = ChainId(parse_cairo_short_string(&forked_chain_id).unwrap()); + block_context.chain_id = + starknet_api::core::ChainId(parse_cairo_short_string(&forked_chain_id).unwrap()); trace!( target: "backend", @@ -89,7 +94,7 @@ impl Backend { forked_url ); - Blockchain::new_from_forked( + let blockchain = Blockchain::new_from_forked( ForkedProvider::new(provider, forked_block_num.into()), block.block_hash, block.parent_hash, @@ -101,10 +106,14 @@ impl Backend { _ => panic!("unable to fork for non-accepted block"), }, ) - .expect("able to create forked blockchain") + .expect("able to create forked blockchain"); + + (blockchain, forked_chain_id.into()) } else { - Blockchain::new_with_genesis(InMemoryProvider::new(), &block_context) - .expect("able to create blockchain from genesis block") + let blockchain = Blockchain::new_with_genesis(InMemoryProvider::new(), &block_context) + .expect("able to create blockchain from genesis block"); + + (blockchain, config.env.chain_id) }; let env = Env { block: block_context }; @@ -115,6 +124,7 @@ impl Backend { } Self { + chain_id, accounts, blockchain, config: RwLock::new(config), diff --git a/crates/katana/core/src/sequencer.rs b/crates/katana/core/src/sequencer.rs index 4a3e7a5610..1f2ac0d620 100644 --- a/crates/katana/core/src/sequencer.rs +++ b/crates/katana/core/src/sequencer.rs @@ -10,6 +10,7 @@ use katana_executor::blockifier::state::StateRefDb; use katana_executor::blockifier::utils::EntryPointCall; use katana_executor::blockifier::PendingState; use katana_primitives::block::{BlockHash, BlockHashOrNumber, BlockIdOrTag, BlockNumber}; +use katana_primitives::chain::ChainId; use katana_primitives::contract::{ ClassHash, CompiledContractClass, ContractAddress, Nonce, StorageKey, StorageValue, }; @@ -26,7 +27,6 @@ use katana_provider::traits::transaction::{ ReceiptProvider, TransactionProvider, TransactionsProviderExt, }; use starknet::core::types::{BlockTag, EmittedEvent, EventsPage, FeeEstimate}; -use starknet_api::core::ChainId; use crate::backend::config::StarknetConfig; use crate::backend::contract::StarknetContract; @@ -215,7 +215,7 @@ impl KatanaSequencer { } pub fn chain_id(&self) -> ChainId { - self.backend.env.read().block.chain_id.clone() + self.backend.chain_id } pub fn block_number(&self) -> BlockNumber { diff --git a/crates/katana/core/src/service/messaging/ethereum.rs b/crates/katana/core/src/service/messaging/ethereum.rs index ec4e0aa140..a9710c2cf4 100644 --- a/crates/katana/core/src/service/messaging/ethereum.rs +++ b/crates/katana/core/src/service/messaging/ethereum.rs @@ -8,6 +8,7 @@ use ethers::prelude::*; use ethers::providers::{Http, Provider}; use ethers::types::{Address, BlockNumber, Log}; use k256::ecdsa::SigningKey; +use katana_primitives::chain::ChainId; use katana_primitives::receipt::MessageToL1; use katana_primitives::transaction::L1HandlerTx; use katana_primitives::utils::transaction::compute_l1_message_hash; @@ -127,7 +128,7 @@ impl Messenger for EthereumMessaging { &self, from_block: u64, max_blocks: u64, - chain_id: FieldElement, + chain_id: ChainId, ) -> MessengerResult<(u64, Vec)> { let chain_latest_block: u64 = self .provider @@ -206,7 +207,7 @@ impl Messenger for EthereumMessaging { } } -fn l1_handler_tx_from_log(log: Log, chain_id: FieldElement) -> MessengerResult { +fn l1_handler_tx_from_log(log: Log, chain_id: ChainId) -> MessengerResult { let parsed_log = ::decode_log(&log.into()).map_err(|e| { error!(target: LOG_TARGET, "Log parsing failed {e}"); Error::GatherError @@ -259,6 +260,7 @@ fn felt_from_address(v: Address) -> FieldElement { #[cfg(test)] mod tests { + use katana_primitives::chain::{ChainId, NamedChainId}; use starknet::macros::{felt, selector}; use super::*; @@ -299,7 +301,7 @@ mod tests { }; // SN_GOERLI. - let chain_id = starknet::macros::felt!("0x534e5f474f45524c49"); + let chain_id = ChainId::Named(NamedChainId::Goerli); let to_address = FieldElement::from_hex_be(to_address).unwrap(); let from_address = FieldElement::from_hex_be(from_address).unwrap(); diff --git a/crates/katana/core/src/service/messaging/mod.rs b/crates/katana/core/src/service/messaging/mod.rs index 79d8c6364d..6b2de596c2 100644 --- a/crates/katana/core/src/service/messaging/mod.rs +++ b/crates/katana/core/src/service/messaging/mod.rs @@ -39,12 +39,12 @@ mod starknet; use std::path::Path; -use ::starknet::core::types::FieldElement; use ::starknet::providers::ProviderError as StarknetProviderError; use anyhow::Result; use async_trait::async_trait; use ethereum::EthereumMessaging; use ethers::providers::ProviderError as EthereumProviderError; +use katana_primitives::chain::ChainId; use katana_primitives::receipt::MessageToL1; use serde::Deserialize; use tracing::{error, info}; @@ -145,7 +145,7 @@ pub trait Messenger { &self, from_block: u64, max_blocks: u64, - chain_id: FieldElement, + chain_id: ChainId, ) -> MessengerResult<(u64, Vec)>; /// Computes the hash of the given messages and sends them to the settlement chain. diff --git a/crates/katana/core/src/service/messaging/service.rs b/crates/katana/core/src/service/messaging/service.rs index e43e8df890..379a866dce 100644 --- a/crates/katana/core/src/service/messaging/service.rs +++ b/crates/katana/core/src/service/messaging/service.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use ::starknet::core::types::FieldElement; use futures::{Future, FutureExt, Stream}; use katana_primitives::block::BlockHashOrNumber; use katana_primitives::receipt::MessageToL1; @@ -76,9 +75,6 @@ impl MessagingService { backend: Arc, from_block: u64, ) -> MessengerResult<(u64, usize)> { - let chain_id = FieldElement::from_hex_be(&backend.env.read().block.chain_id.as_hex()) - .expect("failed to parse katana chain id"); - // 200 avoids any possible rejection from RPC with possibly lot's of messages. // TODO: May this be configurable? let max_block = 200; @@ -86,7 +82,7 @@ impl MessagingService { match messenger.as_ref() { MessengerMode::Ethereum(inner) => { let (block_num, txs) = - inner.gather_messages(from_block, max_block, chain_id).await?; + inner.gather_messages(from_block, max_block, backend.chain_id).await?; let txs_count = txs.len(); txs.into_iter().for_each(|tx| { @@ -101,7 +97,7 @@ impl MessagingService { #[cfg(feature = "starknet-messaging")] MessengerMode::Starknet(inner) => { let (block_num, txs) = - inner.gather_messages(from_block, max_block, chain_id).await?; + inner.gather_messages(from_block, max_block, backend.chain_id).await?; let txs_count = txs.len(); txs.into_iter().for_each(|tx| { diff --git a/crates/katana/core/src/service/messaging/starknet.rs b/crates/katana/core/src/service/messaging/starknet.rs index 642f0de240..019c6e1970 100644 --- a/crates/katana/core/src/service/messaging/starknet.rs +++ b/crates/katana/core/src/service/messaging/starknet.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; +use katana_primitives::chain::ChainId; use katana_primitives::receipt::MessageToL1; use katana_primitives::transaction::L1HandlerTx; use katana_primitives::utils::transaction::compute_l1_message_hash; @@ -163,7 +164,7 @@ impl Messenger for StarknetMessaging { &self, from_block: u64, max_blocks: u64, - chain_id: FieldElement, + chain_id: ChainId, ) -> MessengerResult<(u64, Vec)> { let chain_latest_block: u64 = match self.provider.block_number().await { Ok(n) => n, @@ -306,7 +307,7 @@ fn parse_messages(messages: &[MessageToL1]) -> MessengerResult<(Vec Result { +fn l1_handler_tx_from_event(event: &EmittedEvent, chain_id: ChainId) -> Result { if event.keys[0] != selector!("MessageSentToAppchain") { debug!( target: LOG_TARGET, @@ -429,7 +430,7 @@ mod tests { let from_address = selector!("from_address"); let to_address = selector!("to_address"); let selector = selector!("selector"); - let chain_id = selector!("KATANA"); + let chain_id = ChainId::parse("KATANA").unwrap(); let nonce = FieldElement::ONE; let calldata = vec![from_address, FieldElement::THREE]; @@ -438,7 +439,7 @@ mod tests { to_address, selector, &calldata, - chain_id, + chain_id.into(), nonce, ); @@ -512,7 +513,7 @@ mod tests { transaction_hash, }; - let _tx = l1_handler_tx_from_event(&event, FieldElement::ZERO).unwrap(); + let _tx = l1_handler_tx_from_event(&event, ChainId::default()).unwrap(); } #[test] @@ -536,6 +537,6 @@ mod tests { transaction_hash, }; - let _tx = l1_handler_tx_from_event(&event, FieldElement::ZERO).unwrap(); + let _tx = l1_handler_tx_from_event(&event, ChainId::default()).unwrap(); } } diff --git a/crates/katana/primitives/Cargo.toml b/crates/katana/primitives/Cargo.toml index 3f2b36eddc..9a3f8799dd 100644 --- a/crates/katana/primitives/Cargo.toml +++ b/crates/katana/primitives/Cargo.toml @@ -14,6 +14,8 @@ serde.workspace = true serde_json.workspace = true serde_with.workspace = true starknet.workspace = true +strum.workspace = true +strum_macros.workspace = true thiserror.workspace = true blockifier.workspace = true diff --git a/crates/katana/primitives/src/chain.rs b/crates/katana/primitives/src/chain.rs new file mode 100644 index 0000000000..ec1c2d6f77 --- /dev/null +++ b/crates/katana/primitives/src/chain.rs @@ -0,0 +1,219 @@ +use starknet::core::types::{FieldElement, FromStrError}; +use starknet::core::utils::{cairo_short_string_to_felt, CairoShortStringToFeltError}; + +/// Known chain ids that has been assigned a name. +#[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::Display)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum NamedChainId { + Mainnet, + Goerli, + Sepolia, +} + +impl NamedChainId { + /// `SN_MAIN` in ASCII + pub const SN_MAIN: FieldElement = FieldElement::from_mont([ + 0xf596341657d6d657, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0x6f9757bd5443bc6, + ]); + + /// `SN_GOERLI` in ASCII + pub const SN_GOERLI: FieldElement = FieldElement::from_mont([ + 0x3417161755cc97b2, + 0xfffffffffffff596, + 0xffffffffffffffff, + 0x588778cb29612d1, + ]); + + /// `SN_SEPOLIA` in ASCII + pub const SN_SEPOLIA: FieldElement = FieldElement::from_mont([ + 0x159755f62c97a933, + 0xfffffffffff59634, + 0xffffffffffffffff, + 0x70cb558f6123c62, + ]); + + /// Returns the id of the chain. It is the ASCII representation of a predefined string + /// constants. + #[inline] + pub const fn id(&self) -> FieldElement { + match self { + NamedChainId::Mainnet => Self::SN_MAIN, + NamedChainId::Goerli => Self::SN_GOERLI, + NamedChainId::Sepolia => Self::SN_SEPOLIA, + } + } + + /// Returns the predefined string constant of the chain id. + #[inline] + pub const fn name(&self) -> &'static str { + match self { + NamedChainId::Mainnet => "SN_MAIN", + NamedChainId::Goerli => "SN_GOERLI", + NamedChainId::Sepolia => "SN_SEPOLIA", + } + } +} + +/// This `struct` is created by the [`NamedChainId::try_from`] method. +#[derive(Debug, thiserror::Error)] +#[error("Unknown named chain id {0:#x}")] +pub struct NamedChainTryFromError(FieldElement); + +impl TryFrom for NamedChainId { + type Error = NamedChainTryFromError; + fn try_from(value: FieldElement) -> Result { + if value == Self::SN_MAIN { + Ok(Self::Mainnet) + } else if value == Self::SN_GOERLI { + Ok(Self::Goerli) + } else if value == Self::SN_SEPOLIA { + Ok(Self::Sepolia) + } else { + Err(NamedChainTryFromError(value)) + } + } +} + +/// Represents a chain id. +#[derive(Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ChainId { + /// A chain id with a known chain name. + Named(NamedChainId), + Id(FieldElement), +} + +#[derive(Debug, thiserror::Error)] +pub enum ParseChainIdError { + #[error(transparent)] + FromStr(#[from] FromStrError), + #[error(transparent)] + CairoShortStringToFelt(#[from] CairoShortStringToFeltError), +} + +impl ChainId { + /// Chain id of the Starknet mainnet. + pub const MAINNET: Self = Self::Named(NamedChainId::Mainnet); + /// Chain id of the Starknet goerli testnet. + pub const GOERLI: Self = Self::Named(NamedChainId::Goerli); + /// Chain id of the Starknet sepolia testnet. + pub const SEPOLIA: Self = Self::Named(NamedChainId::Sepolia); + + /// Parse a [`ChainId`] from a [`str`]. + /// + /// If the `str` starts with `0x` it is parsed as a hex string, otherwise it is parsed as a + /// Cairo short string. + pub fn parse(s: &str) -> Result { + let id = if s.starts_with("0x") { + FieldElement::from_hex_be(s)? + } else { + cairo_short_string_to_felt(s)? + }; + Ok(ChainId::from(id)) + } + + /// Returns the chain id value. + pub const fn id(&self) -> FieldElement { + match self { + ChainId::Named(name) => name.id(), + ChainId::Id(id) => *id, + } + } +} + +impl Default for ChainId { + fn default() -> Self { + ChainId::Id(FieldElement::ZERO) + } +} + +impl std::fmt::Debug for ChainId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChainId::Named(name) => write!(f, "ChainId {{ name: {name}, id: {:#x} }}", name.id()), + ChainId::Id(id) => write!(f, "ChainId {{ id: {id:#x} }}"), + } + } +} + +impl std::fmt::Display for ChainId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChainId::Named(id) => write!(f, "{id}"), + ChainId::Id(id) => write!(f, "{id:#x}"), + } + } +} + +impl From for ChainId { + fn from(value: FieldElement) -> Self { + NamedChainId::try_from(value).map(ChainId::Named).unwrap_or(ChainId::Id(value)) + } +} + +impl From for FieldElement { + fn from(value: ChainId) -> Self { + value.id() + } +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use starknet::core::utils::cairo_short_string_to_felt; + use starknet::macros::felt; + + use super::ChainId; + use crate::chain::NamedChainId; + + #[test] + fn named_chain_id() { + let mainnet_id = cairo_short_string_to_felt("SN_MAIN").unwrap(); + let goerli_id = cairo_short_string_to_felt("SN_GOERLI").unwrap(); + let sepolia_id = cairo_short_string_to_felt("SN_SEPOLIA").unwrap(); + + assert_eq!(NamedChainId::Mainnet.id(), mainnet_id); + assert_eq!(NamedChainId::Goerli.id(), goerli_id); + assert_eq!(NamedChainId::Sepolia.id(), sepolia_id); + + assert_eq!(NamedChainId::try_from(mainnet_id).unwrap(), NamedChainId::Mainnet); + assert_eq!(NamedChainId::try_from(goerli_id).unwrap(), NamedChainId::Goerli); + assert_eq!(NamedChainId::try_from(sepolia_id).unwrap(), NamedChainId::Sepolia); + assert!(NamedChainId::try_from(felt!("0x1337")).is_err()); + } + + #[test] + fn chain_id() { + let mainnet_id = cairo_short_string_to_felt("SN_MAIN").unwrap(); + let goerli_id = cairo_short_string_to_felt("SN_GOERLI").unwrap(); + let sepolia_id = cairo_short_string_to_felt("SN_SEPOLIA").unwrap(); + + assert_eq!(ChainId::MAINNET.id(), NamedChainId::Mainnet.id()); + assert_eq!(ChainId::GOERLI.id(), NamedChainId::Goerli.id()); + assert_eq!(ChainId::SEPOLIA.id(), NamedChainId::Sepolia.id()); + + assert_eq!(ChainId::from(mainnet_id), ChainId::MAINNET); + assert_eq!(ChainId::from(goerli_id), ChainId::GOERLI); + assert_eq!(ChainId::from(sepolia_id), ChainId::SEPOLIA); + assert_eq!(ChainId::from(felt!("0x1337")), ChainId::Id(felt!("0x1337"))); + + assert_eq!(ChainId::MAINNET.to_string(), "Mainnet"); + assert_eq!(ChainId::GOERLI.to_string(), "Goerli"); + assert_eq!(ChainId::SEPOLIA.to_string(), "Sepolia"); + assert_eq!(ChainId::Id(felt!("0x1337")).to_string(), "0x1337"); + } + + #[test] + fn parse_chain_id() { + let mainnet_id = cairo_short_string_to_felt("SN_MAIN").unwrap(); + let custom_id = cairo_short_string_to_felt("KATANA").unwrap(); + + assert_eq!(ChainId::parse("SN_MAIN").unwrap(), ChainId::MAINNET); + assert_eq!(ChainId::parse("KATANA").unwrap(), ChainId::Id(custom_id)); + assert_eq!(ChainId::parse(&format!("{mainnet_id:#x}")).unwrap(), ChainId::MAINNET); + } +} diff --git a/crates/katana/primitives/src/conversion/blockifier.rs b/crates/katana/primitives/src/conversion/blockifier.rs index 2ed3ca2983..80751fb896 100644 --- a/crates/katana/primitives/src/conversion/blockifier.rs +++ b/crates/katana/primitives/src/conversion/blockifier.rs @@ -1,9 +1,12 @@ //! Translation layer for converting the primitive types to the execution engine types. +use starknet::core::utils::parse_cairo_short_string; use starknet_api::core::{ContractAddress, PatriciaKey}; use starknet_api::hash::StarkHash; use starknet_api::patricia_key; +use crate::chain::ChainId; + impl From for ContractAddress { fn from(address: crate::contract::ContractAddress) -> Self { Self(patricia_key!(address.0)) @@ -15,3 +18,31 @@ impl From for crate::contract::ContractAddress { Self((*address.0.key()).into()) } } + +impl From for starknet_api::core::ChainId { + fn from(chain_id: ChainId) -> Self { + let name: String = match chain_id { + ChainId::Named(named) => named.name().to_string(), + ChainId::Id(id) => parse_cairo_short_string(&id).expect("valid cairo string"), + }; + Self(name) + } +} + +#[cfg(test)] +mod tests { + use starknet::core::utils::parse_cairo_short_string; + + use crate::chain::{ChainId, NamedChainId}; + + #[test] + fn convert_chain_id() { + let mainnet = starknet_api::core::ChainId::from(ChainId::Named(NamedChainId::Mainnet)); + let goerli = starknet_api::core::ChainId::from(ChainId::Named(NamedChainId::Goerli)); + let sepolia = starknet_api::core::ChainId::from(ChainId::Named(NamedChainId::Sepolia)); + + assert_eq!(mainnet.0, parse_cairo_short_string(&NamedChainId::Mainnet.id()).unwrap()); + assert_eq!(goerli.0, parse_cairo_short_string(&NamedChainId::Goerli.id()).unwrap()); + assert_eq!(sepolia.0, parse_cairo_short_string(&NamedChainId::Sepolia.id()).unwrap()); + } +} diff --git a/crates/katana/primitives/src/env.rs b/crates/katana/primitives/src/env.rs index 5d7633ce36..64c381ba7c 100644 --- a/crates/katana/primitives/src/env.rs +++ b/crates/katana/primitives/src/env.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use crate::chain::ChainId; use crate::contract::ContractAddress; /// Block environment values. @@ -21,7 +22,7 @@ pub struct BlockEnv { #[derive(Debug, Clone)] pub struct CfgEnv { /// The chain id. - pub chain_id: u64, + pub chain_id: ChainId, /// The fee cost of the VM resources. pub vm_resource_fee_cost: HashMap, /// The maximum number of steps allowed for an invoke transaction. diff --git a/crates/katana/primitives/src/lib.rs b/crates/katana/primitives/src/lib.rs index d09bf7c232..f3299e2e6b 100644 --- a/crates/katana/primitives/src/lib.rs +++ b/crates/katana/primitives/src/lib.rs @@ -1,4 +1,5 @@ pub mod block; +pub mod chain; pub mod contract; pub mod env; pub mod event; @@ -12,6 +13,3 @@ pub mod state; pub mod utils; pub type FieldElement = starknet::core::types::FieldElement; - -/// The id of the chain. -pub type ChainId = FieldElement; diff --git a/crates/katana/primitives/src/transaction.rs b/crates/katana/primitives/src/transaction.rs index ab5665d844..033374400b 100644 --- a/crates/katana/primitives/src/transaction.rs +++ b/crates/katana/primitives/src/transaction.rs @@ -1,6 +1,7 @@ use derive_more::{AsRef, Deref}; use ethers::types::H256; +use crate::chain::ChainId; use crate::contract::{ ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, FlattenedSierraClass, Nonce, @@ -9,7 +10,7 @@ use crate::utils::transaction::{ compute_declare_v1_tx_hash, compute_declare_v2_tx_hash, compute_deploy_account_v1_tx_hash, compute_invoke_v1_tx_hash, compute_l1_handler_tx_hash, }; -use crate::{ChainId, FieldElement}; +use crate::FieldElement; /// The hash of a transaction. pub type TxHash = FieldElement; @@ -136,7 +137,7 @@ impl InvokeTx { self.sender_address.into(), &self.calldata, self.max_fee, - self.chain_id, + self.chain_id.into(), self.nonce, is_query, ) @@ -195,7 +196,7 @@ impl DeclareTx { tx.sender_address.into(), tx.class_hash, tx.max_fee, - tx.chain_id, + tx.chain_id.into(), tx.nonce, is_query, ), @@ -204,7 +205,7 @@ impl DeclareTx { tx.sender_address.into(), tx.class_hash, tx.max_fee, - tx.chain_id, + tx.chain_id.into(), tx.nonce, tx.compiled_class_hash, is_query, @@ -234,7 +235,7 @@ impl L1HandlerTx { self.contract_address.into(), self.entry_point_selector, &self.calldata, - self.chain_id, + self.chain_id.into(), self.nonce, ) } @@ -263,7 +264,7 @@ impl DeployAccountTx { self.class_hash, self.contract_address_salt, self.max_fee, - self.chain_id, + self.chain_id.into(), self.nonce, is_query, ) diff --git a/crates/katana/rpc/rpc-types/src/message.rs b/crates/katana/rpc/rpc-types/src/message.rs index cff3689a2d..3b6c37b446 100644 --- a/crates/katana/rpc/rpc-types/src/message.rs +++ b/crates/katana/rpc/rpc-types/src/message.rs @@ -1,3 +1,4 @@ +use katana_primitives::chain::ChainId; use katana_primitives::transaction::L1HandlerTx; use katana_primitives::utils::transaction::compute_l1_message_hash; use katana_primitives::FieldElement; @@ -7,7 +8,7 @@ use serde::Deserialize; pub struct MsgFromL1(starknet::core::types::MsgFromL1); impl MsgFromL1 { - pub fn into_tx_with_chain_id(self, chain_id: FieldElement) -> L1HandlerTx { + pub fn into_tx_with_chain_id(self, chain_id: ChainId) -> L1HandlerTx { let message_hash = compute_l1_message_hash( // This conversion will never fail bcs `from_address` is 20 bytes and the it will only // fail if the slice is > 32 bytes diff --git a/crates/katana/rpc/rpc-types/src/transaction.rs b/crates/katana/rpc/rpc-types/src/transaction.rs index 69f4f1740c..00aab586bc 100644 --- a/crates/katana/rpc/rpc-types/src/transaction.rs +++ b/crates/katana/rpc/rpc-types/src/transaction.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use anyhow::Result; use derive_more::Deref; +use katana_primitives::chain::ChainId; use katana_primitives::contract::{ClassHash, ContractAddress}; use katana_primitives::conversion::rpc::{ compiled_class_hash_from_flattened_sierra_class, flattened_sierra_to_compiled_class, @@ -25,7 +26,7 @@ use starknet::core::utils::get_contract_address; pub struct BroadcastedInvokeTx(BroadcastedInvokeTransaction); impl BroadcastedInvokeTx { - pub fn into_tx_with_chain_id(self, chain_id: FieldElement) -> InvokeTx { + pub fn into_tx_with_chain_id(self, chain_id: ChainId) -> InvokeTx { InvokeTx { chain_id, nonce: self.0.nonce, @@ -57,7 +58,7 @@ impl BroadcastedDeclareTx { } /// This function assumes that the compiled class hash is valid. - pub fn try_into_tx_with_chain_id(self, chain_id: FieldElement) -> Result { + pub fn try_into_tx_with_chain_id(self, chain_id: ChainId) -> Result { match self.0 { BroadcastedDeclareTransaction::V1(tx) => { let (class_hash, compiled_class) = @@ -112,7 +113,7 @@ impl BroadcastedDeclareTx { pub struct BroadcastedDeployAccountTx(BroadcastedDeployAccountTransaction); impl BroadcastedDeployAccountTx { - pub fn into_tx_with_chain_id(self, chain_id: FieldElement) -> DeployAccountTx { + pub fn into_tx_with_chain_id(self, chain_id: ChainId) -> DeployAccountTx { let contract_address = get_contract_address( self.0.contract_address_salt, self.0.class_hash, @@ -276,7 +277,7 @@ impl From for InvokeTx { calldata: tx.0.calldata, signature: tx.0.signature, version: FieldElement::ONE, - chain_id: FieldElement::ZERO, + chain_id: ChainId::default(), sender_address: tx.0.sender_address.into(), max_fee: tx.0.max_fee.try_into().expect("max_fee is too big"), } @@ -297,7 +298,7 @@ impl From for DeployAccountTx { signature: tx.0.signature, version: FieldElement::ONE, class_hash: tx.0.class_hash, - chain_id: FieldElement::ZERO, + chain_id: ChainId::default(), contract_address: contract_address.into(), constructor_calldata: tx.0.constructor_calldata, contract_address_salt: tx.0.contract_address_salt, diff --git a/crates/katana/rpc/src/starknet.rs b/crates/katana/rpc/src/starknet.rs index a6a1a750b8..51b53b8812 100644 --- a/crates/katana/rpc/src/starknet.rs +++ b/crates/katana/rpc/src/starknet.rs @@ -1,4 +1,3 @@ -use std::str::FromStr; use std::sync::Arc; use jsonrpsee::core::{async_trait, Error}; @@ -47,8 +46,7 @@ impl StarknetApi { #[async_trait] impl StarknetApiServer for StarknetApi { async fn chain_id(&self) -> Result { - let chain_id = self.sequencer.chain_id().as_hex(); - Ok(FieldElement::from_str(&chain_id).map_err(|_| StarknetApiError::UnexpectedError)?.into()) + Ok(FieldElement::from(self.sequencer.chain_id()).into()) } async fn nonce( @@ -401,8 +399,7 @@ impl StarknetApiServer for StarknetApi { return Err(StarknetApiError::UnsupportedTransactionVersion.into()); } - let chain_id = FieldElement::from_hex_be(&self.sequencer.chain_id().as_hex()) - .map_err(|_| StarknetApiError::UnexpectedError)?; + let chain_id = self.sequencer.chain_id(); let tx = deploy_account_transaction.into_tx_with_chain_id(chain_id); let contract_address = tx.contract_address; @@ -420,8 +417,7 @@ impl StarknetApiServer for StarknetApi { request: Vec, block_id: BlockIdOrTag, ) -> Result, Error> { - let chain_id = FieldElement::from_hex_be(&self.sequencer.chain_id().as_hex()) - .map_err(|_| StarknetApiError::UnexpectedError)?; + let chain_id = self.sequencer.chain_id(); let transactions = request .into_iter() @@ -465,8 +461,7 @@ impl StarknetApiServer for StarknetApi { message: MsgFromL1, block_id: BlockIdOrTag, ) -> Result { - let chain_id = FieldElement::from_hex_be(&self.sequencer.chain_id().as_hex()) - .map_err(|_| StarknetApiError::UnexpectedError)?; + let chain_id = self.sequencer.chain_id(); let tx = message.into_tx_with_chain_id(chain_id); let hash = tx.calculate_hash(); @@ -496,8 +491,7 @@ impl StarknetApiServer for StarknetApi { return Err(StarknetApiError::UnsupportedTransactionVersion.into()); } - let chain_id = FieldElement::from_hex_be(&self.sequencer.chain_id().as_hex()) - .map_err(|_| StarknetApiError::UnexpectedError)?; + let chain_id = self.sequencer.chain_id(); // // validate compiled class hash // let is_valid = declare_transaction @@ -529,8 +523,7 @@ impl StarknetApiServer for StarknetApi { return Err(StarknetApiError::UnsupportedTransactionVersion.into()); } - let chain_id = FieldElement::from_hex_be(&self.sequencer.chain_id().as_hex()) - .map_err(|_| StarknetApiError::UnexpectedError)?; + let chain_id = self.sequencer.chain_id(); let tx = invoke_transaction.into_tx_with_chain_id(chain_id); let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); diff --git a/crates/katana/src/args.rs b/crates/katana/src/args.rs index 9d23d947e3..14237b3ece 100644 --- a/crates/katana/src/args.rs +++ b/crates/katana/src/args.rs @@ -20,6 +20,7 @@ use katana_core::constants::{ DEFAULT_GAS_PRICE, DEFAULT_INVOKE_MAX_STEPS, DEFAULT_VALIDATE_MAX_STEPS, }; use katana_core::sequencer::SequencerConfig; +use katana_primitives::chain::ChainId; use katana_rpc::api::ApiKind; use katana_rpc::config::ServerConfig; use metrics::utils::parse_socket_address; @@ -151,8 +152,12 @@ pub struct StarknetOptions { pub struct EnvironmentOptions { #[arg(long)] #[arg(help = "The chain ID.")] + #[arg(long_help = "The chain ID. If a raw hex string (`0x` prefix) is provided, then it'd \ + used as the actual chain ID. Otherwise, it's represented as the raw \ + ASCII values. It must be a valid Cairo short string.")] #[arg(default_value = "KATANA")] - pub chain_id: String, + #[arg(value_parser = ChainId::parse)] + pub chain_id: ChainId, #[arg(long)] #[arg(help = "The gas price.")] @@ -219,7 +224,7 @@ impl KatanaArgs { fork_rpc_url: self.rpc_url.clone(), fork_block_number: self.fork_block_number, env: Environment { - chain_id: self.starknet.environment.chain_id.clone(), + chain_id: self.starknet.environment.chain_id, gas_price: self.starknet.environment.gas_price.unwrap_or(DEFAULT_GAS_PRICE), invoke_max_steps: self .starknet