From 6e0d1546424c9eeac5bb9c7ff1ef40538ce82d67 Mon Sep 17 00:00:00 2001 From: ok300 <106775972+ok300@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:35:41 +0200 Subject: [PATCH] Add AmountStr wrapper for string serialization --- crates/cdk/src/amount.rs | 51 +++++++++++++++++++++++++++++++- crates/cdk/src/nuts/nut01/mod.rs | 14 ++++----- crates/cdk/src/nuts/nut02.rs | 5 ++-- 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/crates/cdk/src/amount.rs b/crates/cdk/src/amount.rs index 7e2d4e26..8fa53496 100644 --- a/crates/cdk/src/amount.rs +++ b/crates/cdk/src/amount.rs @@ -4,8 +4,9 @@ use std::cmp::Ordering; use std::fmt; +use std::str::FromStr; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use thiserror::Error; use crate::nuts::CurrencyUnit; @@ -210,6 +211,54 @@ impl std::ops::Div for Amount { } } +/// String wrapper for an [Amount]. +/// +/// It ser-/deserializes the inner [Amount] to a string, while at the same time using the [u64] +/// value of the [Amount] for comparison and ordering. This helps automatically sort the keys of +/// a [BTreeMap] when [AmountStr] is used as key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AmountStr(Amount); + +impl AmountStr { + pub(crate) fn from(amt: Amount) -> Self { + Self(amt) + } +} + +impl PartialOrd for AmountStr { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for AmountStr { + fn cmp(&self, other: &Self) -> Ordering { + self.0.cmp(&other.0) + } +} + +impl<'de> Deserialize<'de> for AmountStr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + u64::from_str(&s) + .map(Amount) + .map(Self) + .map_err(serde::de::Error::custom) + } +} + +impl Serialize for AmountStr { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.0.to_string()) + } +} + /// Kinds of targeting that are supported #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] pub enum SplitTarget { diff --git a/crates/cdk/src/nuts/nut01/mod.rs b/crates/cdk/src/nuts/nut01/mod.rs index c23f7936..12819222 100644 --- a/crates/cdk/src/nuts/nut01/mod.rs +++ b/crates/cdk/src/nuts/nut01/mod.rs @@ -16,7 +16,7 @@ mod secret_key; pub use self::public_key::PublicKey; pub use self::secret_key::SecretKey; use super::nut02::KeySet; -use crate::amount::Amount; +use crate::amount::{Amount, AmountStr}; /// Nut01 Error #[derive(Debug, Error)] @@ -43,14 +43,14 @@ pub enum Error { /// /// See [NUT-01] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] -pub struct Keys(BTreeMap); +pub struct Keys(BTreeMap); impl From for Keys { fn from(keys: MintKeys) -> Self { Self( keys.0 .into_iter() - .map(|(amount, keypair)| (amount, keypair.public_key)) + .map(|(amount, keypair)| (AmountStr::from(amount), keypair.public_key)) .collect(), ) } @@ -59,25 +59,25 @@ impl From for Keys { impl Keys { /// Create new [`Keys`] #[inline] - pub fn new(keys: BTreeMap) -> Self { + pub fn new(keys: BTreeMap) -> Self { Self(keys) } /// Get [`Keys`] #[inline] - pub fn keys(&self) -> &BTreeMap { + pub fn keys(&self) -> &BTreeMap { &self.0 } /// Get [`PublicKey`] for [`Amount`] #[inline] pub fn amount_key(&self, amount: Amount) -> Option { - self.0.get(&amount).copied() + self.0.get(&AmountStr::from(amount)).copied() } /// Iterate through the (`Amount`, `PublicKey`) entries in the Map #[inline] - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.0.iter() } } diff --git a/crates/cdk/src/nuts/nut02.rs b/crates/cdk/src/nuts/nut02.rs index f3e22833..79d216c0 100644 --- a/crates/cdk/src/nuts/nut02.rs +++ b/crates/cdk/src/nuts/nut02.rs @@ -25,6 +25,7 @@ use thiserror::Error; use super::nut01::Keys; #[cfg(feature = "mint")] use super::nut01::{MintKeyPair, MintKeys}; +use crate::amount::AmountStr; use crate::nuts::nut00::CurrencyUnit; use crate::util::hex; use crate::Amount; @@ -196,9 +197,9 @@ impl From<&Keys> for Id { 5 - prefix it with a keyset ID version byte */ - let mut keys: Vec<(&Amount, &super::PublicKey)> = map.iter().collect(); + let mut keys: Vec<(&AmountStr, &super::PublicKey)> = map.iter().collect(); - keys.sort_by_key(|(&amt, _v)| u64::from(amt)); + keys.sort_by_key(|(amt, _v)| *amt); let pubkeys_concat: Vec = keys .iter()