From cc5b2673674f3517d267490008f605dec3b48628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9sar=20D=2E=20Rodas?= Date: Sun, 10 Nov 2024 09:08:44 -0300 Subject: [PATCH] fix: Send current state of the subscriptions (#444) --- crates/cdk-axum/src/ws/subscribe.rs | 20 +++- crates/cdk-integration-tests/tests/regtest.rs | 14 ++- crates/cdk/Cargo.toml | 2 +- crates/cdk/src/mint/mod.rs | 2 +- crates/cdk/src/nuts/mod.rs | 2 + crates/cdk/src/nuts/nut06.rs | 5 +- .../cdk/src/nuts/{nut17.rs => nut17/mod.rs} | 31 +++-- crates/cdk/src/nuts/nut17/on_subscription.rs | 110 ++++++++++++++++++ crates/cdk/src/pub_sub/mod.rs | 69 ++++++++++- 9 files changed, 236 insertions(+), 19 deletions(-) rename crates/cdk/src/nuts/{nut17.rs => nut17/mod.rs} (94%) create mode 100644 crates/cdk/src/nuts/nut17/on_subscription.rs diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs index 7b755eda..0a7de158 100644 --- a/crates/cdk-axum/src/ws/subscribe.rs +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -11,17 +11,26 @@ use cdk::{ pub struct Method(Params); #[derive(Debug, Clone, serde::Serialize)] +/// The response to a subscription request pub struct Response { + /// Status status: String, + /// Subscription ID #[serde(rename = "subId")] sub_id: SubId, } #[derive(Debug, Clone, serde::Serialize)] +/// The notification +/// +/// This is the notification that is sent to the client when an event matches a +/// subscription pub struct Notification { + /// The subscription ID #[serde(rename = "subId")] pub sub_id: SubId, + /// The notification payload pub payload: NotificationPayload, } @@ -39,12 +48,21 @@ impl From<(SubId, NotificationPayload)> for WsNotification { impl WsHandle for Method { type Response = Response; + /// The `handle` method is called when a client sends a subscription request async fn handle(self, context: &mut WsContext) -> Result { let sub_id = self.0.id.clone(); if context.subscriptions.contains_key(&sub_id) { + // Subscription ID already exits. Returns an error instead of + // replacing the other subscription or avoiding it. return Err(WsError::InvalidParams); } - let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; + + let mut subscription = context + .state + .mint + .pubsub_manager + .subscribe(self.0.clone()) + .await; let publisher = context.publisher.clone(); context.subscriptions.insert( sub_id.clone(), diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index 2ba54595..24bbe89b 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -35,7 +35,7 @@ async fn get_notification> + Unpin, E: De .unwrap(); let mut response: serde_json::Value = - serde_json::from_str(&msg.to_text().unwrap()).expect("valid json"); + serde_json::from_str(msg.to_text().unwrap()).expect("valid json"); let mut params_raw = response .as_object_mut() @@ -112,6 +112,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> { assert!(melt_response.preimage.is_some()); assert!(melt_response.state == MeltQuoteState::Paid); + let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; + // first message is the current state + assert_eq!("test-sub", sub_id); + let payload = match payload { + NotificationPayload::MeltQuoteBolt11Response(melt) => melt, + _ => panic!("Wrong payload"), + }; + assert_eq!(payload.amount + payload.fee_reserve, 100.into()); + assert_eq!(payload.quote, melt.id); + assert_eq!(payload.state, MeltQuoteState::Unpaid); + + // get current state let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; assert_eq!("test-sub", sub_id); let payload = match payload { diff --git a/crates/cdk/Cargo.toml b/crates/cdk/Cargo.toml index 81673861..4d72da31 100644 --- a/crates/cdk/Cargo.toml +++ b/crates/cdk/Cargo.toml @@ -39,7 +39,7 @@ serde_json = "1" serde_with = "3" tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } thiserror = "1" -futures = { version = "0.3.28", default-features = false, optional = true } +futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] } url = "2.3" utoipa = { version = "4", optional = true } uuid = { version = "1", features = ["v4"] } diff --git a/crates/cdk/src/mint/mod.rs b/crates/cdk/src/mint/mod.rs index 5c8e28e1..dbbdc59b 100644 --- a/crates/cdk/src/mint/mod.rs +++ b/crates/cdk/src/mint/mod.rs @@ -185,7 +185,7 @@ impl Mint { Ok(Self { mint_url: MintUrl::from_str(mint_url)?, keysets: Arc::new(RwLock::new(active_keysets)), - pubsub_manager: Default::default(), + pubsub_manager: Arc::new(localstore.clone().into()), secp_ctx, quote_ttl, xpriv, diff --git a/crates/cdk/src/nuts/mod.rs b/crates/cdk/src/nuts/mod.rs index 79bfb0d8..eb1f8170 100644 --- a/crates/cdk/src/nuts/mod.rs +++ b/crates/cdk/src/nuts/mod.rs @@ -18,6 +18,7 @@ pub mod nut12; pub mod nut13; pub mod nut14; pub mod nut15; +#[cfg(feature = "mint")] pub mod nut17; pub mod nut18; @@ -48,5 +49,6 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions}; pub use nut12::{BlindSignatureDleq, ProofDleq}; pub use nut14::HTLCWitness; pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings}; +#[cfg(feature = "mint")] pub use nut17::{NotificationPayload, PubSubManager}; pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport}; diff --git a/crates/cdk/src/nuts/nut06.rs b/crates/cdk/src/nuts/nut06.rs index 9ecabe87..17ba18b6 100644 --- a/crates/cdk/src/nuts/nut06.rs +++ b/crates/cdk/src/nuts/nut06.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::nut01::PublicKey; -use super::{nut04, nut05, nut15, nut17, MppMethodSettings}; +use super::{nut04, nut05, nut15, MppMethodSettings}; /// Mint Version #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -238,7 +238,8 @@ pub struct Nuts { /// NUT17 Settings #[serde(default)] #[serde(rename = "17")] - pub nut17: nut17::SupportedSettings, + #[cfg(feature = "mint")] + pub nut17: super::nut17::SupportedSettings, } impl Nuts { diff --git a/crates/cdk/src/nuts/nut17.rs b/crates/cdk/src/nuts/nut17/mod.rs similarity index 94% rename from crates/cdk/src/nuts/nut17.rs rename to crates/cdk/src/nuts/nut17/mod.rs index 3d64fad9..d186cc63 100644 --- a/crates/cdk/src/nuts/nut17.rs +++ b/crates/cdk/src/nuts/nut17/mod.rs @@ -1,5 +1,8 @@ //! Specific Subscription for the cdk crate +use super::{BlindSignature, CurrencyUnit, PaymentMethod}; +use crate::cdk_database::{self, MintDatabase}; +pub use crate::pub_sub::SubId; use crate::{ nuts::{ MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState, @@ -8,7 +11,11 @@ use crate::{ pub_sub::{self, Index, Indexable, SubscriptionGlobalId}, }; use serde::{Deserialize, Serialize}; -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; + +mod on_subscription; + +pub use on_subscription::OnSubscription; /// Subscription Parameter according to the standard #[derive(Debug, Clone, Serialize, Deserialize)] @@ -57,10 +64,6 @@ impl Default for SupportedMethods { } } -pub use crate::pub_sub::SubId; - -use super::{BlindSignature, CurrencyUnit, PaymentMethod}; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(untagged)] /// Subscription response @@ -145,15 +148,27 @@ impl From for Vec> { } /// Manager -#[derive(Default)] /// Publish–subscribe manager /// /// Nut-17 implementation is system-wide and not only through the WebSocket, so /// it is possible for another part of the system to subscribe to events. -pub struct PubSubManager(pub_sub::Manager); +pub struct PubSubManager(pub_sub::Manager); + +#[allow(clippy::default_constructed_unit_structs)] +impl Default for PubSubManager { + fn default() -> Self { + PubSubManager(OnSubscription::default().into()) + } +} + +impl From + Send + Sync>> for PubSubManager { + fn from(val: Arc + Send + Sync>) -> Self { + PubSubManager(OnSubscription(Some(val)).into()) + } +} impl Deref for PubSubManager { - type Target = pub_sub::Manager; + type Target = pub_sub::Manager; fn deref(&self) -> &Self::Target { &self.0 diff --git a/crates/cdk/src/nuts/nut17/on_subscription.rs b/crates/cdk/src/nuts/nut17/on_subscription.rs new file mode 100644 index 00000000..b2347e60 --- /dev/null +++ b/crates/cdk/src/nuts/nut17/on_subscription.rs @@ -0,0 +1,110 @@ +//! On Subscription +//! +//! This module contains the code that is triggered when a new subscription is created. +use super::{Kind, NotificationPayload}; +use crate::{ + cdk_database::{self, MintDatabase}, + nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey}, + pub_sub::OnNewSubscription, +}; +use std::{collections::HashMap, sync::Arc}; + +#[derive(Default)] +/// Subscription Init +/// +/// This struct triggers code when a new subscription is created. +/// +/// It is used to send the initial state of the subscription to the client. +pub struct OnSubscription( + pub(crate) Option + Send + Sync>>, +); + +#[async_trait::async_trait] +impl OnNewSubscription for OnSubscription { + type Event = NotificationPayload; + type Index = (String, Kind); + + async fn on_new_subscription( + &self, + request: &[&Self::Index], + ) -> Result, String> { + let datastore = if let Some(localstore) = self.0.as_ref() { + localstore + } else { + return Ok(vec![]); + }; + + let mut to_return = vec![]; + + for (kind, values) in request.iter().fold( + HashMap::new(), + |mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| { + acc.entry(kind).or_default().push(data); + acc + }, + ) { + match kind { + Kind::Bolt11MeltQuote => { + let queries = values + .iter() + .map(|id| datastore.get_melt_quote(id)) + .collect::>(); + + to_return.extend( + futures::future::try_join_all(queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MeltQuoteBolt11Response| x.into()) + .collect::>() + }) + .map_err(|e| e.to_string())?, + ); + } + Kind::Bolt11MintQuote => { + let queries = values + .iter() + .map(|id| datastore.get_mint_quote(id)) + .collect::>(); + + to_return.extend( + futures::future::try_join_all(queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MintQuoteBolt11Response| x.into()) + .collect::>() + }) + .map_err(|e| e.to_string())?, + ); + } + Kind::ProofState => { + let public_keys = values + .iter() + .map(PublicKey::from_hex) + .collect::, _>>() + .map_err(|e| e.to_string())?; + + to_return.extend( + datastore + .get_proofs_states(&public_keys) + .await + .map_err(|e| e.to_string())? + .into_iter() + .enumerate() + .filter_map(|(idx, state)| { + state.map(|state| (public_keys[idx], state).into()) + }) + .map(|state: ProofState| state.into()), + ); + } + } + } + + Ok(to_return) + } +} diff --git a/crates/cdk/src/pub_sub/mod.rs b/crates/cdk/src/pub_sub/mod.rs index f9347c78..c511c803 100644 --- a/crates/cdk/src/pub_sub/mod.rs +++ b/crates/cdk/src/pub_sub/mod.rs @@ -37,6 +37,25 @@ pub const DEFAULT_REMOVE_SIZE: usize = 10_000; /// Default channel size for subscription buffering pub const DEFAULT_CHANNEL_SIZE: usize = 10; +#[async_trait::async_trait] +/// On New Subscription trait +/// +/// This trait is optional and it is used to notify the application when a new +/// subscription is created. This is useful when the application needs to send +/// the initial state to the subscriber upon subscription +pub trait OnNewSubscription { + /// Index type + type Index; + /// Subscription event type + type Event; + + /// Called when a new subscription is created + async fn on_new_subscription( + &self, + request: &[&Self::Index], + ) -> Result, String>; +} + /// Subscription manager /// /// This object keep track of all subscription listener and it is also @@ -45,21 +64,24 @@ pub const DEFAULT_CHANNEL_SIZE: usize = 10; /// The content of the notification is not relevant to this scope and it is up /// to the application, therefore the generic T is used instead of a specific /// type -pub struct Manager +pub struct Manager where T: Indexable + Clone + Send + Sync + 'static, I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { indexes: IndexTree, + on_new_subscription: Option, unsubscription_sender: mpsc::Sender<(SubId, Vec>)>, active_subscriptions: Arc, background_subscription_remover: Option>, } -impl Default for Manager +impl Default for Manager where T: Indexable + Clone + Send + Sync + 'static, I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { fn default() -> Self { let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE); @@ -72,6 +94,7 @@ where storage.clone(), active_subscriptions.clone(), ))), + on_new_subscription: None, unsubscription_sender: sender, active_subscriptions, indexes: storage, @@ -79,10 +102,24 @@ where } } -impl Manager +impl From for Manager where T: Indexable + Clone + Send + Sync + 'static, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, + I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, +{ + fn from(value: F) -> Self { + let mut manager: Self = Default::default(); + manager.on_new_subscription = Some(value); + manager + } +} + +impl Manager +where + T: Indexable + Clone + Send + Sync + 'static, + I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { #[inline] /// Broadcast an event to all listeners @@ -132,8 +169,29 @@ where ) -> ActiveSubscription { let (sender, receiver) = mpsc::channel(10); let sub_id: SubId = params.as_ref().clone(); + let indexes: Vec> = params.into(); + if let Some(on_new_subscription) = self.on_new_subscription.as_ref() { + match on_new_subscription + .on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::>()) + .await + { + Ok(events) => { + for event in events { + let _ = sender.try_send((sub_id.clone(), event)); + } + } + Err(err) => { + tracing::info!( + "Failed to get initial state for subscription: {:?}, {}", + sub_id, + err + ); + } + } + } + let mut index_storage = self.indexes.write().await; for index in indexes.clone() { index_storage.insert(index, sender.clone()); @@ -180,10 +238,11 @@ where } /// Manager goes out of scope, stop all background tasks -impl Drop for Manager +impl Drop for Manager where T: Indexable + Clone + Send + Sync + 'static, I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, + F: OnNewSubscription + 'static, { fn drop(&mut self) { if let Some(handler) = self.background_subscription_remover.take() {