Skip to content

Commit

Permalink
Keyset ID: fix deserialization edge-case, add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ok300 committed Oct 25, 2024
1 parent 103574b commit 2a09266
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 60 deletions.
2 changes: 1 addition & 1 deletion crates/cdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ reqwest = { version = "0.12", default-features = false, features = [
], optional = true }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
serde_with = "3.1"
serde_with = "3"
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
thiserror = "1"
futures = { version = "0.3.28", default-features = false, optional = true }
Expand Down
118 changes: 59 additions & 59 deletions crates/cdk/src/nuts/nut02.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use bitcoin::hashes::Hash;
use bitcoin::key::Secp256k1;
#[cfg(feature = "mint")]
use bitcoin::secp256k1;
use serde::{Deserialize, Deserializer, Serialize};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, VecSkipError};
use thiserror::Error;

Expand Down Expand Up @@ -86,10 +86,11 @@ impl fmt::Display for KeySetVersion {

/// A keyset ID is an identifier for a specific keyset. It can be derived by
/// anyone who knows the set of public keys of a mint. The keyset ID **CAN**
/// be stored in a Cashu token such that the token can be used to identify
/// be stored in a Cashu token such that the token can be used to identify
/// which mint or keyset it was generated from.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(into = "String", try_from = "String")]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema), schema(as = String))]
pub struct Id {
version: KeySetVersion,
id: [u8; Self::BYTELEN],
Expand Down Expand Up @@ -130,81 +131,46 @@ impl fmt::Display for Id {
}
}

impl FromStr for Id {
type Err = Error;
impl TryFrom<String> for Id {
type Error = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the string length is valid
fn try_from(s: String) -> Result<Self, Self::Error> {
if s.len() != 16 {
return Err(Error::Length);
}

Ok(Self {
version: KeySetVersion::Version00,
version: KeySetVersion::from_byte(&hex::decode(&s[..2])?[0])?,
id: hex::decode(&s[2..])?
.try_into()
.map_err(|_| Error::Length)?,
})
}
}

impl Serialize for Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
impl FromStr for Id {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(s.to_string())
}
}

impl<'de> Deserialize<'de> for Id {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct IdVisitor;

impl<'de> serde::de::Visitor<'de> for IdVisitor {
type Value = Id;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("Expecting a 14 char hex string")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Id::from_str(v).map_err(|e| match e {
Error::Length => E::custom(format!(
"Invalid Length: Expected {}, got {}:
{}",
Id::STRLEN,
v.len(),
v
)),
_ => E::custom(e),
})
}
}

deserializer.deserialize_str(IdVisitor)
impl From<Id> for String {
fn from(value: Id) -> Self {
value.to_string()
}
}

impl From<&Keys> for Id {
/// As per NUT-02:
/// 1. sort public keys by their amount in ascending order
/// 2. concatenate all public keys to one string
/// 3. HASH_SHA256 the concatenated public keys
/// 4. take the first 14 characters of the hex-encoded hash
/// 5. prefix it with a keyset ID version byte
fn from(map: &Keys) -> Self {
// REVIEW: Is it 16 or 14 bytes
/* NUT-02
1 - sort public keys by their amount in ascending order
2 - concatenate all public keys to one string
3 - HASH_SHA256 the concatenated public keys
4 - take the first 14 characters of the hex-encoded hash
5 - prefix it with a keyset ID version byte
*/

let mut keys: Vec<(&AmountStr, &super::PublicKey)> = map.iter().collect();

keys.sort_by_key(|(amt, _v)| *amt);

let pubkeys_concat: Vec<u8> = keys
Expand Down Expand Up @@ -400,12 +366,14 @@ impl From<&MintKeys> for Id {

#[cfg(test)]
mod test {

use std::str::FromStr;

use rand::RngCore;

use super::{KeySetInfo, Keys, KeysetResponse};
use crate::nuts::nut02::Id;
use crate::nuts::nut02::{Error, Id};
use crate::nuts::KeysResponse;
use crate::util::hex;

const SHORT_KEYSET_ID: &str = "00456a94ab4e1c46";
const SHORT_KEYSET: &str = r#"
Expand Down Expand Up @@ -547,4 +515,36 @@ mod test {

assert_eq!(keys_response.keysets.len(), 2);
}

fn generate_random_id() -> Id {
let mut rand_bytes = vec![0u8; 8];
rand::thread_rng().fill_bytes(&mut rand_bytes[1..]);
Id::from_bytes(&rand_bytes)
.unwrap_or_else(|e| panic!("Failed to create Id from {}: {e}", hex::encode(rand_bytes)))
}

#[test]
fn test_id_serialization() {
let id = generate_random_id();
let id_str = id.to_string();

assert!(id_str.chars().all(|c| c.is_ascii_hexdigit()));
assert_eq!(16, id_str.len());
assert_eq!(id_str.to_lowercase(), id_str);
}

#[test]
fn test_id_deserialization() {
let id_from_short_str = Id::from_str("00123");
assert!(matches!(id_from_short_str, Err(Error::Length)));

let id_from_non_hex_str = Id::from_str(&SHORT_KEYSET_ID.replace('a', "x"));
assert!(matches!(id_from_non_hex_str, Err(Error::HexError(_))));

let id_invalid_version = Id::from_str(&SHORT_KEYSET_ID.replace("00", "99"));
assert!(matches!(id_invalid_version, Err(Error::UnknownVersion)));

let id_from_uppercase = Id::from_str(&SHORT_KEYSET_ID.to_uppercase());
assert!(id_from_uppercase.is_ok());
}
}

0 comments on commit 2a09266

Please sign in to comment.