Skip to content

Commit

Permalink
fix: Send current state of the subscriptions (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
crodas authored Nov 10, 2024
1 parent 70ef5a4 commit cc5b267
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 19 deletions.
20 changes: 19 additions & 1 deletion crates/cdk-axum/src/ws/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@ use cdk::{
pub struct Method(Params);

#[derive(Debug, Clone, serde::Serialize)]
/// The response to a subscription request
pub struct Response {
/// Status
status: String,
/// Subscription ID
#[serde(rename = "subId")]
sub_id: SubId,
}

#[derive(Debug, Clone, serde::Serialize)]
/// The notification
///
/// This is the notification that is sent to the client when an event matches a
/// subscription
pub struct Notification {
/// The subscription ID
#[serde(rename = "subId")]
pub sub_id: SubId,

/// The notification payload
pub payload: NotificationPayload,
}

Expand All @@ -39,12 +48,21 @@ impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
impl WsHandle for Method {
type Response = Response;

/// The `handle` method is called when a client sends a subscription request
async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
let sub_id = self.0.id.clone();
if context.subscriptions.contains_key(&sub_id) {
// Subscription ID already exits. Returns an error instead of
// replacing the other subscription or avoiding it.
return Err(WsError::InvalidParams);
}
let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;

let mut subscription = context
.state
.mint
.pubsub_manager
.subscribe(self.0.clone())
.await;
let publisher = context.publisher.clone();
context.subscriptions.insert(
sub_id.clone(),
Expand Down
14 changes: 13 additions & 1 deletion crates/cdk-integration-tests/tests/regtest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn get_notification<T: StreamExt<Item = Result<Message, E>> + Unpin, E: De
.unwrap();

let mut response: serde_json::Value =
serde_json::from_str(&msg.to_text().unwrap()).expect("valid json");
serde_json::from_str(msg.to_text().unwrap()).expect("valid json");

let mut params_raw = response
.as_object_mut()
Expand Down Expand Up @@ -112,6 +112,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> {
assert!(melt_response.preimage.is_some());
assert!(melt_response.state == MeltQuoteState::Paid);

let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
// first message is the current state
assert_eq!("test-sub", sub_id);
let payload = match payload {
NotificationPayload::MeltQuoteBolt11Response(melt) => melt,
_ => panic!("Wrong payload"),
};
assert_eq!(payload.amount + payload.fee_reserve, 100.into());
assert_eq!(payload.quote, melt.id);
assert_eq!(payload.state, MeltQuoteState::Unpaid);

// get current state
let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
assert_eq!("test-sub", sub_id);
let payload = match payload {
Expand Down
2 changes: 1 addition & 1 deletion crates/cdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ serde_json = "1"
serde_with = "3"
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
thiserror = "1"
futures = { version = "0.3.28", default-features = false, optional = true }
futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] }
url = "2.3"
utoipa = { version = "4", optional = true }
uuid = { version = "1", features = ["v4"] }
Expand Down
2 changes: 1 addition & 1 deletion crates/cdk/src/mint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl Mint {
Ok(Self {
mint_url: MintUrl::from_str(mint_url)?,
keysets: Arc::new(RwLock::new(active_keysets)),
pubsub_manager: Default::default(),
pubsub_manager: Arc::new(localstore.clone().into()),
secp_ctx,
quote_ttl,
xpriv,
Expand Down
2 changes: 2 additions & 0 deletions crates/cdk/src/nuts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod nut12;
pub mod nut13;
pub mod nut14;
pub mod nut15;
#[cfg(feature = "mint")]
pub mod nut17;
pub mod nut18;

Expand Down Expand Up @@ -48,5 +49,6 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions};
pub use nut12::{BlindSignatureDleq, ProofDleq};
pub use nut14::HTLCWitness;
pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings};
#[cfg(feature = "mint")]
pub use nut17::{NotificationPayload, PubSubManager};
pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport};
5 changes: 3 additions & 2 deletions crates/cdk/src/nuts/nut06.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use super::nut01::PublicKey;
use super::{nut04, nut05, nut15, nut17, MppMethodSettings};
use super::{nut04, nut05, nut15, MppMethodSettings};

/// Mint Version
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -238,7 +238,8 @@ pub struct Nuts {
/// NUT17 Settings
#[serde(default)]
#[serde(rename = "17")]
pub nut17: nut17::SupportedSettings,
#[cfg(feature = "mint")]
pub nut17: super::nut17::SupportedSettings,
}

impl Nuts {
Expand Down
31 changes: 23 additions & 8 deletions crates/cdk/src/nuts/nut17.rs → crates/cdk/src/nuts/nut17/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Specific Subscription for the cdk crate
use super::{BlindSignature, CurrencyUnit, PaymentMethod};
use crate::cdk_database::{self, MintDatabase};
pub use crate::pub_sub::SubId;
use crate::{
nuts::{
MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState,
Expand All @@ -8,7 +11,11 @@ use crate::{
pub_sub::{self, Index, Indexable, SubscriptionGlobalId},
};
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::{ops::Deref, sync::Arc};

mod on_subscription;

pub use on_subscription::OnSubscription;

/// Subscription Parameter according to the standard
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -57,10 +64,6 @@ impl Default for SupportedMethods {
}
}

pub use crate::pub_sub::SubId;

use super::{BlindSignature, CurrencyUnit, PaymentMethod};

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
/// Subscription response
Expand Down Expand Up @@ -145,15 +148,27 @@ impl From<Params> for Vec<Index<(String, Kind)>> {
}

/// Manager
#[derive(Default)]
/// Publish–subscribe manager
///
/// Nut-17 implementation is system-wide and not only through the WebSocket, so
/// it is possible for another part of the system to subscribe to events.
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind)>);
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>);

#[allow(clippy::default_constructed_unit_structs)]
impl Default for PubSubManager {
fn default() -> Self {
PubSubManager(OnSubscription::default().into())
}
}

impl From<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>> for PubSubManager {
fn from(val: Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>) -> Self {
PubSubManager(OnSubscription(Some(val)).into())
}
}

impl Deref for PubSubManager {
type Target = pub_sub::Manager<NotificationPayload, (String, Kind)>;
type Target = pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>;

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
110 changes: 110 additions & 0 deletions crates/cdk/src/nuts/nut17/on_subscription.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//! On Subscription
//!
//! This module contains the code that is triggered when a new subscription is created.
use super::{Kind, NotificationPayload};
use crate::{
cdk_database::{self, MintDatabase},
nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey},
pub_sub::OnNewSubscription,
};
use std::{collections::HashMap, sync::Arc};

#[derive(Default)]
/// Subscription Init
///
/// This struct triggers code when a new subscription is created.
///
/// It is used to send the initial state of the subscription to the client.
pub struct OnSubscription(
pub(crate) Option<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>>,
);

#[async_trait::async_trait]
impl OnNewSubscription for OnSubscription {
type Event = NotificationPayload;
type Index = (String, Kind);

async fn on_new_subscription(
&self,
request: &[&Self::Index],
) -> Result<Vec<Self::Event>, String> {
let datastore = if let Some(localstore) = self.0.as_ref() {
localstore
} else {
return Ok(vec![]);
};

let mut to_return = vec![];

for (kind, values) in request.iter().fold(
HashMap::new(),
|mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| {
acc.entry(kind).or_default().push(data);
acc
},
) {
match kind {
Kind::Bolt11MeltQuote => {
let queries = values
.iter()
.map(|id| datastore.get_melt_quote(id))
.collect::<Vec<_>>();

to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MeltQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::Bolt11MintQuote => {
let queries = values
.iter()
.map(|id| datastore.get_mint_quote(id))
.collect::<Vec<_>>();

to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MintQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::ProofState => {
let public_keys = values
.iter()
.map(PublicKey::from_hex)
.collect::<Result<Vec<PublicKey>, _>>()
.map_err(|e| e.to_string())?;

to_return.extend(
datastore
.get_proofs_states(&public_keys)
.await
.map_err(|e| e.to_string())?
.into_iter()
.enumerate()
.filter_map(|(idx, state)| {
state.map(|state| (public_keys[idx], state).into())
})
.map(|state: ProofState| state.into()),
);
}
}
}

Ok(to_return)
}
}
Loading

0 comments on commit cc5b267

Please sign in to comment.