diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs index 7b755eda2..971e15a73 100644 --- a/crates/cdk-axum/src/ws/subscribe.rs +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -3,7 +3,10 @@ use super::{ WsContext, WsError, JSON_RPC_VERSION, }; use cdk::{ - nuts::nut17::{NotificationPayload, Params}, + nuts::{ + nut17::{Kind, NotificationPayload, Params}, + MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey, + }, pub_sub::SubId, }; @@ -11,17 +14,26 @@ use cdk::{ pub struct Method(Params); #[derive(Debug, Clone, serde::Serialize)] +/// The response to a subscription request pub struct Response { + /// Status status: String, + /// Subscription ID #[serde(rename = "subId")] sub_id: SubId, } #[derive(Debug, Clone, serde::Serialize)] +/// The notification +/// +/// This is the notification that is sent to the client when an event matches a +/// subscription pub struct Notification { + /// The subscription ID #[serde(rename = "subId")] pub sub_id: SubId, + /// The notification payload pub payload: NotificationPayload, } @@ -39,13 +51,96 @@ impl From<(SubId, NotificationPayload)> for WsNotification { impl WsHandle for Method { type Response = Response; + /// The `handle` method is called when a client sends a subscription request async fn handle(self, context: &mut WsContext) -> Result { let sub_id = self.0.id.clone(); if context.subscriptions.contains_key(&sub_id) { + // Subscription ID already exits. Returns an error instead of + // replacing the other subscription or avoiding it. return Err(WsError::InvalidParams); } - let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; + + let mut subscription = context + .state + .mint + .pubsub_manager + .subscribe(self.0.clone()) + .await; let publisher = context.publisher.clone(); + + let current_notification_to_send: Vec = match self.0.kind { + Kind::Bolt11MeltQuote => { + let queries = self + .0 + .filters + .iter() + .map(|id| context.state.mint.localstore.get_melt_quote(id)) + .collect::>(); + + futures::future::try_join_all(queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MeltQuoteBolt11Response| x.into()) + .collect::>() + }) + .unwrap_or_default() + } + Kind::Bolt11MintQuote => { + let queries = self + .0 + .filters + .iter() + .map(|id| context.state.mint.localstore.get_mint_quote(id)) + .collect::>(); + + futures::future::try_join_all(queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MintQuoteBolt11Response| x.into()) + .collect::>() + }) + .unwrap_or_default() + } + Kind::ProofState => { + if let Ok(public_keys) = self + .0 + .filters + .iter() + .map(PublicKey::from_hex) + .collect::, _>>() + { + context + .state + .mint + .localstore + .get_proofs_states(&public_keys) + .await + .map(|x| { + x.into_iter() + .enumerate() + .filter_map(|(idx, state)| { + state.map(|state| (public_keys[idx], state).into()) + }) + .map(|x: ProofState| x.into()) + .collect::>() + }) + .unwrap_or_default() + } else { + vec![] + } + } + }; + + for notification in current_notification_to_send.into_iter() { + let _ = publisher.send((sub_id.clone(), notification)).await; + } + context.subscriptions.insert( sub_id.clone(), tokio::spawn(async move { diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index 2ba545951..be5ce00c0 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -112,6 +112,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> { assert!(melt_response.preimage.is_some()); assert!(melt_response.state == MeltQuoteState::Paid); + let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; + // first message is the current state + assert_eq!("test-sub", sub_id); + let payload = match payload { + NotificationPayload::MeltQuoteBolt11Response(melt) => melt, + _ => panic!("Wrong payload"), + }; + assert_eq!(payload.amount + payload.fee_reserve, 100.into()); + assert_eq!(payload.quote, melt.id); + assert_eq!(payload.state, MeltQuoteState::Unpaid); + + // get current state let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; assert_eq!("test-sub", sub_id); let payload = match payload {