From d262fa43c3a7bad51db4b17b278413b2a0f08313 Mon Sep 17 00:00:00 2001 From: Yogalholic Date: Thu, 21 Mar 2024 14:30:18 +0100 Subject: [PATCH] update TryFrom implementation --- crates/katana/primitives/src/genesis/json.rs | 277 +++++++++++-------- 1 file changed, 165 insertions(+), 112 deletions(-) diff --git a/crates/katana/primitives/src/genesis/json.rs b/crates/katana/primitives/src/genesis/json.rs index aacdd8eaef..99028df931 100644 --- a/crates/katana/primitives/src/genesis/json.rs +++ b/crates/katana/primitives/src/genesis/json.rs @@ -1,11 +1,10 @@ //! JSON representation of the genesis configuration. Used to deserialize the genesis configuration //! from a JSON file. -use std::collections::{hash_map, BTreeMap, HashMap}; +use std::collections::hash_map::Entry; +use std::collections::{BTreeMap, HashMap}; use std::fs::File; -use std::io::{ - BufReader, {self}, -}; +use std::io::{self, BufReader}; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::Arc; @@ -99,13 +98,12 @@ pub struct GenesisClassJson { /// The class hash of the contract. If not provided, the class hash is computed from the /// class at `path`. pub class_hash: Option, - #[serde(default, skip_serializing)] name: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] -pub enum ClassOrHash { +pub enum NameOrHash { ClassName(String), ClassHash(ClassHash), } @@ -119,7 +117,7 @@ pub struct FeeTokenConfigJson { pub decimals: u8, /// The class hash of the fee token contract. /// If not provided, the default fee token class is used. - pub class: Option, + pub class: Option, /// To initialize the fee token contract storage pub storage: Option>, } @@ -131,7 +129,7 @@ pub struct UniversalDeployerConfigJson { pub address: Option, /// The class hash of the universal deployer contract. /// If not provided, the default UD class is used. - pub class: Option, + pub class: Option, /// To initialize the UD contract storage pub storage: Option>, } @@ -139,7 +137,7 @@ pub struct UniversalDeployerConfigJson { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "camelCase")] pub struct GenesisContractJson { - pub class: Option, + pub class: Option, pub balance: Option, pub nonce: Option, pub storage: Option>, @@ -195,7 +193,6 @@ pub enum GenesisJsonError { #[error(transparent)] Other(#[from] anyhow::Error), - } // The JSON representation of the [Genesis] configuration. This `struct` is used to deserialize @@ -244,15 +241,9 @@ impl GenesisJson { path.pop(); let mut genesis: Self = serde_json::from_reader(BufReader::new(file))?; - let mut class_name_map = HashMap::new(); - for class in &genesis.classes { - if let Some(name) = &class.name { - class_name_map.insert(name.to_string(), class.class_hash.unwrap()); - } - } // resolves the class paths, if any - genesis.resolve_class_artifacts(path, &class_name_map)?; + genesis.resolve_class_artifacts(path)?; Ok(genesis) } @@ -266,15 +257,8 @@ impl GenesisJson { pub fn resolve_class_artifacts( &mut self, base_path: impl AsRef, - class_name_map: &HashMap ) -> Result<(), GenesisJsonError> { for entry in &mut self.classes { - if let Some(name) = &entry.name { - if let Some(hash) = class_name_map.get(name) { - entry.class_hash = Some(*hash); - continue; - } - } if let PathOrFullArtifact::Path(rel_path) = &entry.class { let base_path = base_path.as_ref().to_path_buf(); let artifact = class_artifact_at_path(base_path, rel_path)?; @@ -289,89 +273,94 @@ impl TryFrom for Genesis { type Error = GenesisJsonError; fn try_from(value: GenesisJson) -> Result { - let mut name_to_class_hash = HashMap::new(); - let mut classes: HashMap = value - .classes - .into_par_iter() - .map(|entry| { - let GenesisClassJson { class, class_hash, name } = entry; - // Use name if present - if let (Some(name), Some(class_hash)) = (name, class_hash) { - name_to_class_hash.insert(name, class_hash); - } - let artifact = match class { - PathOrFullArtifact::Artifact(artifact) => artifact, - PathOrFullArtifact::Path(path) => { - return Err(GenesisJsonError::UnresolvedClassPath(path)); - } - }; - - let sierra = serde_json::from_value::(artifact.clone()); - - let (class_hash, compiled_class_hash, sierra, casm) = match sierra { - Ok(sierra) => { - let casm: ContractClass = serde_json::from_value(artifact)?; - let casm = CasmContractClass::from_contract_class(casm, true)?; - - // check if the class hash is provided, otherwise compute it from the - // artifacts - let class_hash = class_hash.unwrap_or(sierra.class_hash()?); - let compiled_hash = casm.compiled_class_hash().to_be_bytes(); - - ( - class_hash, - FieldElement::from_bytes_be(&compiled_hash)?, - Some(Arc::new(sierra.flatten()?)), - Arc::new(CompiledContractClass::V1(CompiledContractClassV1::try_from( - casm, - )?)), - ) - } - - // if the artifact is not a sierra contract, we check if it's a legacy contract - Err(_) => { - let casm: CompiledContractClassV0 = - serde_json::from_value(artifact.clone())?; - - let class_hash = if let Some(class_hash) = class_hash { - class_hash - } else { - let casm: LegacyContractClass = - serde_json::from_value(artifact.clone())?; - casm.class_hash()? - }; - (class_hash, class_hash, None, Arc::new(CompiledContractClass::V0(casm))) - } + let mut name_to_class_hash = HashMap::new(); + let mut classes: HashMap = value + .classes + .into_par_iter() + .map(|entry| { + let GenesisClassJson { class, class_hash, name } = entry; + let artifact = match class { + PathOrFullArtifact::Artifact(artifact) => artifact, + PathOrFullArtifact::Path(path) => { + return Err(GenesisJsonError::UnresolvedClassPath(path)); + } + }; + + let sierra = serde_json::from_value::(artifact.clone()); + + let (class_hash, compiled_class_hash, sierra, casm) = match sierra { + Ok(sierra) => { + let casm: ContractClass = serde_json::from_value(artifact)?; + let casm = CasmContractClass::from_contract_class(casm, true)?; + + // check if the class hash is provided, otherwise compute it from the + // artifacts + let class_hash = class_hash.unwrap_or(sierra.class_hash()?); + let compiled_hash = casm.compiled_class_hash().to_be_bytes(); + + ( + class_hash, + FieldElement::from_bytes_be(&compiled_hash)?, + Some(Arc::new(sierra.flatten()?)), + Arc::new(CompiledContractClass::V1(CompiledContractClassV1::try_from( + casm, + )?)), + ) + } + + // if the artifact is not a sierra contract, we check if it's a legacy contract + Err(_) => { + let casm: CompiledContractClassV0 = + serde_json::from_value(artifact.clone())?; + + let class_hash = if let Some(class_hash) = class_hash { + class_hash + } else { + let casm: LegacyContractClass = + serde_json::from_value(artifact.clone())?; + casm.class_hash()? }; + (class_hash, class_hash, None, Arc::new(CompiledContractClass::V0(casm))) + } + }; - // Add the class name and class hash to the mapping - let _ = name.map(|name| { - name_to_class_hash.insert(name, class_hash); - }); + // Add the class name and class hash to the mapping + let _ = name.map(|name| { + name_to_class_hash.insert(name, class_hash); + }); - Ok((class_hash, GenesisClass { compiled_class_hash, sierra, casm })) - }) - .collect::>()?; + Ok((class_hash, GenesisClass { compiled_class_hash, sierra, casm })) + }) + .collect::>()?; // Populate the classes and name_to_class_hash - let mut fee_token = FeeTokenConfig { + let mut fee_token = FeeTokenConfig { name: value.fee_token.name, symbol: value.fee_token.symbol, total_supply: U256::zero(), decimals: value.fee_token.decimals, address: value.fee_token.address.unwrap_or(DEFAULT_FEE_TOKEN_ADDRESS), class_hash: match value.fee_token.class { - Some(ClassOrHash::ClassHash(class_hash)) => class_hash, - Some(ClassOrHash::ClassName(class_name)) => { - *name_to_class_hash.get(&class_name).ok_or_else(|| value_out_of_range_error::ValueOutOfRangeError)? - }, - None => DEFAULT_LEGACY_ERC20_CONTRACT_CLASS_HASH, - }, + Some(NameOrHash::ClassHash(class_hash)) => class_hash, + Some(NameOrHash::ClassName(class_name)) => *name_to_class_hash + .get(&class_name) + .ok_or_else(|| value_out_of_range_error::ValueOutOfRangeError)?, + None => DEFAULT_LEGACY_ERC20_CONTRACT_CLASS_HASH, + }, storage: value.fee_token.storage, }; match value.fee_token.class { - Some(ClassOrHash::ClassHash(hash)) => { + Some(NameOrHash::ClassHash(hash)) => { + if !classes.contains_key(&hash) { + return Err(GenesisJsonError::MissingClass(hash)); + } + } + + Some(NameOrHash::ClassName(name)) => { + let hash = *name_to_class_hash + .get(&name) + .ok_or_else(|| value_out_of_range_error::ValueOutOfRangeError)?; if !classes.contains_key(&hash) { return Err(GenesisJsonError::MissingClass(hash)); } @@ -391,9 +380,54 @@ impl TryFrom for Genesis { } }; + // Populate the classes and name_to_class_hash + let mut fee_token = FeeTokenConfig { + name: value.fee_token.name, + symbol: value.fee_token.symbol, + total_supply: U256::zero(), + decimals: value.fee_token.decimals, + address: value.fee_token.address.unwrap_or(DEFAULT_FEE_TOKEN_ADDRESS), + class_hash: match value.fee_token.class { + Some(NameOrHash::ClassHash(class_hash)) => class_hash, + Some(NameOrHash::ClassName(class_name)) => *name_to_class_hash + .get(&class_name) + .ok_or_else(|| value_out_of_range_error::ValueOutOfRangeError)?, + None => DEFAULT_LEGACY_ERC20_CONTRACT_CLASS_HASH, + }, + storage: value.fee_token.storage, + }; + + match value.fee_token.class { + Some(NameOrHash::ClassHash(hash)) => { + if !classes.contains_key(&hash) { + return Err(GenesisJsonError::MissingClass(hash)); + } + } + Some(NameOrHash::ClassName(name)) => { + let hash = *name_to_class_hash + .get(&name) + .ok_or_else(|| value_out_of_range_error::ValueOutOfRangeError)?; + if !classes.contains_key(&hash) { + return Err(GenesisJsonError::MissingClass(hash)); + } + } + None => { + let name: String = value.fee_token.name.clone(); + let _ = classes.insert( + DEFAULT_LEGACY_ERC20_CONTRACT_CLASS_HASH, + GenesisClass { + sierra: None, + casm: Arc::new(DEFAULT_LEGACY_ERC20_CONTRACT_CASM.clone()), + compiled_class_hash: DEFAULT_LEGACY_ERC20_CONTRACT_COMPILED_CLASS_HASH, + }, + ); + } + }; + + // if no class hash is provided, use the default UD contract parameters let universal_deployer = if let Some(config) = value.universal_deployer { match config.class { - Some(hash) => { + Some(NameOrHash::ClassHash(hash)) => { if !classes.contains_key(&hash) { return Err(GenesisJsonError::MissingClass(hash)); } @@ -404,8 +438,20 @@ impl TryFrom for Genesis { storage: config.storage, }) } + Some(NameOrHash::ClassName(name)) => { + let hash = *name_to_class_hash + .get(&name) + .ok_or_else(|| value_out_of_range_error::ValueOutOfRangeError)?; + if !classes.contains_key(&hash) { + return Err(GenesisJsonError::MissingClass(hash)); + } - // if no class hash is provided, use the default UD contract parameters + Some(UniversalDeployerConfig { + class_hash: hash, + address: config.address.unwrap_or(DEFAULT_UDC_ADDRESS), + storage: config.storage, + }) + } None => { let name = value.fee_token.name.clone(); let class_hash = DEFAULT_LEGACY_UDC_CLASS_HASH; @@ -444,8 +490,8 @@ impl TryFrom for Genesis { None => { // check that the default account class exists in the classes field before // inserting it - if let hash_map::Entry::Vacant(e) = - classes.entry(DEFAULT_OZ_ACCOUNT_CONTRACT_CLASS_HASH) + + if let Entry::Vacant(e) = classes.entry(DEFAULT_OZ_ACCOUNT_CONTRACT_CLASS_HASH) { // insert default account class to the classes map e.insert(GenesisClass { @@ -497,9 +543,20 @@ impl TryFrom for Genesis { for (address, contract) in value.contracts { // check that the class hash exists in the classes field + let mut class_hash = None; if let Some(hash) = contract.class { - if !classes.contains_key(&hash) { - return Err(GenesisJsonError::MissingClass(hash)); + class_hash = Some(match hash { + NameOrHash::ClassHash(hash) => hash, + NameOrHash::ClassName(name) => { + // Handle the case when the class is specified by name. + *name_to_class_hash + .get(&name) + .ok_or_else(|| GenesisJsonError::MissingClass(name.clone()))? + } + }); + + if !classes.contains_key(&class_hash) { + return Err(GenesisJsonError::MissingClass(class_hash)); } } @@ -512,7 +569,7 @@ impl TryFrom for Genesis { address, GenesisAllocation::Contract(GenesisContractAlloc { balance: contract.balance, - class_hash: contract.class, + class_hash, nonce: contract.nonce, storage: contract.storage, }), @@ -550,9 +607,9 @@ impl FromStr for GenesisJson { pub fn resolve_artifacts_and_to_base64>( mut genesis: GenesisJson, base_path: P, - class_name_map: HashMap + class_name_map: HashMap, ) -> Result, GenesisJsonError> { - genesis.resolve_class_artifacts(base_path, &class_name_map)?; + genesis.resolve_class_artifacts(base_path)?; to_base64(genesis) } @@ -572,17 +629,15 @@ pub fn to_base64(genesis: GenesisJson) -> Result, GenesisJsonError> { /// Deserialize the [GenesisJson] from base64 encoded bytes. pub fn from_base64(data: &[u8]) -> Result { - // Decode base64 bytes + // Decode base64 bytes let decoded = BASE64_STANDARD.decode(data)?; // Deserialize JSON let mut genesis_json: GenesisJson = serde_json::from_slice(&decoded)?; // Populate name field for class in &mut genesis_json.classes { - match &class.class { - - // If artifact is provided, try to extract name + // If artifact is provided, try to extract name PathOrFullArtifact::Artifact(artifact) => { if let Value::Object(obj) = artifact { if let Some(name_value) = obj.get("name") { @@ -591,14 +646,12 @@ pub fn from_base64(data: &[u8]) -> Result { } } } - }, - - // Ignore path case - PathOrFullArtifact::Path(_path) => {} + } + // Ignore path case + PathOrFullArtifact::Path(_path) => {} + } } - - } Ok(genesis_json) } @@ -1180,4 +1233,4 @@ mod tests { assert_eq!(genesis, decoded); } -} \ No newline at end of file +}