diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs index 7b755eda..0a7de158 100644 --- a/crates/cdk-axum/src/ws/subscribe.rs +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -11,17 +11,26 @@ use cdk::{ pub struct Method(Params); #[derive(Debug, Clone, serde::Serialize)] +/// The response to a subscription request pub struct Response { + /// Status status: String, + /// Subscription ID #[serde(rename = "subId")] sub_id: SubId, } #[derive(Debug, Clone, serde::Serialize)] +/// The notification +/// +/// This is the notification that is sent to the client when an event matches a +/// subscription pub struct Notification { + /// The subscription ID #[serde(rename = "subId")] pub sub_id: SubId, + /// The notification payload pub payload: NotificationPayload, } @@ -39,12 +48,21 @@ impl From<(SubId, NotificationPayload)> for WsNotification { impl WsHandle for Method { type Response = Response; + /// The `handle` method is called when a client sends a subscription request async fn handle(self, context: &mut WsContext) -> Result { let sub_id = self.0.id.clone(); if context.subscriptions.contains_key(&sub_id) { + // Subscription ID already exits. Returns an error instead of + // replacing the other subscription or avoiding it. return Err(WsError::InvalidParams); } - let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; + + let mut subscription = context + .state + .mint + .pubsub_manager + .subscribe(self.0.clone()) + .await; let publisher = context.publisher.clone(); context.subscriptions.insert( sub_id.clone(), diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index 2ba54595..24bbe89b 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -35,7 +35,7 @@ async fn get_notification> + Unpin, E: De .unwrap(); let mut response: serde_json::Value = - serde_json::from_str(&msg.to_text().unwrap()).expect("valid json"); + serde_json::from_str(msg.to_text().unwrap()).expect("valid json"); let mut params_raw = response .as_object_mut() @@ -112,6 +112,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> { assert!(melt_response.preimage.is_some()); assert!(melt_response.state == MeltQuoteState::Paid); + let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; + // first message is the current state + assert_eq!("test-sub", sub_id); + let payload = match payload { + NotificationPayload::MeltQuoteBolt11Response(melt) => melt, + _ => panic!("Wrong payload"), + }; + assert_eq!(payload.amount + payload.fee_reserve, 100.into()); + assert_eq!(payload.quote, melt.id); + assert_eq!(payload.state, MeltQuoteState::Unpaid); + + // get current state let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; assert_eq!("test-sub", sub_id); let payload = match payload { diff --git a/crates/cdk/Cargo.toml b/crates/cdk/Cargo.toml index 81673861..4d72da31 100644 --- a/crates/cdk/Cargo.toml +++ b/crates/cdk/Cargo.toml @@ -39,7 +39,7 @@ serde_json = "1" serde_with = "3" tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } thiserror = "1" -futures = { version = "0.3.28", default-features = false, optional = true } +futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] } url = "2.3" utoipa = { version = "4", optional = true } uuid = { version = "1", features = ["v4"] } diff --git a/crates/cdk/src/mint/mod.rs b/crates/cdk/src/mint/mod.rs index 5c8e28e1..dbbdc59b 100644 --- a/crates/cdk/src/mint/mod.rs +++ b/crates/cdk/src/mint/mod.rs @@ -185,7 +185,7 @@ impl Mint { Ok(Self { mint_url: MintUrl::from_str(mint_url)?, keysets: Arc::new(RwLock::new(active_keysets)), - pubsub_manager: Default::default(), + pubsub_manager: Arc::new(localstore.clone().into()), secp_ctx, quote_ttl, xpriv, diff --git a/crates/cdk/src/nuts/mod.rs b/crates/cdk/src/nuts/mod.rs index 79bfb0d8..eb1f8170 100644 --- a/crates/cdk/src/nuts/mod.rs +++ b/crates/cdk/src/nuts/mod.rs @@ -18,6 +18,7 @@ pub mod nut12; pub mod nut13; pub mod nut14; pub mod nut15; +#[cfg(feature = "mint")] pub mod nut17; pub mod nut18; @@ -48,5 +49,6 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions}; pub use nut12::{BlindSignatureDleq, ProofDleq}; pub use nut14::HTLCWitness; pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings}; +#[cfg(feature = "mint")] pub use nut17::{NotificationPayload, PubSubManager}; pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport}; diff --git a/crates/cdk/src/nuts/nut02.rs b/crates/cdk/src/nuts/nut02.rs index 0de24c98..57393980 100644 --- a/crates/cdk/src/nuts/nut02.rs +++ b/crates/cdk/src/nuts/nut02.rs @@ -114,14 +114,33 @@ impl Id { } } -impl TryFrom for u64 { +impl TryFrom for u32 { type Error = Error; fn try_from(value: Id) -> Result { let hex_bytes: [u8; 8] = value.to_bytes().try_into().map_err(|_| Error::Length)?; let int = u64::from_be_bytes(hex_bytes); - Ok(int % (2_u64.pow(31) - 1)) + let result = (int % (2_u64.pow(31) - 1)) as u32; + Ok(result) + } +} + +impl TryFrom for Id { + type Error = Error; + fn try_from(value: u64) -> Result { + let bytes = value.to_be_bytes(); + Self::from_bytes(&bytes) + } +} + +impl TryFrom for u64 { + type Error = Error; + + fn try_from(value: Id) -> Result { + let bytes = value.to_bytes(); + let byte_array: [u8; 8] = bytes.try_into().map_err(|_| Error::Length)?; + Ok(u64::from_be_bytes(byte_array)) } } @@ -490,10 +509,28 @@ mod test { fn test_to_int() { let id = Id::from_str("009a1f293253e41e").unwrap(); - let id_int = u64::try_from(id).unwrap(); + let id_int = u32::try_from(id).unwrap(); assert_eq!(864559728, id_int) } + #[test] + fn test_to_u64_and_back() { + let id = Id::from_str("009a1f293253e41e").unwrap(); + + let id_long = u64::try_from(id).unwrap(); + assert_eq!(43381408211919902, id_long); + + let new_id = Id::try_from(id_long).unwrap(); + assert_eq!(id, new_id); + } + + #[test] + fn test_id_from_invalid_byte_length() { + let three_bytes = [0x01, 0x02, 0x03]; + let result = Id::from_bytes(&three_bytes); + assert!(result.is_err(), "Expected an invalid byte length error"); + } + #[test] fn test_keyset_bytes() { let id = Id::from_str("009a1f293253e41e").unwrap(); diff --git a/crates/cdk/src/nuts/nut06.rs b/crates/cdk/src/nuts/nut06.rs index 9ecabe87..17ba18b6 100644 --- a/crates/cdk/src/nuts/nut06.rs +++ b/crates/cdk/src/nuts/nut06.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::nut01::PublicKey; -use super::{nut04, nut05, nut15, nut17, MppMethodSettings}; +use super::{nut04, nut05, nut15, MppMethodSettings}; /// Mint Version #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -238,7 +238,8 @@ pub struct Nuts { /// NUT17 Settings #[serde(default)] #[serde(rename = "17")] - pub nut17: nut17::SupportedSettings, + #[cfg(feature = "mint")] + pub nut17: super::nut17::SupportedSettings, } impl Nuts { diff --git a/crates/cdk/src/nuts/nut17.rs b/crates/cdk/src/nuts/nut17/mod.rs similarity index 93% rename from crates/cdk/src/nuts/nut17.rs rename to crates/cdk/src/nuts/nut17/mod.rs index 3d64fad9..7d2a3c4a 100644 --- a/crates/cdk/src/nuts/nut17.rs +++ b/crates/cdk/src/nuts/nut17/mod.rs @@ -1,5 +1,13 @@ //! Specific Subscription for the cdk crate +use std::{ops::Deref, sync::Arc}; + +use serde::{Deserialize, Serialize}; + +mod on_subscription; + +use crate::cdk_database::{self, MintDatabase}; +use crate::nuts::{BlindSignature, CurrencyUnit, PaymentMethod}; use crate::{ nuts::{ MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState, @@ -7,8 +15,9 @@ use crate::{ }, pub_sub::{self, Index, Indexable, SubscriptionGlobalId}, }; -use serde::{Deserialize, Serialize}; -use std::ops::Deref; + +pub use crate::pub_sub::SubId; +pub use on_subscription::OnSubscription; /// Subscription Parameter according to the standard #[derive(Debug, Clone, Serialize, Deserialize)] @@ -57,10 +66,6 @@ impl Default for SupportedMethods { } } -pub use crate::pub_sub::SubId; - -use super::{BlindSignature, CurrencyUnit, PaymentMethod}; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(untagged)] /// Subscription response @@ -145,15 +150,27 @@ impl From for Vec> { } /// Manager -#[derive(Default)] /// Publish–subscribe manager /// /// Nut-17 implementation is system-wide and not only through the WebSocket, so /// it is possible for another part of the system to subscribe to events. -pub struct PubSubManager(pub_sub::Manager); +pub struct PubSubManager(pub_sub::Manager); + +#[allow(clippy::default_constructed_unit_structs)] +impl Default for PubSubManager { + fn default() -> Self { + PubSubManager(OnSubscription::default().into()) + } +} + +impl From + Send + Sync>> for PubSubManager { + fn from(val: Arc + Send + Sync>) -> Self { + PubSubManager(OnSubscription(Some(val)).into()) + } +} impl Deref for PubSubManager { - type Target = pub_sub::Manager; + type Target = pub_sub::Manager; fn deref(&self) -> &Self::Target { &self.0 diff --git a/crates/cdk/src/nuts/nut17/on_subscription.rs b/crates/cdk/src/nuts/nut17/on_subscription.rs new file mode 100644 index 00000000..9aaf8c18 --- /dev/null +++ b/crates/cdk/src/nuts/nut17/on_subscription.rs @@ -0,0 +1,111 @@ +//! On Subscription +//! +//! This module contains the code that is triggered when a new subscription is created. +use std::{collections::HashMap, sync::Arc}; + +use super::{Kind, NotificationPayload}; +use crate::{ + cdk_database::{self, MintDatabase}, + nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey}, + pub_sub::OnNewSubscription, +}; + +#[derive(Default)] +/// Subscription Init +/// +/// This struct triggers code when a new subscription is created. +/// +/// It is used to send the initial state of the subscription to the client. +pub struct OnSubscription( + pub(crate) Option + Send + Sync>>, +); + +#[async_trait::async_trait] +impl OnNewSubscription for OnSubscription { + type Event = NotificationPayload; + type Index = (String, Kind); + + async fn on_new_subscription( + &self, + request: &[&Self::Index], + ) -> Result, String> { + let datastore = if let Some(localstore) = self.0.as_ref() { + localstore + } else { + return Ok(vec![]); + }; + + let mut to_return = vec![]; + + for (kind, values) in request.iter().fold( + HashMap::new(), + |mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| { + acc.entry(kind).or_default().push(data); + acc + }, + ) { + match kind { + Kind::Bolt11MeltQuote => { + let queries = values + .iter() + .map(|id| datastore.get_melt_quote(id)) + .collect::>(); + + to_return.extend( + futures::future::try_join_all(queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MeltQuoteBolt11Response| x.into()) + .collect::>() + }) + .map_err(|e| e.to_string())?, + ); + } + Kind::Bolt11MintQuote => { + let queries = values + .iter() + .map(|id| datastore.get_mint_quote(id)) + .collect::>(); + + to_return.extend( + futures::future::try_join_all(queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MintQuoteBolt11Response| x.into()) + .collect::>() + }) + .map_err(|e| e.to_string())?, + ); + } + Kind::ProofState => { + let public_keys = values + .iter() + .map(PublicKey::from_hex) + .collect::, _>>() + .map_err(|e| e.to_string())?; + + to_return.extend( + datastore + .get_proofs_states(&public_keys) + .await + .map_err(|e| e.to_string())? + .into_iter() + .enumerate() + .filter_map(|(idx, state)| { + state.map(|state| (public_keys[idx], state).into()) + }) + .map(|state: ProofState| state.into()), + ); + } + } + } + + Ok(to_return) + } +} diff --git a/crates/cdk/src/pub_sub/mod.rs b/crates/cdk/src/pub_sub/mod.rs index f9347c78..c511c803 100644 --- a/crates/cdk/src/pub_sub/mod.rs +++ b/crates/cdk/src/pub_sub/mod.rs @@ -37,6 +37,25 @@ pub const DEFAULT_REMOVE_SIZE: usize = 10_000; /// Default channel size for subscription buffering pub const DEFAULT_CHANNEL_SIZE: usize = 10; +#[async_trait::async_trait] +/// On New Subscription trait +/// +/// This trait is optional and it is used to notify the application when a new +/// subscription is created. This is useful when the application needs to send +/// the initial state to the subscriber upon subscription +pub trait OnNewSubscription { + /// Index type + type Index; + /// Subscription event type + type Event; + + /// Called when a new subscription is created + async fn on_new_subscription( + &self, + request: &[&Self::Index], + ) -> Result, String>; +} + /// Subscription manager /// /// This object keep track of all subscription listener and it is also @@ -45,21 +64,24 @@ pub const DEFAULT_CHANNEL_SIZE: usize = 10; /// The content of the notification is not relevant to this scope and it is up /// to the application, therefore the generic T is used instead of a specific /// type -pub struct Manager +pub struct Manager where T: Indexable + Clone + Send + Sync + 'static, I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { indexes: IndexTree, + on_new_subscription: Option, unsubscription_sender: mpsc::Sender<(SubId, Vec>)>, active_subscriptions: Arc, background_subscription_remover: Option>, } -impl Default for Manager +impl Default for Manager where T: Indexable + Clone + Send + Sync + 'static, I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { fn default() -> Self { let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE); @@ -72,6 +94,7 @@ where storage.clone(), active_subscriptions.clone(), ))), + on_new_subscription: None, unsubscription_sender: sender, active_subscriptions, indexes: storage, @@ -79,10 +102,24 @@ where } } -impl Manager +impl From for Manager where T: Indexable + Clone + Send + Sync + 'static, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, + I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, +{ + fn from(value: F) -> Self { + let mut manager: Self = Default::default(); + manager.on_new_subscription = Some(value); + manager + } +} + +impl Manager +where + T: Indexable + Clone + Send + Sync + 'static, + I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { #[inline] /// Broadcast an event to all listeners @@ -132,8 +169,29 @@ where ) -> ActiveSubscription { let (sender, receiver) = mpsc::channel(10); let sub_id: SubId = params.as_ref().clone(); + let indexes: Vec> = params.into(); + if let Some(on_new_subscription) = self.on_new_subscription.as_ref() { + match on_new_subscription + .on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::>()) + .await + { + Ok(events) => { + for event in events { + let _ = sender.try_send((sub_id.clone(), event)); + } + } + Err(err) => { + tracing::info!( + "Failed to get initial state for subscription: {:?}, {}", + sub_id, + err + ); + } + } + } + let mut index_storage = self.indexes.write().await; for index in indexes.clone() { index_storage.insert(index, sender.clone()); @@ -180,10 +238,11 @@ where } /// Manager goes out of scope, stop all background tasks -impl Drop for Manager +impl Drop for Manager where T: Indexable + Clone + Send + Sync + 'static, I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { fn drop(&mut self) { if let Some(handler) = self.background_subscription_remover.take() { diff --git a/crates/cdk/src/wallet/keysets.rs b/crates/cdk/src/wallet/keysets.rs index 9e5babc5..546c17a5 100644 --- a/crates/cdk/src/wallet/keysets.rs +++ b/crates/cdk/src/wallet/keysets.rs @@ -50,7 +50,7 @@ impl Wallet { /// Queries mint for current keysets then gets [`Keys`] for any unknown /// keysets #[instrument(skip(self))] - pub async fn get_active_mint_keyset(&self) -> Result { + pub async fn get_active_mint_keysets(&self) -> Result, Error> { let keysets = self.client.get_mint_keysets(self.mint_url.clone()).await?; let keysets = keysets.keysets; @@ -86,6 +86,21 @@ impl Wallet { } } - active_keysets.first().ok_or(Error::NoActiveKeyset).cloned() + Ok(active_keysets) + } + + /// Get active keyset for mint with the lowest fees + /// + /// Queries mint for current keysets then gets [`Keys`] for any unknown + /// keysets + #[instrument(skip(self))] + pub async fn get_active_mint_keyset(&self) -> Result { + let active_keysets = self.get_active_mint_keysets().await?; + + let keyset_with_lowest_fee = active_keysets + .into_iter() + .min_by_key(|key| key.input_fee_ppk) + .ok_or(Error::NoActiveKeyset)?; + Ok(keyset_with_lowest_fee) } } diff --git a/rustfmt.toml b/rustfmt.toml index 9b155f2e..547421e7 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,5 +1,9 @@ tab_spaces = 4 -max_width = 100 newline_style = "Auto" reorder_imports = true reorder_modules = true +reorder_impl_items = false +indent_style = "Block" +normalize_comments = false +imports_granularity = "Module" +group_imports = "StdExternalCrate"