diff --git a/starknet-core/src/types/serde_impls.rs b/starknet-core/src/types/serde_impls.rs index ac959764..39ee3e49 100644 --- a/starknet-core/src/types/serde_impls.rs +++ b/starknet-core/src/types/serde_impls.rs @@ -1,12 +1,14 @@ -use alloc::{format, string::String}; +use alloc::{fmt::Formatter, format}; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize}; use serde_with::{DeserializeAs, SerializeAs}; use super::{SyncStatus, SyncStatusType}; pub(crate) struct NumAsHex; +struct NumAsHexVisitor; + impl SerializeAs for NumAsHex { fn serialize_as(value: &u64, serializer: S) -> Result where @@ -21,14 +23,48 @@ impl<'de> DeserializeAs<'de, u64> for NumAsHex { where D: Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - match u64::from_str_radix(&value[2..], 16) { + deserializer.deserialize_any(NumAsHexVisitor) + } +} + +impl<'de> Visitor<'de> for NumAsHexVisitor { + type Value = u64; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string or number") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match u64::from_str_radix(v.trim_start_matches("0x"), 16) { Ok(value) => Ok(value), Err(err) => Err(serde::de::Error::custom(format!( "invalid hex string: {err}" ))), } } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + match v.try_into() { + Ok(value) => self.visit_u64(value), + Err(_) => Err(serde::de::Error::custom(format!( + "value cannot be negative: {}", + v + ))), + } + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(v) + } } #[derive(Deserialize)] @@ -231,7 +267,13 @@ mod enum_ser_impls { #[cfg(test)] mod tests { - use super::super::{BlockId, BlockTag, FieldElement}; + use serde::Deserialize; + use serde_with::serde_as; + + use super::{ + super::{BlockId, BlockTag, FieldElement}, + NumAsHex, + }; #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] @@ -251,4 +293,16 @@ mod tests { assert_eq!(serde_json::from_str::(json).unwrap(), block_id); } } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_num_as_hex_deser() { + #[serde_as] + #[derive(Debug, PartialEq, Eq, Deserialize)] + struct Value(#[serde_as(as = "NumAsHex")] u64); + + for (num, json) in [(Value(100), "\"0x64\""), (Value(100), "100")].into_iter() { + assert_eq!(serde_json::from_str::(json).unwrap(), num); + } + } }