From f3f5d3271ce091029ed0a044aea318bc4961ef99 Mon Sep 17 00:00:00 2001 From: Cesar Rodas Date: Tue, 19 Nov 2024 18:13:52 -0300 Subject: [PATCH] Introduce subscription support in the Wallet crate. The main goal is to add a subscription to CDK Mint updates into the wallet. This feature will be particularly useful for improving the code whenever loops hit the mint server to check status changes. The goal is to add an easy-to-use interface that will hide the fact that we're connecting to WebSocket and subscribing to events. This will also hide the fact that the CDK-mint server may not support WebSocket updates. To be fully backward compatible, the HttpClientMethods traits have a new method, `subscribe,` which will return an object that implements `ActiveSubscription.` In the primary implementation, there is a `SubscriptionClient` that will attempt to connect through WebSocket and will fall to the HTTP-status pull and sleep approach (the current approach), but upper stream code will receive updates as if they come from a stream of updates through WebSocket. This `SubscriptionClient` struct will also manage reconnections to WebSockets (with automatic resubscriptions) and all the low-level stuff, providing an easy-to-use interface and leaving the upper-level code with a nice interface that is hard to misuse. When `ActiveSubscription` is dropped, it will automatically unsubscribe. --- Cargo.lock | 86 ++++- crates/cdk-axum/src/ws/error.rs | 15 + crates/cdk-axum/src/ws/handler.rs | 71 ---- crates/cdk-axum/src/ws/mod.rs | 56 ++- crates/cdk-axum/src/ws/subscribe.rs | 97 ++---- crates/cdk-axum/src/ws/unsubscribe.rs | 38 +- crates/cdk-cli/src/sub_commands/mint.rs | 22 +- crates/cdk-integration-tests/Cargo.toml | 2 +- crates/cdk-integration-tests/src/lib.rs | 52 +-- .../tests/fake_wallet.rs | 26 +- crates/cdk-integration-tests/tests/regtest.rs | 22 +- crates/cdk/Cargo.toml | 24 +- crates/cdk/examples/mint-token.rs | 24 +- crates/cdk/examples/p2pk.rs | 26 +- crates/cdk/examples/proof-selection.rs | 26 +- crates/cdk/src/lib.rs | 2 +- crates/cdk/src/nuts/mod.rs | 4 +- crates/cdk/src/nuts/nut04.rs | 12 + crates/cdk/src/nuts/nut05.rs | 17 + crates/cdk/src/nuts/nut17/manager.rs | 215 ++++++++++++ crates/cdk/src/nuts/nut17/mod.rs | 256 ++------------ crates/cdk/src/nuts/nut17/on_subscription.rs | 2 +- crates/cdk/src/nuts/nut17/ws.rs | 215 ++++++++++++ crates/cdk/src/wallet/client.rs | 1 + crates/cdk/src/wallet/mod.rs | 70 +++- crates/cdk/src/wallet/subscription/http.rs | 152 ++++++++ crates/cdk/src/wallet/subscription/mod.rs | 326 ++++++++++++++++++ crates/cdk/src/wallet/subscription/ws.rs | 216 ++++++++++++ crates/cdk/src/wallet/websocket.rs | 54 +++ misc/itests.sh | 3 + 30 files changed, 1596 insertions(+), 536 deletions(-) delete mode 100644 crates/cdk-axum/src/ws/handler.rs create mode 100644 crates/cdk/src/nuts/nut17/manager.rs create mode 100644 crates/cdk/src/nuts/nut17/ws.rs create mode 100644 crates/cdk/src/wallet/subscription/http.rs create mode 100644 crates/cdk/src/wallet/subscription/mod.rs create mode 100644 crates/cdk/src/wallet/subscription/ws.rs create mode 100644 crates/cdk/src/wallet/websocket.rs diff --git a/Cargo.lock b/Cargo.lock index bbb7dd16..b02c4712 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -693,6 +693,7 @@ dependencies = [ "sync_wrapper 0.1.2", "thiserror 1.0.69", "tokio", + "tokio-tungstenite 0.19.0", "tracing", "url", "utoipa", @@ -715,6 +716,7 @@ dependencies = [ "tokio", "tracing", "utoipa", + "uuid", ] [[package]] @@ -772,7 +774,6 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", - "uuid", ] [[package]] @@ -901,6 +902,7 @@ dependencies = [ "serde_json", "thiserror 1.0.69", "tracing", + "uuid", ] [[package]] @@ -931,6 +933,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tracing", + "uuid", ] [[package]] @@ -1148,6 +1151,16 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.0" @@ -1964,7 +1977,7 @@ dependencies = [ "hyper 1.5.1", "hyper-util", "rustls 0.23.19", - "rustls-native-certs", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -3416,7 +3429,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.19", - "rustls-native-certs", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -3618,6 +3631,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework 2.11.1", +] + [[package]] name = "rustls-native-certs" version = "0.8.1" @@ -3627,7 +3652,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.0.1", ] [[package]] @@ -3795,6 +3820,19 @@ dependencies = [ "cc", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.0.1" @@ -3802,7 +3840,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" dependencies = [ "bitflags 2.6.0", - "core-foundation", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -4112,6 +4150,7 @@ dependencies = [ "thiserror 1.0.69", "tokio-stream", "url", + "uuid", "webpki-roots 0.22.6", ] @@ -4474,6 +4513,21 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c" +dependencies = [ + "futures-util", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", + "tungstenite 0.19.0", +] + [[package]] name = "tokio-tungstenite" version = "0.20.1" @@ -4692,6 +4746,27 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 0.2.12", + "httparse", + "log", + "rand", + "rustls 0.21.12", + "sha1", + "thiserror 1.0.69", + "url", + "utf-8", + "webpki", +] + [[package]] name = "tungstenite" version = "0.20.1" @@ -4893,6 +4968,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", + "serde", ] [[package]] diff --git a/crates/cdk-axum/src/ws/error.rs b/crates/cdk-axum/src/ws/error.rs index 24fa4c8c..d67e3ef8 100644 --- a/crates/cdk-axum/src/ws/error.rs +++ b/crates/cdk-axum/src/ws/error.rs @@ -1,3 +1,4 @@ +use cdk::nuts::nut17::ws::WsErrorBody; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -17,3 +18,17 @@ pub enum WsError { /// Custom error ServerError(i32, String), } + +impl From for WsErrorBody { + fn from(val: WsError) -> Self { + let (id, message) = match val { + WsError::ParseError => (-32700, "Parse error".to_string()), + WsError::InvalidRequest => (-32600, "Invalid Request".to_string()), + WsError::MethodNotFound => (-32601, "Method not found".to_string()), + WsError::InvalidParams => (-32602, "Invalid params".to_string()), + WsError::InternalError => (-32603, "Internal error".to_string()), + WsError::ServerError(code, message) => (code, message), + }; + WsErrorBody { code: id, message } + } +} diff --git a/crates/cdk-axum/src/ws/handler.rs b/crates/cdk-axum/src/ws/handler.rs deleted file mode 100644 index b2298551..00000000 --- a/crates/cdk-axum/src/ws/handler.rs +++ /dev/null @@ -1,71 +0,0 @@ -use serde::Serialize; - -use super::{WsContext, WsError, JSON_RPC_VERSION}; - -impl From for WsErrorResponse { - fn from(val: WsError) -> Self { - let (id, message) = match val { - WsError::ParseError => (-32700, "Parse error".to_string()), - WsError::InvalidRequest => (-32600, "Invalid Request".to_string()), - WsError::MethodNotFound => (-32601, "Method not found".to_string()), - WsError::InvalidParams => (-32602, "Invalid params".to_string()), - WsError::InternalError => (-32603, "Internal error".to_string()), - WsError::ServerError(code, message) => (code, message), - }; - WsErrorResponse { code: id, message } - } -} - -#[derive(Debug, Clone, Serialize)] -struct WsErrorResponse { - code: i32, - message: String, -} - -#[derive(Debug, Clone, Serialize)] -struct WsResponse { - jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - error: Option, - id: usize, -} - -#[derive(Debug, Clone, Serialize)] -pub struct WsNotification { - pub jsonrpc: String, - pub method: String, - pub params: T, -} - -#[async_trait::async_trait] -pub trait WsHandle { - type Response: Serialize + Sized; - - async fn process( - self, - req_id: usize, - context: &mut WsContext, - ) -> Result - where - Self: Sized, - { - serde_json::to_value(&match self.handle(context).await { - Ok(response) => WsResponse { - jsonrpc: JSON_RPC_VERSION.to_owned(), - result: Some(response), - error: None, - id: req_id, - }, - Err(error) => WsResponse { - jsonrpc: JSON_RPC_VERSION.to_owned(), - result: None, - error: Some(error.into()), - id: req_id, - }, - }) - } - - async fn handle(self, context: &mut WsContext) -> Result; -} diff --git a/crates/cdk-axum/src/ws/mod.rs b/crates/cdk-axum/src/ws/mod.rs index 4b71368e..4581e7f8 100644 --- a/crates/cdk-axum/src/ws/mod.rs +++ b/crates/cdk-axum/src/ws/mod.rs @@ -1,50 +1,33 @@ use std::collections::HashMap; use axum::extract::ws::{Message, WebSocket}; +use cdk::nuts::nut17::ws::{ + NotificationInner, WsErrorBody, WsMessageOrResponse, WsMethodRequest, WsRequest, +}; use cdk::nuts::nut17::{NotificationPayload, SubId}; use futures::StreamExt; -use handler::{WsHandle, WsNotification}; -use serde::{Deserialize, Serialize}; -use subscribe::Notification; use tokio::sync::mpsc; +use uuid::Uuid; use crate::MintState; mod error; -mod handler; mod subscribe; mod unsubscribe; -/// JSON RPC version -pub const JSON_RPC_VERSION: &str = "2.0"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct WsRequest { - jsonrpc: String, - #[serde(flatten)] - method: WsMethod, - id: usize, -} +async fn process( + context: &mut WsContext, + body: WsRequest, +) -> Result { + let response = match body.method { + WsMethodRequest::Subscribe(sub) => subscribe::handle(context, sub).await, + WsMethodRequest::Unsubscribe(unsub) => unsubscribe::handle(context, unsub).await, + } + .map_err(WsErrorBody::from); -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", tag = "method", content = "params")] -pub enum WsMethod { - Subscribe(subscribe::Method), - Unsubscribe(unsubscribe::Method), -} + let response: WsMessageOrResponse = (body.id, response).into(); -impl WsMethod { - pub async fn process( - self, - req_id: usize, - context: &mut WsContext, - ) -> Result { - match self { - WsMethod::Subscribe(sub) => sub.process(req_id, context), - WsMethod::Unsubscribe(unsub) => unsub.process(req_id, context), - } - .await - } + serde_json::to_value(response) } pub use error::WsError; @@ -52,7 +35,7 @@ pub use error::WsError; pub struct WsContext { state: MintState, subscriptions: HashMap>, - publisher: mpsc::Sender<(SubId, NotificationPayload)>, + publisher: mpsc::Sender<(SubId, NotificationPayload)>, } /// Main function for websocket connections @@ -78,7 +61,10 @@ pub async fn main_websocket(mut socket: WebSocket, state: MintState) { // unsubscribed from the subscription manager, just ignore it. continue; } - let notification: WsNotification = (sub_id, payload).into(); + let notification: WsMessageOrResponse= NotificationInner { + sub_id, + payload, + }.into(); let message = match serde_json::to_string(¬ification) { Ok(message) => message, Err(err) => { @@ -101,7 +87,7 @@ pub async fn main_websocket(mut socket: WebSocket, state: MintState) { } }; - match request.method.process(request.id, &mut context).await { + match process(&mut context, request).await { Ok(result) => { if let Err(err) = socket .send(Message::Text(result.to_string())) diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs index f177ae08..a1cc6ecd 100644 --- a/crates/cdk-axum/src/ws/subscribe.rs +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -1,72 +1,33 @@ -use cdk::nuts::nut17::{NotificationPayload, Params}; -use cdk::pub_sub::SubId; - -use super::handler::{WsHandle, WsNotification}; -use super::{WsContext, WsError, JSON_RPC_VERSION}; - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Method(Params); - -#[derive(Debug, Clone, serde::Serialize)] -/// The response to a subscription request -pub struct Response { - /// Status - status: String, - /// Subscription ID - #[serde(rename = "subId")] - sub_id: SubId, -} - -#[derive(Debug, Clone, serde::Serialize)] -/// The notification -/// -/// This is the notification that is sent to the client when an event matches a -/// subscription -pub struct Notification { - /// The subscription ID - #[serde(rename = "subId")] - pub sub_id: SubId, - - /// The notification payload - pub payload: NotificationPayload, -} - -impl From<(SubId, NotificationPayload)> for WsNotification { - fn from((sub_id, payload): (SubId, NotificationPayload)) -> Self { - WsNotification { - jsonrpc: JSON_RPC_VERSION.to_owned(), - method: "subscribe".to_string(), - params: Notification { sub_id, payload }, - } +use cdk::nuts::nut17::ws::{WsResponseResult, WsSubscribeResponse}; +use cdk::nuts::nut17::Params; + +use super::{WsContext, WsError}; + +/// The `handle` method is called when a client sends a subscription request +pub(crate) async fn handle( + context: &mut WsContext, + params: Params, +) -> Result { + let sub_id = params.id.clone(); + if context.subscriptions.contains_key(&sub_id) { + // Subscription ID already exits. Returns an error instead of + // replacing the other subscription or avoiding it. + return Err(WsError::InvalidParams); } -} - -#[async_trait::async_trait] -impl WsHandle for Method { - type Response = Response; - - /// The `handle` method is called when a client sends a subscription request - async fn handle(self, context: &mut WsContext) -> Result { - let sub_id = self.0.id.clone(); - if context.subscriptions.contains_key(&sub_id) { - // Subscription ID already exits. Returns an error instead of - // replacing the other subscription or avoiding it. - return Err(WsError::InvalidParams); - } - let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; - let publisher = context.publisher.clone(); - context.subscriptions.insert( - sub_id.clone(), - tokio::spawn(async move { - while let Some(response) = subscription.recv().await { - let _ = publisher.send(response).await; - } - }), - ); - Ok(Response { - status: "OK".to_string(), - sub_id, - }) + let mut subscription = context.state.mint.pubsub_manager.subscribe(params).await; + let publisher = context.publisher.clone(); + context.subscriptions.insert( + sub_id.clone(), + tokio::spawn(async move { + while let Some(response) = subscription.recv().await { + let _ = publisher.send(response).await; + } + }), + ); + Ok(WsSubscribeResponse { + status: "OK".to_string(), + sub_id, } + .into()) } diff --git a/crates/cdk-axum/src/ws/unsubscribe.rs b/crates/cdk-axum/src/ws/unsubscribe.rs index 8a8d3660..0689e201 100644 --- a/crates/cdk-axum/src/ws/unsubscribe.rs +++ b/crates/cdk-axum/src/ws/unsubscribe.rs @@ -1,32 +1,18 @@ -use cdk::pub_sub::SubId; +use cdk::nuts::nut17::ws::{WsResponseResult, WsUnsubscribeRequest, WsUnsubscribeResponse}; -use super::handler::WsHandle; use super::{WsContext, WsError}; -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Method { - #[serde(rename = "subId")] - pub sub_id: SubId, -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct Response { - status: String, - sub_id: SubId, -} - -#[async_trait::async_trait] -impl WsHandle for Method { - type Response = Response; - - async fn handle(self, context: &mut WsContext) -> Result { - if context.subscriptions.remove(&self.sub_id).is_some() { - Ok(Response { - status: "OK".to_string(), - sub_id: self.sub_id, - }) - } else { - Err(WsError::InvalidParams) +pub(crate) async fn handle( + context: &mut WsContext, + req: WsUnsubscribeRequest, +) -> Result { + if context.subscriptions.remove(&req.sub_id).is_some() { + Ok(WsUnsubscribeResponse { + status: "OK".to_string(), + sub_id: req.sub_id, } + .into()) + } else { + Err(WsError::InvalidParams) } } diff --git a/crates/cdk-cli/src/sub_commands/mint.rs b/crates/cdk-cli/src/sub_commands/mint.rs index 43e71a44..a73e178c 100644 --- a/crates/cdk-cli/src/sub_commands/mint.rs +++ b/crates/cdk-cli/src/sub_commands/mint.rs @@ -1,18 +1,16 @@ use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use anyhow::Result; use cdk::amount::SplitTarget; use cdk::cdk_database::{Error, WalletDatabase}; use cdk::mint_url::MintUrl; -use cdk::nuts::{CurrencyUnit, MintQuoteState}; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload}; use cdk::wallet::multi_mint_wallet::WalletKey; -use cdk::wallet::{MultiMintWallet, Wallet}; +use cdk::wallet::{MultiMintWallet, Wallet, WalletSubscription}; use cdk::Amount; use clap::Args; use serde::{Deserialize, Serialize}; -use tokio::time::sleep; #[derive(Args, Serialize, Deserialize)] pub struct MintSubCommand { @@ -59,14 +57,18 @@ pub async fn mint( println!("Please pay: {}", quote.request); - loop { - let status = wallet.mint_quote_state("e.id).await?; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(2)).await; } let receive_amount = wallet.mint("e.id, SplitTarget::default(), None).await?; diff --git a/crates/cdk-integration-tests/Cargo.toml b/crates/cdk-integration-tests/Cargo.toml index 135186f3..a8815d1d 100644 --- a/crates/cdk-integration-tests/Cargo.toml +++ b/crates/cdk-integration-tests/Cargo.toml @@ -11,7 +11,7 @@ rust-version = "1.63.0" [features] - +http_subscription = ["cdk/http_subscription"] [dependencies] axum = "0.6.20" diff --git a/crates/cdk-integration-tests/src/lib.rs b/crates/cdk-integration-tests/src/lib.rs index 2e52b034..c07c8db6 100644 --- a/crates/cdk-integration-tests/src/lib.rs +++ b/crates/cdk-integration-tests/src/lib.rs @@ -1,6 +1,5 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use std::time::Duration; use anyhow::{bail, Result}; use axum::Router; @@ -10,17 +9,19 @@ use cdk::cdk_database::mint_memory::MintMemoryDatabase; use cdk::cdk_lightning::MintLightning; use cdk::dhke::construct_proofs; use cdk::mint::FeeReserve; +use cdk::nuts::nut17::Params; use cdk::nuts::{ CurrencyUnit, Id, KeySet, MintBolt11Request, MintInfo, MintQuoteBolt11Request, MintQuoteState, - Nuts, PaymentMethod, PreMintSecrets, Proofs, State, + NotificationPayload, Nuts, PaymentMethod, PreMintSecrets, Proofs, State, }; use cdk::types::{LnKey, QuoteTTL}; use cdk::wallet::client::{HttpClient, HttpClientMethods}; +use cdk::wallet::subscription::SubscriptionManager; +use cdk::wallet::WalletSubscription; use cdk::{Mint, Wallet}; use cdk_fake_wallet::FakeWallet; use init_regtest::{get_mint_addr, get_mint_port, get_mint_url}; use tokio::sync::Notify; -use tokio::time::sleep; use tower_http::cors::CorsLayer; pub mod init_fake_wallet; @@ -127,15 +128,18 @@ pub async fn wallet_mint( ) -> Result<()> { let quote = wallet.mint_quote(amount, description).await?; - loop { - let status = wallet.mint_quote_state("e.id).await?; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - println!("{:?}", status); - - sleep(Duration::from_secs(2)).await; } let receive_amount = wallet.mint("e.id, split_target, None).await?; @@ -155,7 +159,7 @@ pub async fn mint_proofs( println!("Minting for ecash"); println!(); - let wallet_client = HttpClient::new(); + let wallet_client = Arc::new(HttpClient::new()); let request = MintQuoteBolt11Request { amount, @@ -169,17 +173,25 @@ pub async fn mint_proofs( println!("Please pay: {}", mint_quote.request); - loop { - let status = wallet_client - .get_mint_quote_status(mint_url.parse()?, &mint_quote.quote) - .await?; + let subscription_client = SubscriptionManager::new(wallet_client.clone()); - if status.state == MintQuoteState::Paid { - break; - } - println!("{:?}", status.state); + let mut subscription = subscription_client + .subscribe( + mint_url.parse()?, + Params { + filters: vec![mint_quote.quote.clone()], + kind: cdk::nuts::nut17::Kind::Bolt11MintQuote, + id: "sub".into(), + }, + ) + .await; - sleep(Duration::from_secs(2)).await; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } + } } let premint_secrets = PreMintSecrets::random(keyset_id, amount, &SplitTarget::default())?; diff --git a/crates/cdk-integration-tests/tests/fake_wallet.rs b/crates/cdk-integration-tests/tests/fake_wallet.rs index cf0ea1be..b4f4213b 100644 --- a/crates/cdk-integration-tests/tests/fake_wallet.rs +++ b/crates/cdk-integration-tests/tests/fake_wallet.rs @@ -1,18 +1,17 @@ use std::sync::Arc; -use std::time::Duration; use anyhow::Result; use bip39::Mnemonic; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; use cdk::nuts::{ - CurrencyUnit, MeltBolt11Request, MeltQuoteState, MintQuoteState, PreMintSecrets, State, + CurrencyUnit, MeltBolt11Request, MeltQuoteState, MintQuoteState, NotificationPayload, + PreMintSecrets, State, }; use cdk::wallet::client::{HttpClient, HttpClientMethods}; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk_fake_wallet::{create_fake_invoice, FakeInvoiceDescription}; use cdk_integration_tests::attempt_to_swap_pending; -use tokio::time::sleep; const MINT_URL: &str = "http://127.0.0.1:8086"; @@ -379,12 +378,19 @@ async fn test_fake_melt_change_in_quote() -> Result<()> { // Keep polling the state of the mint quote id until it's paid async fn wait_for_mint_to_be_paid(wallet: &Wallet, mint_quote_id: &str) -> Result<()> { - loop { - let status = wallet.mint_quote_state(mint_quote_id).await?; - if status.state == MintQuoteState::Paid { - return Ok(()); + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![ + mint_quote_id.to_owned(), + ])) + .await; + + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_millis(5)).await; } + + Ok(()) } diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index c88a86fb..16ff8a5d 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -12,7 +12,7 @@ use cdk::nuts::{ PreMintSecrets, State, }; use cdk::wallet::client::{HttpClient, HttpClientMethods}; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk_integration_tests::init_regtest::{ get_mint_url, get_mint_ws_url, init_cln_client, init_lnd_client, }; @@ -20,7 +20,7 @@ use futures::{SinkExt, StreamExt}; use lightning_invoice::Bolt11Invoice; use ln_regtest_rs::InvoiceStatus; use serde_json::json; -use tokio::time::{sleep, timeout}; +use tokio::time::timeout; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::protocol::Message; @@ -361,16 +361,18 @@ async fn test_cached_mint() -> Result<()> { let quote = wallet.mint_quote(mint_amount, None).await?; lnd_client.pay_invoice(quote.request).await?; - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - println!("Quote status: {}", status.state); - - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(5)).await; } let active_keyset_id = wallet.get_active_mint_keyset().await?.id; diff --git a/crates/cdk/Cargo.toml b/crates/cdk/Cargo.toml index c5f97b11..83ea0401 100644 --- a/crates/cdk/Cargo.toml +++ b/crates/cdk/Cargo.toml @@ -6,7 +6,7 @@ authors = ["CDK Developers"] description = "Core Cashu Development Kit library implementing the Cashu protocol" homepage = "https://github.com/cashubtc/cdk" repository = "https://github.com/cashubtc/cdk.git" -rust-version = "1.63.0" # MSRV +rust-version = "1.63.0" # MSRV license = "MIT" @@ -17,12 +17,18 @@ mint = ["dep:futures"] swagger = ["mint", "dep:utoipa"] wallet = ["dep:reqwest"] bench = [] +http_subscription = [] [dependencies] async-trait = "0.1" anyhow = { version = "1.0.43", features = ["backtrace"] } -bitcoin = { version= "0.32.2", features = ["base64", "serde", "rand", "rand-std"] } +bitcoin = { version = "0.32.2", features = [ + "base64", + "serde", + "rand", + "rand-std", +] } ciborium = { version = "0.2.2", default-features = false, features = ["std"] } cbor-diag = "0.1.12" lightning-invoice = { version = "0.32.0", features = ["serde", "std"] } @@ -37,9 +43,14 @@ reqwest = { version = "0.12", default-features = false, features = [ serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" serde_with = "3" -tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } +tracing = { version = "0.1", default-features = false, features = [ + "attributes", + "log", +] } thiserror = "1" -futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] } +futures = { version = "0.3.28", default-features = false, optional = true, features = [ + "alloc", +] } url = "2.3" utoipa = { version = "4", optional = true } uuid = { version = "1", features = ["v4", "serde"] } @@ -55,6 +66,11 @@ tokio = { version = "1.21", features = [ "macros", "sync", ] } +getrandom = { version = "0.2" } +tokio-tungstenite = { version = "0.19.0", features = [ + "rustls", + "rustls-tls-native-roots", +] } [target.'cfg(target_arch = "wasm32")'.dependencies] tokio = { version = "1.21", features = ["rt", "macros", "sync", "time"] } diff --git a/crates/cdk/examples/mint-token.rs b/crates/cdk/examples/mint-token.rs index 195fb0ff..cb2dce1e 100644 --- a/crates/cdk/examples/mint-token.rs +++ b/crates/cdk/examples/mint-token.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use std::time::Duration; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; use cdk::error::Error; -use cdk::nuts::{CurrencyUnit, MintQuoteState}; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload}; use cdk::wallet::types::SendKind; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk::Amount; use rand::Rng; -use tokio::time::sleep; #[tokio::main] async fn main() -> Result<(), Error> { @@ -26,16 +24,18 @@ async fn main() -> Result<(), Error> { println!("Quote: {:#?}", quote); - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - println!("Quote status: {}", status.state); - - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(5)).await; } let receive_amount = wallet diff --git a/crates/cdk/examples/p2pk.rs b/crates/cdk/examples/p2pk.rs index 6e51f781..7edbe579 100644 --- a/crates/cdk/examples/p2pk.rs +++ b/crates/cdk/examples/p2pk.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use std::time::Duration; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; use cdk::error::Error; -use cdk::nuts::{CurrencyUnit, MintQuoteState, SecretKey, SpendingConditions}; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload, SecretKey, SpendingConditions}; use cdk::wallet::types::SendKind; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk::Amount; use rand::Rng; -use tokio::time::sleep; #[tokio::main] async fn main() -> Result<(), Error> { @@ -26,16 +24,18 @@ async fn main() -> Result<(), Error> { println!("Minting nuts ..."); - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); - - println!("Quote status: {}", status.state); - - if status.state == MintQuoteState::Paid { - break; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; + + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(5)).await; } let _receive_amount = wallet diff --git a/crates/cdk/examples/proof-selection.rs b/crates/cdk/examples/proof-selection.rs index 210b7731..dcfab297 100644 --- a/crates/cdk/examples/proof-selection.rs +++ b/crates/cdk/examples/proof-selection.rs @@ -1,15 +1,13 @@ //! Wallet example with memory store use std::sync::Arc; -use std::time::Duration; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; -use cdk::nuts::{CurrencyUnit, MintQuoteState}; -use cdk::wallet::Wallet; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload}; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk::Amount; use rand::Rng; -use tokio::time::sleep; #[tokio::main] async fn main() { @@ -28,16 +26,18 @@ async fn main() { println!("Pay request: {}", quote.request); - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); - - if status.state == MintQuoteState::Paid { - break; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; + + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - println!("Quote state: {}", status.state); - - sleep(Duration::from_secs(5)).await; } let receive_amount = wallet diff --git a/crates/cdk/src/lib.rs b/crates/cdk/src/lib.rs index effb04f9..80be76bd 100644 --- a/crates/cdk/src/lib.rs +++ b/crates/cdk/src/lib.rs @@ -35,7 +35,7 @@ pub use lightning_invoice::{self, Bolt11Invoice}; pub use mint::Mint; #[cfg(feature = "wallet")] #[doc(hidden)] -pub use wallet::Wallet; +pub use wallet::{Wallet, WalletSubscription}; #[doc(hidden)] pub use self::amount::Amount; diff --git a/crates/cdk/src/nuts/mod.rs b/crates/cdk/src/nuts/mod.rs index 7f913f49..0f3bc525 100644 --- a/crates/cdk/src/nuts/mod.rs +++ b/crates/cdk/src/nuts/mod.rs @@ -18,7 +18,6 @@ pub mod nut12; pub mod nut13; pub mod nut14; pub mod nut15; -#[cfg(feature = "mint")] pub mod nut17; pub mod nut18; pub mod nut19; @@ -51,5 +50,6 @@ pub use nut12::{BlindSignatureDleq, ProofDleq}; pub use nut14::HTLCWitness; pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings}; #[cfg(feature = "mint")] -pub use nut17::{NotificationPayload, PubSubManager}; +pub use nut17::PubSubManager; +pub use nut17::{NotificationPayload, SupportedSettings as Nut17SupportedSettings}; pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport}; diff --git a/crates/cdk/src/nuts/nut04.rs b/crates/cdk/src/nuts/nut04.rs index 7f90ba32..d7d4213d 100644 --- a/crates/cdk/src/nuts/nut04.rs +++ b/crates/cdk/src/nuts/nut04.rs @@ -96,6 +96,18 @@ pub struct MintQuoteBolt11Response { pub expiry: Option, } +impl MintQuoteBolt11Response { + /// Convert the MintQuote with a quote type Q to a String + pub fn to_string_id(&self) -> MintQuoteBolt11Response { + MintQuoteBolt11Response { + quote: self.quote.to_string(), + request: self.request.clone(), + state: self.state, + expiry: self.expiry, + } + } +} + #[cfg(feature = "mint")] impl From for MintQuoteBolt11Response { fn from(mint_quote: crate::mint::MintQuote) -> MintQuoteBolt11Response { diff --git a/crates/cdk/src/nuts/nut05.rs b/crates/cdk/src/nuts/nut05.rs index 3510ef4f..c0f146ed 100644 --- a/crates/cdk/src/nuts/nut05.rs +++ b/crates/cdk/src/nuts/nut05.rs @@ -115,6 +115,23 @@ pub struct MeltQuoteBolt11Response { pub change: Option>, } +impl MeltQuoteBolt11Response { + /// Convert a `MeltQuoteBolt11Response` with type Q (generic/unknown) to a + /// `MeltQuoteBolt11Response` with `String` + pub fn to_string_id(self) -> MeltQuoteBolt11Response { + MeltQuoteBolt11Response { + quote: self.quote.to_string(), + amount: self.amount, + fee_reserve: self.fee_reserve, + paid: self.paid, + state: self.state, + expiry: self.expiry, + payment_preimage: self.payment_preimage, + change: self.change, + } + } +} + #[cfg(feature = "mint")] impl From<&MeltQuote> for MeltQuoteBolt11Response { fn from(melt_quote: &MeltQuote) -> MeltQuoteBolt11Response { diff --git a/crates/cdk/src/nuts/nut17/manager.rs b/crates/cdk/src/nuts/nut17/manager.rs new file mode 100644 index 00000000..67d2566b --- /dev/null +++ b/crates/cdk/src/nuts/nut17/manager.rs @@ -0,0 +1,215 @@ +//! Specific Subscription for the cdk crate +use std::ops::Deref; +use std::sync::Arc; + +use uuid::Uuid; + +use super::{Notification, NotificationPayload, OnSubscription}; +use crate::cdk_database::{self, MintDatabase}; +use crate::nuts::{ + BlindSignature, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, + MintQuoteState, ProofState, +}; +use crate::pub_sub; + +/// Manager +/// Publish–subscribe manager +/// +/// Nut-17 implementation is system-wide and not only through the WebSocket, so +/// it is possible for another part of the system to subscribe to events. +pub struct PubSubManager(pub_sub::Manager, Notification, OnSubscription>); + +#[allow(clippy::default_constructed_unit_structs)] +impl Default for PubSubManager { + fn default() -> Self { + PubSubManager(OnSubscription::default().into()) + } +} + +impl From + Send + Sync>> for PubSubManager { + fn from(val: Arc + Send + Sync>) -> Self { + PubSubManager(OnSubscription(Some(val)).into()) + } +} + +impl Deref for PubSubManager { + type Target = pub_sub::Manager, Notification, OnSubscription>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PubSubManager { + /// Helper function to emit a ProofState status + pub fn proof_state>(&self, event: E) { + self.broadcast(event.into().into()); + } + + /// Helper function to emit a MintQuoteBolt11Response status + pub fn mint_quote_bolt11_status>>( + &self, + quote: E, + new_state: MintQuoteState, + ) { + let mut event = quote.into(); + event.state = new_state; + + self.broadcast(event.into()); + } + + /// Helper function to emit a MeltQuoteBolt11Response status + pub fn melt_quote_status>>( + &self, + quote: E, + payment_preimage: Option, + change: Option>, + new_state: MeltQuoteState, + ) { + let mut quote = quote.into(); + quote.state = new_state; + quote.paid = Some(new_state == MeltQuoteState::Paid); + quote.payment_preimage = payment_preimage; + quote.change = change; + self.broadcast(quote.into()); + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use tokio::time::sleep; + + use super::*; + use crate::nuts::nut17::Params; + use crate::nuts::{PublicKey, State}; + + #[tokio::test] + async fn active_and_drop() { + let manager = PubSubManager::default(); + let params = Params { + kind: Kind::ProofState, + filters: vec!["x".to_string()], + id: "uno".into(), + }; + + // Although the same param is used, two subscriptions are created, that + // is because each index is unique, thanks to `Unique`, it is the + // responsibility of the implementor to make sure that SubId are unique + // either globally or per client + let subscriptions = vec![ + manager.subscribe(params.clone()).await, + manager.subscribe(params).await, + ]; + assert_eq!(2, manager.active_subscriptions()); + drop(subscriptions); + + sleep(Duration::from_millis(10)).await; + + assert_eq!(0, manager.active_subscriptions()); + } + + #[tokio::test] + async fn broadcast() { + let manager = PubSubManager::default(); + let mut subscriptions = [ + manager + .subscribe(Params { + kind: Kind::ProofState, + filters: vec![ + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + .to_string(), + ], + id: "uno".into(), + }) + .await, + manager + .subscribe(Params { + kind: Kind::ProofState, + filters: vec![ + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + .to_string(), + ], + id: "dos".into(), + }) + .await, + ]; + + let event = ProofState { + y: PublicKey::from_hex( + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + }; + + manager.broadcast(event.into()); + + sleep(Duration::from_millis(10)).await; + + let (sub1, _) = subscriptions[0].try_recv().expect("valid message"); + assert_eq!("uno", *sub1); + + let (sub1, _) = subscriptions[1].try_recv().expect("valid message"); + assert_eq!("dos", *sub1); + + assert!(subscriptions[0].try_recv().is_err()); + assert!(subscriptions[1].try_recv().is_err()); + } + + #[test] + fn parsing_request() { + let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#; + let params: Params = serde_json::from_str(json).expect("valid json"); + assert_eq!(params.kind, Kind::ProofState); + assert_eq!(params.filters, vec!["x"]); + assert_eq!(*params.id, "uno"); + } + + #[tokio::test] + async fn json_test() { + let manager = PubSubManager::default(); + let mut subscription = manager + .subscribe::( + serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#) + .expect("valid json"), + ) + .await; + + manager.broadcast( + ProofState { + y: PublicKey::from_hex( + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + } + .into(), + ); + + // no one is listening for this event + manager.broadcast( + ProofState { + y: PublicKey::from_hex( + "020000000000000000000000000000000000000000000000000000000000000001", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + } + .into(), + ); + + sleep(Duration::from_millis(10)).await; + let (sub1, msg) = subscription.try_recv().expect("valid message"); + assert_eq!("uno", *sub1); + assert_eq!( + r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#, + serde_json::to_string(&msg).expect("valid json") + ); + assert!(subscription.try_recv().is_err()); + } +} diff --git a/crates/cdk/src/nuts/nut17/mod.rs b/crates/cdk/src/nuts/nut17/mod.rs index 424eb435..437b0c4f 100644 --- a/crates/cdk/src/nuts/nut17/mod.rs +++ b/crates/cdk/src/nuts/nut17/mod.rs @@ -1,27 +1,30 @@ //! Specific Subscription for the cdk crate - -use std::ops::Deref; use std::str::FromStr; -use std::sync::Arc; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; - -mod on_subscription; - -pub use on_subscription::OnSubscription; use uuid::Uuid; use super::PublicKey; -use crate::cdk_database::{self, MintDatabase}; use crate::nuts::{ - BlindSignature, CurrencyUnit, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, - MintQuoteState, PaymentMethod, ProofState, + CurrencyUnit, MeltQuoteBolt11Response, MintQuoteBolt11Response, PaymentMethod, ProofState, }; +use crate::pub_sub::{Index, Indexable, SubscriptionGlobalId}; + +#[cfg(feature = "mint")] +mod manager; +#[cfg(feature = "mint")] +mod on_subscription; +#[cfg(feature = "mint")] +pub use manager::PubSubManager; +#[cfg(feature = "mint")] +pub use on_subscription::OnSubscription; + pub use crate::pub_sub::SubId; -use crate::pub_sub::{self, Index, Indexable, SubscriptionGlobalId}; +pub mod ws; /// Subscription Parameter according to the standard -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash, Deserialize)] pub struct Params { /// Kind pub kind: Kind, @@ -88,31 +91,32 @@ impl Default for SupportedMethods { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "T: Serialize + DeserializeOwned")] #[serde(untagged)] /// Subscription response -pub enum NotificationPayload { +pub enum NotificationPayload { /// Proof State ProofState(ProofState), /// Melt Quote Bolt11 Response - MeltQuoteBolt11Response(MeltQuoteBolt11Response), + MeltQuoteBolt11Response(MeltQuoteBolt11Response), /// Mint Quote Bolt11 Response - MintQuoteBolt11Response(MintQuoteBolt11Response), + MintQuoteBolt11Response(MintQuoteBolt11Response), } -impl From for NotificationPayload { - fn from(proof_state: ProofState) -> NotificationPayload { +impl From for NotificationPayload { + fn from(proof_state: ProofState) -> NotificationPayload { NotificationPayload::ProofState(proof_state) } } -impl From> for NotificationPayload { - fn from(melt_quote: MeltQuoteBolt11Response) -> NotificationPayload { +impl From> for NotificationPayload { + fn from(melt_quote: MeltQuoteBolt11Response) -> NotificationPayload { NotificationPayload::MeltQuoteBolt11Response(melt_quote) } } -impl From> for NotificationPayload { - fn from(mint_quote: MintQuoteBolt11Response) -> NotificationPayload { +impl From> for NotificationPayload { + fn from(mint_quote: MintQuoteBolt11Response) -> NotificationPayload { NotificationPayload::MintQuoteBolt11Response(mint_quote) } } @@ -128,7 +132,7 @@ pub enum Notification { MintQuoteBolt11(Uuid), } -impl Indexable for NotificationPayload { +impl Indexable for NotificationPayload { type Type = Notification; fn to_indexes(&self) -> Vec> { @@ -146,10 +150,9 @@ impl Indexable for NotificationPayload { } } +/// Kind #[derive(Debug, Clone, Copy, Eq, Ord, PartialOrd, PartialEq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] - -/// Kind pub enum Kind { /// Bolt 11 Melt Quote Bolt11MeltQuote, @@ -188,208 +191,3 @@ impl From for Vec> { // TODO don't unwrap, move to try from } } - -/// Manager -/// Publish–subscribe manager -/// -/// Nut-17 implementation is system-wide and not only through the WebSocket, so -/// it is possible for another part of the system to subscribe to events. -pub struct PubSubManager(pub_sub::Manager); - -#[allow(clippy::default_constructed_unit_structs)] -impl Default for PubSubManager { - fn default() -> Self { - PubSubManager(OnSubscription::default().into()) - } -} - -impl From + Send + Sync>> for PubSubManager { - fn from(val: Arc + Send + Sync>) -> Self { - PubSubManager(OnSubscription(Some(val)).into()) - } -} - -impl Deref for PubSubManager { - type Target = pub_sub::Manager; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl PubSubManager { - /// Helper function to emit a ProofState status - pub fn proof_state>(&self, event: E) { - self.broadcast(event.into().into()); - } - - /// Helper function to emit a MintQuoteBolt11Response status - pub fn mint_quote_bolt11_status>>( - &self, - quote: E, - new_state: MintQuoteState, - ) { - let mut event = quote.into(); - event.state = new_state; - - self.broadcast(event.into()); - } - - /// Helper function to emit a MeltQuoteBolt11Response status - pub fn melt_quote_status>>( - &self, - quote: E, - payment_preimage: Option, - change: Option>, - new_state: MeltQuoteState, - ) { - let mut quote = quote.into(); - quote.state = new_state; - quote.paid = Some(new_state == MeltQuoteState::Paid); - quote.payment_preimage = payment_preimage; - quote.change = change; - self.broadcast(quote.into()); - } -} - -#[cfg(test)] -mod test { - use std::time::Duration; - - use tokio::time::sleep; - - use super::*; - use crate::nuts::{PublicKey, State}; - - #[tokio::test] - async fn active_and_drop() { - let manager = PubSubManager::default(); - let params = Params { - kind: Kind::ProofState, - filters: vec![PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .unwrap() - .to_string()], - id: "uno".into(), - }; - - // Although the same param is used, two subscriptions are created, that - // is because each index is unique, thanks to `Unique`, it is the - // responsibility of the implementor to make sure that SubId are unique - // either globally or per client - let subscriptions = vec![ - manager.subscribe(params.clone()).await, - manager.subscribe(params).await, - ]; - assert_eq!(2, manager.active_subscriptions()); - drop(subscriptions); - - sleep(Duration::from_millis(10)).await; - - assert_eq!(0, manager.active_subscriptions()); - } - - #[tokio::test] - async fn broadcast() { - let manager = PubSubManager::default(); - let mut subscriptions = [ - manager - .subscribe(Params { - kind: Kind::ProofState, - filters: vec![ - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - .to_string(), - ], - id: "uno".into(), - }) - .await, - manager - .subscribe(Params { - kind: Kind::ProofState, - filters: vec![ - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - .to_string(), - ], - id: "dos".into(), - }) - .await, - ]; - - let event = ProofState { - y: PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - }; - - manager.broadcast(event.into()); - - sleep(Duration::from_millis(10)).await; - - let (sub1, _) = subscriptions[0].try_recv().expect("valid message"); - assert_eq!("uno", *sub1); - - let (sub1, _) = subscriptions[1].try_recv().expect("valid message"); - assert_eq!("dos", *sub1); - - assert!(subscriptions[0].try_recv().is_err()); - assert!(subscriptions[1].try_recv().is_err()); - } - - #[test] - fn parsing_request() { - let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#; - let params: Params = serde_json::from_str(json).expect("valid json"); - assert_eq!(params.kind, Kind::ProofState); - assert_eq!(params.filters, vec!["x"]); - assert_eq!(*params.id, "uno"); - } - - #[tokio::test] - async fn json_test() { - let manager = PubSubManager::default(); - let mut subscription = manager - .subscribe::( - serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#) - .expect("valid json"), - ) - .await; - - manager.broadcast( - ProofState { - y: PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - } - .into(), - ); - - // no one is listening for this event - manager.broadcast( - ProofState { - y: PublicKey::from_hex( - "020000000000000000000000000000000000000000000000000000000000000001", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - } - .into(), - ); - - sleep(Duration::from_millis(10)).await; - let (sub1, msg) = subscription.try_recv().expect("valid message"); - assert_eq!("uno", *sub1); - assert_eq!( - r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#, - serde_json::to_string(&msg).expect("valid json") - ); - assert!(subscription.try_recv().is_err()); - } -} diff --git a/crates/cdk/src/nuts/nut17/on_subscription.rs b/crates/cdk/src/nuts/nut17/on_subscription.rs index d7465567..37d74587 100644 --- a/crates/cdk/src/nuts/nut17/on_subscription.rs +++ b/crates/cdk/src/nuts/nut17/on_subscription.rs @@ -22,7 +22,7 @@ pub struct OnSubscription( #[async_trait::async_trait] impl OnNewSubscription for OnSubscription { - type Event = NotificationPayload; + type Event = NotificationPayload; type Index = Notification; async fn on_new_subscription( diff --git a/crates/cdk/src/nuts/nut17/ws.rs b/crates/cdk/src/nuts/nut17/ws.rs new file mode 100644 index 00000000..e59d87b6 --- /dev/null +++ b/crates/cdk/src/nuts/nut17/ws.rs @@ -0,0 +1,215 @@ +//! Websocket types + +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use super::{NotificationPayload, Params, SubId}; + +/// JSON RPC version +pub const JSON_RPC_VERSION: &str = "2.0"; + +/// The response to a subscription request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsSubscribeResponse { + /// Status + pub status: String, + /// Subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// The response to an unsubscription request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsUnsubscribeResponse { + /// Status + pub status: String, + /// Subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// The notification +/// +/// This is the notification that is sent to the client when an event matches a +/// subscription +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "T: Serialize + DeserializeOwned")] +pub struct NotificationInner { + /// The subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, + + /// The notification payload + pub payload: NotificationPayload, +} + +impl From> for NotificationInner { + fn from(value: NotificationInner) -> Self { + NotificationInner { + sub_id: value.sub_id, + payload: match value.payload { + NotificationPayload::ProofState(pk) => NotificationPayload::ProofState(pk), + NotificationPayload::MeltQuoteBolt11Response(quote) => { + NotificationPayload::MeltQuoteBolt11Response(quote.to_string_id()) + } + NotificationPayload::MintQuoteBolt11Response(quote) => { + NotificationPayload::MintQuoteBolt11Response(quote.to_string_id()) + } + }, + } + } +} + +/// Responses from the web socket server +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum WsResponseResult { + /// A response to a subscription request + Subscribe(WsSubscribeResponse), + /// Unsubscribe + Unsubscribe(WsUnsubscribeResponse), +} + +impl From for WsResponseResult { + fn from(response: WsSubscribeResponse) -> Self { + WsResponseResult::Subscribe(response) + } +} + +impl From for WsResponseResult { + fn from(response: WsUnsubscribeResponse) -> Self { + WsResponseResult::Unsubscribe(response) + } +} + +/// The request to unsubscribe +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsUnsubscribeRequest { + /// Subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// The inner method of the websocket request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "method", content = "params")] +pub enum WsMethodRequest { + /// Subscribe method + Subscribe(Params), + /// Unsubscribe method + Unsubscribe(WsUnsubscribeRequest), +} + +/// Websocket request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsRequest { + /// JSON RPC version + pub jsonrpc: String, + /// The method body + #[serde(flatten)] + pub method: WsMethodRequest, + /// The request ID + pub id: usize, +} + +impl From<(WsMethodRequest, usize)> for WsRequest { + fn from((method, id): (WsMethodRequest, usize)) -> Self { + WsRequest { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method, + id, + } + } +} + +/// Notification from the server +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsNotification { + /// JSON RPC version + pub jsonrpc: String, + /// The method + pub method: String, + /// The parameters + pub params: T, +} + +/// Websocket error +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WsErrorBody { + /// Error code + pub code: i32, + /// Error message + pub message: String, +} + +/// Websocket response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsResponse { + /// JSON RPC version + pub jsonrpc: String, + /// The result + pub result: WsResponseResult, + /// The request ID + pub id: usize, +} + +/// WebSocket error response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsErrorResponse { + /// JSON RPC version + pub jsonrpc: String, + /// The result + pub error: WsErrorBody, + /// The request ID + pub id: usize, +} + +/// Message from the server to the client +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum WsMessageOrResponse { + /// A response to a request + Response(WsResponse), + /// An error response + ErrorResponse(WsErrorResponse), + /// A notification + Notification(WsNotification>), +} + +impl From<(usize, Result)> for WsMessageOrResponse { + fn from((id, result): (usize, Result)) -> Self { + match result { + Ok(result) => WsMessageOrResponse::Response(WsResponse { + jsonrpc: JSON_RPC_VERSION.to_owned(), + result, + id, + }), + Err(err) => WsMessageOrResponse::ErrorResponse(WsErrorResponse { + jsonrpc: JSON_RPC_VERSION.to_owned(), + error: err, + id, + }), + } + } +} + +impl From> for WsMessageOrResponse { + fn from(notification: NotificationInner) -> Self { + WsMessageOrResponse::Notification(WsNotification { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method: "subscribe".to_string(), + params: notification.into(), + }) + } +} + +impl From> for WsMessageOrResponse { + fn from(notification: NotificationInner) -> Self { + WsMessageOrResponse::Notification(WsNotification { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method: "subscribe".to_string(), + params: notification, + }) + } +} diff --git a/crates/cdk/src/wallet/client.rs b/crates/cdk/src/wallet/client.rs index 52091d03..271d9a01 100644 --- a/crates/cdk/src/wallet/client.rs +++ b/crates/cdk/src/wallet/client.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use async_trait::async_trait; use reqwest::Client; use tracing::instrument; +#[cfg(not(target_arch = "wasm32"))] use url::Url; use super::Error; diff --git a/crates/cdk/src/wallet/mod.rs b/crates/cdk/src/wallet/mod.rs index bff6d4bf..2885c6c1 100644 --- a/crates/cdk/src/wallet/mod.rs +++ b/crates/cdk/src/wallet/mod.rs @@ -7,7 +7,11 @@ use std::sync::Arc; use bitcoin::bip32::Xpriv; use bitcoin::Network; use client::HttpClientMethods; +use getrandom::getrandom; +pub use multi_mint_wallet::MultiMintWallet; +use subscription::{ActiveSubscription, SubscriptionManager}; use tracing::instrument; +pub use types::{MeltQuote, MintQuote, SendKind}; use crate::amount::SplitTarget; use crate::cdk_database::{self, WalletDatabase}; @@ -16,6 +20,7 @@ use crate::error::Error; use crate::fees::calculate_fee; use crate::mint_url::MintUrl; use crate::nuts::nut00::token::Token; +use crate::nuts::nut17::{Kind, Params}; use crate::nuts::{ nut10, CurrencyUnit, Id, Keys, MintInfo, MintQuoteState, PreMintSecrets, Proof, Proofs, RestoreRequest, SpendingConditions, State, @@ -32,13 +37,11 @@ pub mod multi_mint_wallet; mod proofs; mod receive; mod send; +pub mod subscription; mod swap; pub mod types; pub mod util; -pub use multi_mint_wallet::MultiMintWallet; -pub use types::{MeltQuote, MintQuote, SendKind}; - use crate::nuts::nut00::ProofsMethods; /// CDK Wallet @@ -58,6 +61,54 @@ pub struct Wallet { pub target_proof_count: usize, xpriv: Xpriv, client: Arc, + subscription: SubscriptionManager, +} + +const ALPHANUMERIC: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + +/// Wallet Subscription filter +#[derive(Debug, Clone)] +pub enum WalletSubscription { + /// Proof subscription + ProofState(Vec), + /// Mint quote subscription + Bolt11MintQuoteState(Vec), + /// Melt quote subscription + Bolt11MeltQuoteState(Vec), +} + +impl From for Params { + fn from(val: WalletSubscription) -> Self { + let mut buffer = vec![0u8; 10]; + + getrandom(&mut buffer).expect("Failed to generate random bytes"); + + let id = buffer + .iter() + .map(|&byte| { + let index = byte as usize % ALPHANUMERIC.len(); // 62 alphanumeric characters (A-Z, a-z, 0-9) + ALPHANUMERIC[index] as char + }) + .collect::(); + + match val { + WalletSubscription::ProofState(filters) => Params { + filters, + kind: Kind::ProofState, + id: id.into(), + }, + WalletSubscription::Bolt11MintQuoteState(filters) => Params { + filters, + kind: Kind::Bolt11MintQuote, + id: id.into(), + }, + WalletSubscription::Bolt11MeltQuoteState(filters) => Params { + filters, + kind: Kind::Bolt11MeltQuote, + id: id.into(), + }, + } + } } impl Wallet { @@ -86,11 +137,12 @@ impl Wallet { target_proof_count: Option, ) -> Result { let xpriv = Xpriv::new_master(Network::Bitcoin, seed).expect("Could not create master key"); - + let http_client = Arc::new(HttpClient::new()); Ok(Self { mint_url: MintUrl::from_str(mint_url)?, unit, - client: Arc::new(HttpClient::new()), + client: http_client.clone(), + subscription: SubscriptionManager::new(http_client), localstore, xpriv, target_proof_count: target_proof_count.unwrap_or(3), @@ -100,6 +152,14 @@ impl Wallet { /// Change HTTP client pub fn set_client(&mut self, client: C) { self.client = Arc::new(client); + self.subscription = SubscriptionManager::new(self.client.clone()); + } + + /// Subscribe to events + pub async fn subscribe>(&self, query: T) -> ActiveSubscription { + self.subscription + .subscribe(self.mint_url.clone(), query.into()) + .await } /// Fee required for proof set diff --git a/crates/cdk/src/wallet/subscription/http.rs b/crates/cdk/src/wallet/subscription/http.rs new file mode 100644 index 00000000..cc98199a --- /dev/null +++ b/crates/cdk/src/wallet/subscription/http.rs @@ -0,0 +1,152 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::{mpsc, RwLock}; +use tokio::time; + +use super::WsSubscriptionBody; +use crate::mint_url::MintUrl; +use crate::nuts::nut17::Kind; +use crate::nuts::{nut01, nut04, nut05, nut07, CheckStateRequest, NotificationPayload}; +use crate::pub_sub::SubId; +use crate::wallet::client::HttpClientMethods; + +#[derive(Debug, Hash, PartialEq, Eq)] +enum UrlType { + Mint(String), + Melt(String), + PublicKey(nut01::PublicKey), +} + +#[derive(Debug, Eq, PartialEq)] +enum AnyState { + MintQuoteState(nut04::QuoteState), + MeltQuoteState(nut05::QuoteState), + PublicKey(nut07::State), + Empty, +} + +type SubscribedTo = HashMap>, SubId, AnyState)>; + +async fn convert_subscription( + sub_id: SubId, + subscriptions: &Arc>>, + subscribed_to: &mut SubscribedTo, +) -> Option<()> { + let subscription = subscriptions.read().await; + let sub = subscription.get(&sub_id)?; + tracing::debug!("New subscription: {:?}", sub); + match sub.1.kind { + Kind::Bolt11MintQuote => { + for id in sub.1.filters.iter().map(|id| UrlType::Mint(id.clone())) { + subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); + } + } + Kind::Bolt11MeltQuote => { + for id in sub.1.filters.iter().map(|id| UrlType::Melt(id.clone())) { + subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); + } + } + Kind::ProofState => { + for id in sub + .1 + .filters + .iter() + .map(|id| nut01::PublicKey::from_hex(id).map(UrlType::PublicKey)) + { + match id { + Ok(id) => { + subscribed_to + .insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); + } + Err(err) => { + tracing::error!("Error parsing public key: {:?}. Subscription ignored, will never yield any result", err); + } + } + } + } + } + + Some(()) +} + +#[allow(clippy::incompatible_msrv)] +#[inline] +pub async fn http_main>( + initial_state: S, + http_client: Arc, + mint_url: MintUrl, + subscriptions: Arc>>, + mut new_subscription_recv: mpsc::Receiver, + mut on_drop: mpsc::Receiver, +) { + let mut interval = time::interval(Duration::from_secs(2)); + let mut subscribed_to = HashMap::, _, AnyState)>::new(); + + for sub_id in initial_state { + convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await; + } + + loop { + tokio::select! { + _ = interval.tick() => { + for (url, (sender, _, last_state)) in subscribed_to.iter_mut() { + tracing::debug!("Polling: {:?}", url); + match url { + UrlType::Mint(id) => { + let response = http_client.get_mint_quote_status(mint_url.clone(), id).await; + if let Ok(response) = response { + if *last_state == AnyState::MintQuoteState(response.state) { + continue; + } + *last_state = AnyState::MintQuoteState(response.state); + if let Err(err) = sender.try_send(NotificationPayload::MintQuoteBolt11Response(response)) { + tracing::error!("Error sending mint quote response: {:?}", err); + } + } + } + UrlType::Melt(id) => { + let response = http_client.get_melt_quote_status(mint_url.clone(), id).await; + if let Ok(response) = response { + if *last_state == AnyState::MeltQuoteState(response.state) { + continue; + } + *last_state = AnyState::MeltQuoteState(response.state); + if let Err(err) = sender.try_send(NotificationPayload::MeltQuoteBolt11Response(response)) { + tracing::error!("Error sending melt quote response: {:?}", err); + } + } + } + UrlType::PublicKey(id) => { + let responses = http_client.post_check_state(mint_url.clone(), CheckStateRequest { + ys: vec![*id], + }).await; + if let Ok(mut responses) = responses { + let response = if let Some(state) = responses.states.pop() { + state + } else { + continue; + }; + + if *last_state == AnyState::PublicKey(response.state) { + continue; + } + *last_state = AnyState::PublicKey(response.state); + if let Err(err) = sender.try_send(NotificationPayload::ProofState(response)) { + tracing::error!("Error sending proof state response: {:?}", err); + } + } + } + } + } + } + Some(subid) = new_subscription_recv.recv() => { + convert_subscription(subid, &subscriptions, &mut subscribed_to).await; + } + Some(id) = on_drop.recv() => { + subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id); + } + } + } +} diff --git a/crates/cdk/src/wallet/subscription/mod.rs b/crates/cdk/src/wallet/subscription/mod.rs new file mode 100644 index 00000000..3491a9c1 --- /dev/null +++ b/crates/cdk/src/wallet/subscription/mod.rs @@ -0,0 +1,326 @@ +//! Client for subscriptions +//! +//! Mint servers can send notifications to clients about changes in the state, +//! according to NUT-17, using the WebSocket protocol. This module provides a +//! subscription manager that allows clients to subscribe to notifications from +//! multiple mint servers using WebSocket or with a poll-based system, using +//! the HTTP client. +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use tokio::sync::{mpsc, RwLock}; +use tokio::task::JoinHandle; +use tracing::error; + +use crate::mint_url::MintUrl; +use crate::nuts::nut17::Params; +use crate::pub_sub::SubId; +use crate::wallet::client::HttpClientMethods; + +mod http; +#[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") +))] +mod ws; + +type WsSubscriptionBody = (mpsc::Sender, Params); + +/// Subscription manager +/// +/// This structure should be instantiated once per wallet at most. It is +/// clonable since all its members are Arcs. +/// +/// The main goal is to provide a single interface to manage multiple +/// subscriptions to many servers to subscribe to events. If supported, the +/// WebSocket method is used to subscribe to server-side events. Otherwise, a +/// poll-based system is used, where a background task fetches information about +/// the resource every few seconds and notifies subscribers of any change +/// upstream. +/// +/// The subscribers have a simple-to-use interface, receiving an +/// ActiveSubscription struct, which can be used to receive updates and to +/// unsubscribe from updates automatically on the drop. +#[derive(Debug, Clone)] +pub struct SubscriptionManager { + all_connections: Arc>>, + http_client: Arc, +} + +impl SubscriptionManager { + /// Create a new subscription manager + pub fn new(http_client: Arc) -> Self { + Self { + all_connections: Arc::new(RwLock::new(HashMap::new())), + http_client, + } + } + + /// Subscribe to updates from a mint server with a given filter + pub async fn subscribe(&self, mint_url: MintUrl, filter: Params) -> ActiveSubscription { + let subscription_clients = self.all_connections.read().await; + let id = filter.id.clone(); + if let Some(subscription_client) = subscription_clients.get(&mint_url) { + let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await; + ActiveSubscription::new(receiver, id, on_drop_notif) + } else { + drop(subscription_clients); + + #[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") + ))] + let is_ws_support = self + .http_client + .get_mint_info(mint_url.clone()) + .await + .map(|info| !info.nuts.nut17.supported.is_empty()) + .unwrap_or_default(); + + #[cfg(any( + feature = "http_subscription", + not(feature = "mint"), + target_arch = "wasm32" + ))] + let is_ws_support = false; + + tracing::debug!( + "Connect to {:?} to subscribe. WebSocket is supported ({})", + mint_url, + is_ws_support + ); + + let mut subscription_clients = self.all_connections.write().await; + let subscription_client = + SubscriptionClient::new(mint_url.clone(), self.http_client.clone(), is_ws_support); + let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await; + subscription_clients.insert(mint_url, subscription_client); + + ActiveSubscription::new(receiver, id, on_drop_notif) + } + } +} + +/// Subscription client +/// +/// If the server supports WebSocket subscriptions, this client will be used, +/// otherwise the HTTP pool and pause will be used (which is the less efficient +/// method). +#[derive(Debug)] +pub struct SubscriptionClient { + new_subscription_notif: mpsc::Sender, + on_drop_notif: mpsc::Sender, + subscriptions: Arc>>, + worker: Option>, +} + +type NotificationPayload = crate::nuts::NotificationPayload; + +/// Active Subscription +pub struct ActiveSubscription { + sub_id: Option, + on_drop_notif: mpsc::Sender, + receiver: mpsc::Receiver, +} + +impl ActiveSubscription { + fn new( + receiver: mpsc::Receiver, + sub_id: SubId, + on_drop_notif: mpsc::Sender, + ) -> Self { + Self { + sub_id: Some(sub_id), + on_drop_notif, + receiver, + } + } + + /// Try to receive a notification + pub fn try_recv(&mut self) -> Result, Error> { + match self.receiver.try_recv() { + Ok(payload) => Ok(Some(payload)), + Err(mpsc::error::TryRecvError::Empty) => Ok(None), + Err(mpsc::error::TryRecvError::Disconnected) => Err(Error::Disconnected), + } + } + + /// Receive a notification asynchronously + pub async fn recv(&mut self) -> Option { + self.receiver.recv().await + } +} + +impl Drop for ActiveSubscription { + fn drop(&mut self) { + if let Some(sub_id) = self.sub_id.take() { + let _ = self.on_drop_notif.try_send(sub_id); + } + } +} + +/// Subscription client error +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Url error + #[error("Could not join paths: {0}")] + Url(#[from] crate::mint_url::Error), + /// Disconnected from the notification channel + #[error("Disconnected from the notification channel")] + Disconnected, +} + +impl SubscriptionClient { + /// Create new [`WebSocketClient`] + pub fn new( + url: MintUrl, + http_client: Arc, + prefer_ws_method: bool, + ) -> Self { + let subscriptions = Arc::new(RwLock::new(HashMap::new())); + let (new_subscription_notif, new_subscription_recv) = mpsc::channel(100); + let (on_drop_notif, on_drop_recv) = mpsc::channel(1000); + + Self { + new_subscription_notif, + on_drop_notif, + subscriptions: subscriptions.clone(), + worker: Some(Self::start_worker( + prefer_ws_method, + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop_recv, + )), + } + } + + #[allow(unused_variables)] + fn start_worker( + prefer_ws_method: bool, + http_client: Arc, + url: MintUrl, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop_recv: mpsc::Receiver, + ) -> JoinHandle<()> { + #[cfg(any( + feature = "http_subscription", + not(feature = "mint"), + target_arch = "wasm32" + ))] + return Self::http_worker( + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop_recv, + ); + + #[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") + ))] + if prefer_ws_method { + Self::ws_worker( + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop_recv, + ) + } else { + Self::http_worker( + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop_recv, + ) + } + } + + /// Subscribe to a WebSocket channel + pub async fn subscribe( + &self, + filter: Params, + ) -> (mpsc::Sender, mpsc::Receiver) { + let mut subscriptions = self.subscriptions.write().await; + let id = filter.id.clone(); + + let (sender, receiver) = mpsc::channel(10_000); + subscriptions.insert(id.clone(), (sender, filter)); + drop(subscriptions); + + let _ = self.new_subscription_notif.send(id).await; + (self.on_drop_notif.clone(), receiver) + } + + /// HTTP subscription client + /// + /// This is a poll based subscription, where the client will poll the server + /// from time to time to get updates, notifying the subscribers on changes + fn http_worker( + http_client: Arc, + url: MintUrl, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop: mpsc::Receiver, + ) -> JoinHandle<()> { + let http_worker = http::http_main( + vec![], + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop, + ); + + #[cfg(target_arch = "wasm32")] + let ret = tokio::task::spawn_local(http_worker); + + #[cfg(not(target_arch = "wasm32"))] + let ret = tokio::spawn(http_worker); + + ret + } + + /// WebSocket subscription client + /// + /// This is a WebSocket based subscription, where the client will connect to + /// the server and stay there idle waiting for server-side notifications + #[allow(clippy::incompatible_msrv)] + #[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") + ))] + fn ws_worker( + http_client: Arc, + url: MintUrl, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop: mpsc::Receiver, + ) -> JoinHandle<()> { + tokio::spawn(ws::ws_main( + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop, + )) + } +} + +impl Drop for SubscriptionClient { + fn drop(&mut self) { + if let Some(sender) = self.worker.take() { + sender.abort(); + } + } +} diff --git a/crates/cdk/src/wallet/subscription/ws.rs b/crates/cdk/src/wallet/subscription/ws.rs new file mode 100644 index 00000000..d388f396 --- /dev/null +++ b/crates/cdk/src/wallet/subscription/ws.rs @@ -0,0 +1,216 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use futures::{SinkExt, StreamExt}; +use tokio::sync::{mpsc, RwLock}; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; + +use super::http::http_main; +use super::WsSubscriptionBody; +use crate::mint_url::MintUrl; +use crate::nuts::nut17::ws::{ + WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest, +}; +use crate::nuts::nut17::Params; +use crate::pub_sub::SubId; +use crate::wallet::client::HttpClientMethods; + +const MAX_ATTEMPT_FALLBACK_HTTP: usize = 10; + +async fn fallback_to_http>( + initial_state: S, + http_client: Arc, + mint_url: MintUrl, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop: mpsc::Receiver, +) { + http_main( + initial_state, + http_client, + mint_url, + subscriptions, + new_subscription_recv, + on_drop, + ) + .await +} + +#[allow(clippy::incompatible_msrv)] +#[inline] +pub async fn ws_main( + http_client: Arc, + mint_url: MintUrl, + subscriptions: Arc>>, + mut new_subscription_recv: mpsc::Receiver, + mut on_drop: mpsc::Receiver, +) { + let url = mint_url + .join_paths(&["v1", "ws"]) + .as_mut() + .map(|url| { + if url.scheme() == "https" { + url.set_scheme("wss").expect("Could not set scheme"); + } else { + url.set_scheme("ws").expect("Could not set scheme"); + } + url + }) + .expect("Could not join paths") + .to_string(); + + let mut active_subscriptions = HashMap::>::new(); + let mut failure_count = 0; + + loop { + tracing::debug!("Connecting to {}", url); + let (ws_stream, _) = if let Ok(result) = connect_async(&url).await { + result + } else { + failure_count += 1; + if failure_count > MAX_ATTEMPT_FALLBACK_HTTP { + tracing::error!( + "Could not connect to server after {MAX_ATTEMPT_FALLBACK_HTTP} attempts, falling back to HTTP-subscription client" + ); + return fallback_to_http( + active_subscriptions.into_keys(), + http_client, + mint_url, + subscriptions, + new_subscription_recv, + on_drop, + ) + .await; + } + continue; + }; + tracing::debug!("Connected to {}", url); + + failure_count = 0; + + let (mut write, mut read) = ws_stream.split(); + let req_id = AtomicUsize::new(0); + + let get_sub_request = |params: Params| -> Option<(usize, String)> { + let request: WsRequest = ( + WsMethodRequest::Subscribe(params), + req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + .into(); + + match serde_json::to_string(&request) { + Ok(json) => Some((request.id, json)), + Err(err) => { + tracing::error!("Could not serialize subscribe message: {:?}", err); + None + } + } + }; + + let get_unsub_request = |sub_id: SubId| -> Option { + let request: WsRequest = ( + WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }), + req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + .into(); + + match serde_json::to_string(&request) { + Ok(json) => Some(json), + Err(err) => { + tracing::error!("Could not serialize unsubscribe message: {:?}", err); + None + } + } + }; + + // Websocket reconnected, restore all subscriptions + let mut subscription_requests = HashSet::new(); + + let read_subscriptions = subscriptions.read().await; + for (sub_id, _) in active_subscriptions.iter() { + if let Some(Some((req_id, req))) = read_subscriptions + .get(sub_id) + .map(|(_, params)| get_sub_request(params.clone())) + { + let _ = write.send(Message::Text(req)).await; + subscription_requests.insert(req_id); + } + } + drop(read_subscriptions); + + loop { + tokio::select! { + Some(msg) = read.next() => { + let msg = match msg { + Ok(msg) => msg, + Err(_) => break, + }; + let msg = match msg { + Message::Text(msg) => msg, + _ => continue, + }; + let msg = match serde_json::from_str::(&msg) { + Ok(msg) => msg, + Err(_) => continue, + }; + + match msg { + WsMessageOrResponse::Notification(payload) => { + tracing::debug!("Received notification from server: {:?}", payload); + let _ = active_subscriptions.get(&payload.params.sub_id).map(|sender| { + let _ = sender.try_send(payload.params.payload); + }); + } + WsMessageOrResponse::Response(response) => { + tracing::debug!("Received response from server: {:?}", response); + subscription_requests.remove(&response.id); + } + WsMessageOrResponse::ErrorResponse(error) => { + tracing::error!("Received error from server: {:?}", error); + if subscription_requests.contains(&error.id) { + // If the server sends an error response to a subscription request, we should + // fallback to HTTP. + // TODO: Add some retry before giving up to HTTP. + return fallback_to_http( + active_subscriptions.into_keys(), + http_client, + mint_url, + subscriptions, + new_subscription_recv, + on_drop, + ).await; + } + } + } + + } + Some(subid) = new_subscription_recv.recv() => { + let subscription = subscriptions.read().await; + let sub = if let Some(subscription) = subscription.get(&subid) { + subscription + } else { + continue + }; + tracing::debug!("Subscribing to {:?}", sub.1); + active_subscriptions.insert(subid, sub.0.clone()); + if let Some((req_id, json)) = get_sub_request(sub.1.clone()) { + let _ = write.send(Message::Text(json)).await; + subscription_requests.insert(req_id); + } + }, + Some(subid) = on_drop.recv() => { + let mut subscription = subscriptions.write().await; + if let Some(sub) = subscription.remove(&subid) { + drop(sub); + } + tracing::debug!("Unsubscribing from {:?}", subid); + if let Some(json) = get_unsub_request(subid) { + let _ = write.send(Message::Text(json)).await; + } + } + } + } + } +} diff --git a/crates/cdk/src/wallet/websocket.rs b/crates/cdk/src/wallet/websocket.rs new file mode 100644 index 00000000..023156e9 --- /dev/null +++ b/crates/cdk/src/wallet/websocket.rs @@ -0,0 +1,54 @@ +//! Websocket types + +use serde::{Deserialize, Serialize}; + +use crate::{nuts::nut17::Params, pub_sub::SubId}; + +/// JSON RPC version +pub const JSON_RPC_VERSION: &str = "2.0"; + +/// Websocket request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsRequest { + pub jsonrpc: String, + #[serde(flatten)] + pub method: WsMethod, + pub id: usize, +} + +/// Websocket method +/// +/// List of possible methods that can be called on the websocket +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "method", content = "params")] +pub enum WsMethod { + /// Subscribe to a topic + Subscribe(Params), + /// Unsubscribe from a topic + Unsubscribe(UnsubscribeMethod), +} + +/// Unsubscribe method +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UnsubscribeMethod { + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// Websocket error response +#[derive(Debug, Clone, Serialize)] +pub struct WsErrorResponse { + code: i32, + message: String, +} + +/// Websocket response +#[derive(Debug, Clone, Serialize)] +pub struct WsResponse { + jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + id: usize, +} diff --git a/misc/itests.sh b/misc/itests.sh index 50eb7f7f..ef85d23b 100755 --- a/misc/itests.sh +++ b/misc/itests.sh @@ -83,6 +83,9 @@ done # Run cargo test cargo test -p cdk-integration-tests --test regtest +# Run cargo test with the http_subscription feature +cargo test -p cdk-integration-tests --test regtest --features http_subscription + # Capture the exit status of cargo test test_status=$?