From 2a092669264768cc2d58d3787a6b99658b10693f Mon Sep 17 00:00:00 2001 From: ok300 <106775972+ok300@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:24:51 +0200 Subject: [PATCH] Keyset ID: fix deserialization edge-case, add unit tests --- crates/cdk/Cargo.toml | 2 +- crates/cdk/src/nuts/nut02.rs | 118 +++++++++++++++++------------------ 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/crates/cdk/Cargo.toml b/crates/cdk/Cargo.toml index b72e50d23..e89311a53 100644 --- a/crates/cdk/Cargo.toml +++ b/crates/cdk/Cargo.toml @@ -35,7 +35,7 @@ reqwest = { version = "0.12", default-features = false, features = [ ], optional = true } serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" -serde_with = "3.1" +serde_with = "3" tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } thiserror = "1" futures = { version = "0.3.28", default-features = false, optional = true } diff --git a/crates/cdk/src/nuts/nut02.rs b/crates/cdk/src/nuts/nut02.rs index 32e961909..0de24c984 100644 --- a/crates/cdk/src/nuts/nut02.rs +++ b/crates/cdk/src/nuts/nut02.rs @@ -18,7 +18,7 @@ use bitcoin::hashes::Hash; use bitcoin::key::Secp256k1; #[cfg(feature = "mint")] use bitcoin::secp256k1; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use serde_with::{serde_as, VecSkipError}; use thiserror::Error; @@ -86,10 +86,11 @@ impl fmt::Display for KeySetVersion { /// A keyset ID is an identifier for a specific keyset. It can be derived by /// anyone who knows the set of public keys of a mint. The keyset ID **CAN** -/// be stored in a Cashu token such that the token can be used to identify +/// be stored in a Cashu token such that the token can be used to identify /// which mint or keyset it was generated from. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[serde(into = "String", try_from = "String")] +#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema), schema(as = String))] pub struct Id { version: KeySetVersion, id: [u8; Self::BYTELEN], @@ -130,17 +131,16 @@ impl fmt::Display for Id { } } -impl FromStr for Id { - type Err = Error; +impl TryFrom for Id { + type Error = Error; - fn from_str(s: &str) -> Result { - // Check if the string length is valid + fn try_from(s: String) -> Result { if s.len() != 16 { return Err(Error::Length); } Ok(Self { - version: KeySetVersion::Version00, + version: KeySetVersion::from_byte(&hex::decode(&s[..2])?[0])?, id: hex::decode(&s[2..])? .try_into() .map_err(|_| Error::Length)?, @@ -148,63 +148,29 @@ impl FromStr for Id { } } -impl Serialize for Id { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_string()) +impl FromStr for Id { + type Err = Error; + + fn from_str(s: &str) -> Result { + Self::try_from(s.to_string()) } } -impl<'de> Deserialize<'de> for Id { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct IdVisitor; - - impl<'de> serde::de::Visitor<'de> for IdVisitor { - type Value = Id; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("Expecting a 14 char hex string") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - Id::from_str(v).map_err(|e| match e { - Error::Length => E::custom(format!( - "Invalid Length: Expected {}, got {}: - {}", - Id::STRLEN, - v.len(), - v - )), - _ => E::custom(e), - }) - } - } - - deserializer.deserialize_str(IdVisitor) +impl From for String { + fn from(value: Id) -> Self { + value.to_string() } } impl From<&Keys> for Id { + /// As per NUT-02: + /// 1. sort public keys by their amount in ascending order + /// 2. concatenate all public keys to one string + /// 3. HASH_SHA256 the concatenated public keys + /// 4. take the first 14 characters of the hex-encoded hash + /// 5. prefix it with a keyset ID version byte fn from(map: &Keys) -> Self { - // REVIEW: Is it 16 or 14 bytes - /* NUT-02 - 1 - sort public keys by their amount in ascending order - 2 - concatenate all public keys to one string - 3 - HASH_SHA256 the concatenated public keys - 4 - take the first 14 characters of the hex-encoded hash - 5 - prefix it with a keyset ID version byte - */ - let mut keys: Vec<(&AmountStr, &super::PublicKey)> = map.iter().collect(); - keys.sort_by_key(|(amt, _v)| *amt); let pubkeys_concat: Vec = keys @@ -400,12 +366,14 @@ impl From<&MintKeys> for Id { #[cfg(test)] mod test { - use std::str::FromStr; + use rand::RngCore; + use super::{KeySetInfo, Keys, KeysetResponse}; - use crate::nuts::nut02::Id; + use crate::nuts::nut02::{Error, Id}; use crate::nuts::KeysResponse; + use crate::util::hex; const SHORT_KEYSET_ID: &str = "00456a94ab4e1c46"; const SHORT_KEYSET: &str = r#" @@ -547,4 +515,36 @@ mod test { assert_eq!(keys_response.keysets.len(), 2); } + + fn generate_random_id() -> Id { + let mut rand_bytes = vec![0u8; 8]; + rand::thread_rng().fill_bytes(&mut rand_bytes[1..]); + Id::from_bytes(&rand_bytes) + .unwrap_or_else(|e| panic!("Failed to create Id from {}: {e}", hex::encode(rand_bytes))) + } + + #[test] + fn test_id_serialization() { + let id = generate_random_id(); + let id_str = id.to_string(); + + assert!(id_str.chars().all(|c| c.is_ascii_hexdigit())); + assert_eq!(16, id_str.len()); + assert_eq!(id_str.to_lowercase(), id_str); + } + + #[test] + fn test_id_deserialization() { + let id_from_short_str = Id::from_str("00123"); + assert!(matches!(id_from_short_str, Err(Error::Length))); + + let id_from_non_hex_str = Id::from_str(&SHORT_KEYSET_ID.replace('a', "x")); + assert!(matches!(id_from_non_hex_str, Err(Error::HexError(_)))); + + let id_invalid_version = Id::from_str(&SHORT_KEYSET_ID.replace("00", "99")); + assert!(matches!(id_invalid_version, Err(Error::UnknownVersion))); + + let id_from_uppercase = Id::from_str(&SHORT_KEYSET_ID.to_uppercase()); + assert!(id_from_uppercase.is_ok()); + } }