diff --git a/bin/katana/src/cli/node.rs b/bin/katana/src/cli/node.rs index c187f4fe59..461fc6e036 100644 --- a/bin/katana/src/cli/node.rs +++ b/bin/katana/src/cli/node.rs @@ -360,11 +360,8 @@ impl NodeArgs { } } - if self.starknet == StarknetOptions::default() { - if let Some(starknet) = config.starknet { - self.starknet = starknet; - } - } + self.starknet.merge(config.starknet.as_ref()); + self.development.merge(config.development.as_ref()); if self.gpo == GasPriceOracleOptions::default() { if let Some(gpo) = config.gpo { @@ -378,12 +375,6 @@ impl NodeArgs { } } - if self.development == DevOptions::default() { - if let Some(development) = config.development { - self.development = development; - } - } - #[cfg(feature = "slot")] if self.slot == SlotOptions::default() { if let Some(slot) = config.slot { @@ -509,6 +500,8 @@ PREFUNDED ACCOUNTS #[cfg(test)] mod test { + use std::str::FromStr; + use assert_matches::assert_matches; use katana_core::constants::{ DEFAULT_ETH_L1_DATA_GAS_PRICE, DEFAULT_ETH_L1_GAS_PRICE, DEFAULT_STRK_L1_DATA_GAS_PRICE, @@ -518,7 +511,7 @@ mod test { DEFAULT_INVOCATION_MAX_STEPS, DEFAULT_VALIDATION_MAX_STEPS, }; use katana_primitives::chain::ChainId; - use katana_primitives::{address, felt}; + use katana_primitives::{address, felt, Felt}; use super::*; @@ -694,6 +687,8 @@ total_accounts = 20 [starknet.env] validate_max_steps = 500 +invoke_max_steps = 9988 +chain_id.Named = "Mainnet" "#; let path = std::env::temp_dir().join("katana-config.json"); std::fs::write(&path, content).unwrap(); @@ -710,14 +705,15 @@ validate_max_steps = 500 "1234", "--dev", "--dev.no-fee", + "--chain-id", + "0x123", ]; let config = NodeArgs::parse_from(args.clone()).with_config_file().unwrap().config().unwrap(); assert_eq!(config.execution.validation_max_steps, 1234); - assert_eq!(config.execution.invocation_max_steps, 10_000_000); - assert_eq!(config.chain.id, ChainId::parse("KATANA").unwrap()); + assert_eq!(config.execution.invocation_max_steps, 9988); assert!(!config.dev.fee); assert_matches!(config.dev.fixed_gas_prices, Some(prices) => { assert_eq!(prices.gas_price.eth, 254); @@ -732,5 +728,6 @@ validate_max_steps = 500 assert_eq!(config.chain.genesis.sequencer_address, address!("0x100")); assert_eq!(config.chain.genesis.gas_prices.eth, 9999); assert_eq!(config.chain.genesis.gas_prices.strk, 8888); + assert_eq!(config.chain.id, ChainId::Id(Felt::from_str("0x123").unwrap())); } } diff --git a/bin/katana/src/cli/options.rs b/bin/katana/src/cli/options.rs index c915ce3118..83b90eb0b5 100644 --- a/bin/katana/src/cli/options.rs +++ b/bin/katana/src/cli/options.rs @@ -108,6 +108,18 @@ pub struct StarknetOptions { pub genesis: Option, } +impl StarknetOptions { + pub fn merge(&mut self, other: Option<&Self>) { + if let Some(other) = other { + self.environment.merge(Some(&other.environment)); + + if self.genesis.is_none() { + self.genesis = other.genesis.clone(); + } + } + } +} + #[derive(Debug, Args, Clone, Serialize, Deserialize, PartialEq)] #[command(next_help_heading = "Environment options")] pub struct EnvironmentOptions { @@ -143,6 +155,24 @@ impl Default for EnvironmentOptions { } } +impl EnvironmentOptions { + pub fn merge(&mut self, other: Option<&Self>) { + if let Some(other) = other { + if self.chain_id.is_none() { + self.chain_id = other.chain_id; + } + + if self.validate_max_steps == DEFAULT_VALIDATION_MAX_STEPS { + self.validate_max_steps = other.validate_max_steps; + } + + if self.invoke_max_steps == DEFAULT_INVOCATION_MAX_STEPS { + self.invoke_max_steps = other.invoke_max_steps; + } + } + } +} + #[derive(Debug, Args, Clone, Serialize, Deserialize, PartialEq)] #[command(next_help_heading = "Development options")] #[serde(rename = "dev")] @@ -192,6 +222,32 @@ impl Default for DevOptions { } } +impl DevOptions { + pub fn merge(&mut self, other: Option<&Self>) { + if let Some(other) = other { + if !self.dev { + self.dev = other.dev; + } + + if self.seed == DEFAULT_DEV_SEED { + self.seed = other.seed.clone(); + } + + if self.total_accounts == DEFAULT_DEV_ACCOUNTS { + self.total_accounts = other.total_accounts; + } + + if !self.no_fee { + self.no_fee = other.no_fee; + } + + if !self.no_account_validation { + self.no_account_validation = other.no_account_validation; + } + } + } +} + #[derive(Debug, Args, Clone, Serialize, Deserialize, Default, PartialEq)] #[command(next_help_heading = "Forking options")] pub struct ForkingOptions {