Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send current state of the subscriptions #444

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion crates/cdk-axum/src/ws/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@ use cdk::{
pub struct Method(Params);

#[derive(Debug, Clone, serde::Serialize)]
/// The response to a subscription request
pub struct Response {
/// Status
status: String,
/// Subscription ID
#[serde(rename = "subId")]
sub_id: SubId,
}

#[derive(Debug, Clone, serde::Serialize)]
/// The notification
///
/// This is the notification that is sent to the client when an event matches a
/// subscription
pub struct Notification {
/// The subscription ID
#[serde(rename = "subId")]
pub sub_id: SubId,

/// The notification payload
pub payload: NotificationPayload,
}

Expand All @@ -39,12 +48,21 @@ impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
impl WsHandle for Method {
type Response = Response;

/// The `handle` method is called when a client sends a subscription request
async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
let sub_id = self.0.id.clone();
if context.subscriptions.contains_key(&sub_id) {
// Subscription ID already exits. Returns an error instead of
// replacing the other subscription or avoiding it.
return Err(WsError::InvalidParams);
}
let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;

let mut subscription = context
.state
.mint
.pubsub_manager
.subscribe(self.0.clone())
.await;
let publisher = context.publisher.clone();
context.subscriptions.insert(
sub_id.clone(),
Expand Down
14 changes: 13 additions & 1 deletion crates/cdk-integration-tests/tests/regtest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn get_notification<T: StreamExt<Item = Result<Message, E>> + Unpin, E: De
.unwrap();

let mut response: serde_json::Value =
serde_json::from_str(&msg.to_text().unwrap()).expect("valid json");
serde_json::from_str(msg.to_text().unwrap()).expect("valid json");

let mut params_raw = response
.as_object_mut()
Expand Down Expand Up @@ -112,6 +112,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> {
assert!(melt_response.preimage.is_some());
assert!(melt_response.state == MeltQuoteState::Paid);

let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
// first message is the current state
assert_eq!("test-sub", sub_id);
let payload = match payload {
NotificationPayload::MeltQuoteBolt11Response(melt) => melt,
_ => panic!("Wrong payload"),
};
assert_eq!(payload.amount + payload.fee_reserve, 100.into());
assert_eq!(payload.quote, melt.id);
assert_eq!(payload.state, MeltQuoteState::Unpaid);

// get current state
let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
assert_eq!("test-sub", sub_id);
let payload = match payload {
Expand Down
2 changes: 1 addition & 1 deletion crates/cdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ serde_json = "1"
serde_with = "3"
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
thiserror = "1"
futures = { version = "0.3.28", default-features = false, optional = true }
futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] }
url = "2.3"
utoipa = { version = "4", optional = true }
uuid = { version = "1", features = ["v4"] }
Expand Down
2 changes: 1 addition & 1 deletion crates/cdk/src/mint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl Mint {
Ok(Self {
mint_url: MintUrl::from_str(mint_url)?,
keysets: Arc::new(RwLock::new(active_keysets)),
pubsub_manager: Default::default(),
pubsub_manager: Arc::new(localstore.clone().into()),
secp_ctx,
quote_ttl,
xpriv,
Expand Down
2 changes: 2 additions & 0 deletions crates/cdk/src/nuts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod nut12;
pub mod nut13;
pub mod nut14;
pub mod nut15;
#[cfg(feature = "mint")]
pub mod nut17;
pub mod nut18;

Expand Down Expand Up @@ -48,5 +49,6 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions};
pub use nut12::{BlindSignatureDleq, ProofDleq};
pub use nut14::HTLCWitness;
pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings};
#[cfg(feature = "mint")]
pub use nut17::{NotificationPayload, PubSubManager};
pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport};
5 changes: 3 additions & 2 deletions crates/cdk/src/nuts/nut06.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use super::nut01::PublicKey;
use super::{nut04, nut05, nut15, nut17, MppMethodSettings};
use super::{nut04, nut05, nut15, MppMethodSettings};

/// Mint Version
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -238,7 +238,8 @@ pub struct Nuts {
/// NUT17 Settings
#[serde(default)]
#[serde(rename = "17")]
pub nut17: nut17::SupportedSettings,
#[cfg(feature = "mint")]
pub nut17: super::nut17::SupportedSettings,
}

impl Nuts {
Expand Down
31 changes: 23 additions & 8 deletions crates/cdk/src/nuts/nut17.rs → crates/cdk/src/nuts/nut17/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Specific Subscription for the cdk crate

use super::{BlindSignature, CurrencyUnit, PaymentMethod};
use crate::cdk_database::{self, MintDatabase};
pub use crate::pub_sub::SubId;
use crate::{
nuts::{
MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState,
Expand All @@ -8,7 +11,11 @@ use crate::{
pub_sub::{self, Index, Indexable, SubscriptionGlobalId},
};
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::{ops::Deref, sync::Arc};

mod on_subscription;

pub use on_subscription::OnSubscription;

/// Subscription Parameter according to the standard
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -57,10 +64,6 @@ impl Default for SupportedMethods {
}
}

pub use crate::pub_sub::SubId;

use super::{BlindSignature, CurrencyUnit, PaymentMethod};

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
/// Subscription response
Expand Down Expand Up @@ -145,15 +148,27 @@ impl From<Params> for Vec<Index<(String, Kind)>> {
}

/// Manager
#[derive(Default)]
/// Publish–subscribe manager
///
/// Nut-17 implementation is system-wide and not only through the WebSocket, so
/// it is possible for another part of the system to subscribe to events.
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind)>);
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>);

#[allow(clippy::default_constructed_unit_structs)]
impl Default for PubSubManager {
fn default() -> Self {
PubSubManager(OnSubscription::default().into())
}
}

impl From<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>> for PubSubManager {
fn from(val: Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>) -> Self {
PubSubManager(OnSubscription(Some(val)).into())
}
}

impl Deref for PubSubManager {
type Target = pub_sub::Manager<NotificationPayload, (String, Kind)>;
type Target = pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>;

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
110 changes: 110 additions & 0 deletions crates/cdk/src/nuts/nut17/on_subscription.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//! On Subscription
//!
//! This module contains the code that is triggered when a new subscription is created.
use super::{Kind, NotificationPayload};
use crate::{
cdk_database::{self, MintDatabase},
nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey},
pub_sub::OnNewSubscription,
};
use std::{collections::HashMap, sync::Arc};

#[derive(Default)]
/// Subscription Init
///
/// This struct triggers code when a new subscription is created.
///
/// It is used to send the initial state of the subscription to the client.
pub struct OnSubscription(
pub(crate) Option<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>>,
);

#[async_trait::async_trait]
impl OnNewSubscription for OnSubscription {
type Event = NotificationPayload;
type Index = (String, Kind);

async fn on_new_subscription(
&self,
request: &[&Self::Index],
) -> Result<Vec<Self::Event>, String> {
let datastore = if let Some(localstore) = self.0.as_ref() {
localstore
} else {
return Ok(vec![]);
};

let mut to_return = vec![];

for (kind, values) in request.iter().fold(
HashMap::new(),
|mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| {
acc.entry(kind).or_default().push(data);
acc
},
) {
match kind {
Kind::Bolt11MeltQuote => {
let queries = values
.iter()
.map(|id| datastore.get_melt_quote(id))
.collect::<Vec<_>>();

to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MeltQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::Bolt11MintQuote => {
let queries = values
.iter()
.map(|id| datastore.get_mint_quote(id))
.collect::<Vec<_>>();

to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MintQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::ProofState => {
let public_keys = values
.iter()
.map(PublicKey::from_hex)
.collect::<Result<Vec<PublicKey>, _>>()
.map_err(|e| e.to_string())?;

to_return.extend(
datastore
.get_proofs_states(&public_keys)
.await
.map_err(|e| e.to_string())?
.into_iter()
.enumerate()
.filter_map(|(idx, state)| {
state.map(|state| (public_keys[idx], state).into())
})
.map(|state: ProofState| state.into()),
);
}
}
}

Ok(to_return)
}
}
Loading
Loading