diff --git a/Cargo.lock b/Cargo.lock index ef13fbe6..5f020dcc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -693,6 +693,7 @@ dependencies = [ "sync_wrapper 0.1.2", "thiserror 1.0.69", "tokio", + "tokio-tungstenite 0.19.0", "tracing", "url", "utoipa", @@ -1151,6 +1152,16 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.0" @@ -1967,7 +1978,7 @@ dependencies = [ "hyper 1.5.1", "hyper-util", "rustls 0.23.19", - "rustls-native-certs", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -3419,7 +3430,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.19", - "rustls-native-certs", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -3621,6 +3632,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework 2.11.1", +] + [[package]] name = "rustls-native-certs" version = "0.8.1" @@ -3630,7 +3653,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.0.1", ] [[package]] @@ -3798,6 +3821,19 @@ dependencies = [ "cc", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.0.1" @@ -3805,7 +3841,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" dependencies = [ "bitflags 2.6.0", - "core-foundation", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -4478,6 +4514,21 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c" +dependencies = [ + "futures-util", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", + "tungstenite 0.19.0", +] + [[package]] name = "tokio-tungstenite" version = "0.20.1" @@ -4696,6 +4747,27 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 0.2.12", + "httparse", + "log", + "rand", + "rustls 0.21.12", + "sha1", + "thiserror 1.0.69", + "url", + "utf-8", + "webpki", +] + [[package]] name = "tungstenite" version = "0.20.1" diff --git a/crates/cdk-axum/src/ws/error.rs b/crates/cdk-axum/src/ws/error.rs index 24fa4c8c..d67e3ef8 100644 --- a/crates/cdk-axum/src/ws/error.rs +++ b/crates/cdk-axum/src/ws/error.rs @@ -1,3 +1,4 @@ +use cdk::nuts::nut17::ws::WsErrorBody; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -17,3 +18,17 @@ pub enum WsError { /// Custom error ServerError(i32, String), } + +impl From for WsErrorBody { + fn from(val: WsError) -> Self { + let (id, message) = match val { + WsError::ParseError => (-32700, "Parse error".to_string()), + WsError::InvalidRequest => (-32600, "Invalid Request".to_string()), + WsError::MethodNotFound => (-32601, "Method not found".to_string()), + WsError::InvalidParams => (-32602, "Invalid params".to_string()), + WsError::InternalError => (-32603, "Internal error".to_string()), + WsError::ServerError(code, message) => (code, message), + }; + WsErrorBody { code: id, message } + } +} diff --git a/crates/cdk-axum/src/ws/handler.rs b/crates/cdk-axum/src/ws/handler.rs deleted file mode 100644 index b2298551..00000000 --- a/crates/cdk-axum/src/ws/handler.rs +++ /dev/null @@ -1,71 +0,0 @@ -use serde::Serialize; - -use super::{WsContext, WsError, JSON_RPC_VERSION}; - -impl From for WsErrorResponse { - fn from(val: WsError) -> Self { - let (id, message) = match val { - WsError::ParseError => (-32700, "Parse error".to_string()), - WsError::InvalidRequest => (-32600, "Invalid Request".to_string()), - WsError::MethodNotFound => (-32601, "Method not found".to_string()), - WsError::InvalidParams => (-32602, "Invalid params".to_string()), - WsError::InternalError => (-32603, "Internal error".to_string()), - WsError::ServerError(code, message) => (code, message), - }; - WsErrorResponse { code: id, message } - } -} - -#[derive(Debug, Clone, Serialize)] -struct WsErrorResponse { - code: i32, - message: String, -} - -#[derive(Debug, Clone, Serialize)] -struct WsResponse { - jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - error: Option, - id: usize, -} - -#[derive(Debug, Clone, Serialize)] -pub struct WsNotification { - pub jsonrpc: String, - pub method: String, - pub params: T, -} - -#[async_trait::async_trait] -pub trait WsHandle { - type Response: Serialize + Sized; - - async fn process( - self, - req_id: usize, - context: &mut WsContext, - ) -> Result - where - Self: Sized, - { - serde_json::to_value(&match self.handle(context).await { - Ok(response) => WsResponse { - jsonrpc: JSON_RPC_VERSION.to_owned(), - result: Some(response), - error: None, - id: req_id, - }, - Err(error) => WsResponse { - jsonrpc: JSON_RPC_VERSION.to_owned(), - result: None, - error: Some(error.into()), - id: req_id, - }, - }) - } - - async fn handle(self, context: &mut WsContext) -> Result; -} diff --git a/crates/cdk-axum/src/ws/mod.rs b/crates/cdk-axum/src/ws/mod.rs index 4b71368e..4581e7f8 100644 --- a/crates/cdk-axum/src/ws/mod.rs +++ b/crates/cdk-axum/src/ws/mod.rs @@ -1,50 +1,33 @@ use std::collections::HashMap; use axum::extract::ws::{Message, WebSocket}; +use cdk::nuts::nut17::ws::{ + NotificationInner, WsErrorBody, WsMessageOrResponse, WsMethodRequest, WsRequest, +}; use cdk::nuts::nut17::{NotificationPayload, SubId}; use futures::StreamExt; -use handler::{WsHandle, WsNotification}; -use serde::{Deserialize, Serialize}; -use subscribe::Notification; use tokio::sync::mpsc; +use uuid::Uuid; use crate::MintState; mod error; -mod handler; mod subscribe; mod unsubscribe; -/// JSON RPC version -pub const JSON_RPC_VERSION: &str = "2.0"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct WsRequest { - jsonrpc: String, - #[serde(flatten)] - method: WsMethod, - id: usize, -} +async fn process( + context: &mut WsContext, + body: WsRequest, +) -> Result { + let response = match body.method { + WsMethodRequest::Subscribe(sub) => subscribe::handle(context, sub).await, + WsMethodRequest::Unsubscribe(unsub) => unsubscribe::handle(context, unsub).await, + } + .map_err(WsErrorBody::from); -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", tag = "method", content = "params")] -pub enum WsMethod { - Subscribe(subscribe::Method), - Unsubscribe(unsubscribe::Method), -} + let response: WsMessageOrResponse = (body.id, response).into(); -impl WsMethod { - pub async fn process( - self, - req_id: usize, - context: &mut WsContext, - ) -> Result { - match self { - WsMethod::Subscribe(sub) => sub.process(req_id, context), - WsMethod::Unsubscribe(unsub) => unsub.process(req_id, context), - } - .await - } + serde_json::to_value(response) } pub use error::WsError; @@ -52,7 +35,7 @@ pub use error::WsError; pub struct WsContext { state: MintState, subscriptions: HashMap>, - publisher: mpsc::Sender<(SubId, NotificationPayload)>, + publisher: mpsc::Sender<(SubId, NotificationPayload)>, } /// Main function for websocket connections @@ -78,7 +61,10 @@ pub async fn main_websocket(mut socket: WebSocket, state: MintState) { // unsubscribed from the subscription manager, just ignore it. continue; } - let notification: WsNotification = (sub_id, payload).into(); + let notification: WsMessageOrResponse= NotificationInner { + sub_id, + payload, + }.into(); let message = match serde_json::to_string(¬ification) { Ok(message) => message, Err(err) => { @@ -101,7 +87,7 @@ pub async fn main_websocket(mut socket: WebSocket, state: MintState) { } }; - match request.method.process(request.id, &mut context).await { + match process(&mut context, request).await { Ok(result) => { if let Err(err) = socket .send(Message::Text(result.to_string())) diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs index f177ae08..30f25e30 100644 --- a/crates/cdk-axum/src/ws/subscribe.rs +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -1,72 +1,40 @@ -use cdk::nuts::nut17::{NotificationPayload, Params}; -use cdk::pub_sub::SubId; - -use super::handler::{WsHandle, WsNotification}; -use super::{WsContext, WsError, JSON_RPC_VERSION}; - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Method(Params); - -#[derive(Debug, Clone, serde::Serialize)] -/// The response to a subscription request -pub struct Response { - /// Status - status: String, - /// Subscription ID - #[serde(rename = "subId")] - sub_id: SubId, -} - -#[derive(Debug, Clone, serde::Serialize)] -/// The notification -/// -/// This is the notification that is sent to the client when an event matches a -/// subscription -pub struct Notification { - /// The subscription ID - #[serde(rename = "subId")] - pub sub_id: SubId, - - /// The notification payload - pub payload: NotificationPayload, -} - -impl From<(SubId, NotificationPayload)> for WsNotification { - fn from((sub_id, payload): (SubId, NotificationPayload)) -> Self { - WsNotification { - jsonrpc: JSON_RPC_VERSION.to_owned(), - method: "subscribe".to_string(), - params: Notification { sub_id, payload }, - } +use cdk::nuts::nut17::ws::{WsResponseResult, WsSubscribeResponse}; +use cdk::nuts::nut17::Params; + +use super::{WsContext, WsError}; + +/// The `handle` method is called when a client sends a subscription request +pub(crate) async fn handle( + context: &mut WsContext, + params: Params, +) -> Result { + let sub_id = params.id.clone(); + if context.subscriptions.contains_key(&sub_id) { + // Subscription ID already exits. Returns an error instead of + // replacing the other subscription or avoiding it. + return Err(WsError::InvalidParams); } -} - -#[async_trait::async_trait] -impl WsHandle for Method { - type Response = Response; - - /// The `handle` method is called when a client sends a subscription request - async fn handle(self, context: &mut WsContext) -> Result { - let sub_id = self.0.id.clone(); - if context.subscriptions.contains_key(&sub_id) { - // Subscription ID already exits. Returns an error instead of - // replacing the other subscription or avoiding it. - return Err(WsError::InvalidParams); - } - let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; - let publisher = context.publisher.clone(); - context.subscriptions.insert( - sub_id.clone(), - tokio::spawn(async move { - while let Some(response) = subscription.recv().await { - let _ = publisher.send(response).await; - } - }), - ); - Ok(Response { - status: "OK".to_string(), - sub_id, - }) + let mut subscription = context + .state + .mint + .pubsub_manager + .try_subscribe(params) + .await + .map_err(|_| WsError::ParseError)?; + + let publisher = context.publisher.clone(); + context.subscriptions.insert( + sub_id.clone(), + tokio::spawn(async move { + while let Some(response) = subscription.recv().await { + let _ = publisher.send(response).await; + } + }), + ); + Ok(WsSubscribeResponse { + status: "OK".to_string(), + sub_id, } + .into()) } diff --git a/crates/cdk-axum/src/ws/unsubscribe.rs b/crates/cdk-axum/src/ws/unsubscribe.rs index 8a8d3660..0689e201 100644 --- a/crates/cdk-axum/src/ws/unsubscribe.rs +++ b/crates/cdk-axum/src/ws/unsubscribe.rs @@ -1,32 +1,18 @@ -use cdk::pub_sub::SubId; +use cdk::nuts::nut17::ws::{WsResponseResult, WsUnsubscribeRequest, WsUnsubscribeResponse}; -use super::handler::WsHandle; use super::{WsContext, WsError}; -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Method { - #[serde(rename = "subId")] - pub sub_id: SubId, -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct Response { - status: String, - sub_id: SubId, -} - -#[async_trait::async_trait] -impl WsHandle for Method { - type Response = Response; - - async fn handle(self, context: &mut WsContext) -> Result { - if context.subscriptions.remove(&self.sub_id).is_some() { - Ok(Response { - status: "OK".to_string(), - sub_id: self.sub_id, - }) - } else { - Err(WsError::InvalidParams) +pub(crate) async fn handle( + context: &mut WsContext, + req: WsUnsubscribeRequest, +) -> Result { + if context.subscriptions.remove(&req.sub_id).is_some() { + Ok(WsUnsubscribeResponse { + status: "OK".to_string(), + sub_id: req.sub_id, } + .into()) + } else { + Err(WsError::InvalidParams) } } diff --git a/crates/cdk-cli/src/main.rs b/crates/cdk-cli/src/main.rs index e6c6c730..5352efd5 100644 --- a/crates/cdk-cli/src/main.rs +++ b/crates/cdk-cli/src/main.rs @@ -154,7 +154,7 @@ async fn main() -> Result<()> { )?; if let Some(proxy_url) = args.proxy.as_ref() { let http_client = HttpClient::with_proxy(mint_url, proxy_url.clone(), None, true)?; - wallet.set_client(Arc::from(http_client)); + wallet.set_client(http_client); } wallets.push(wallet); diff --git a/crates/cdk-cli/src/sub_commands/mint.rs b/crates/cdk-cli/src/sub_commands/mint.rs index 43e71a44..a73e178c 100644 --- a/crates/cdk-cli/src/sub_commands/mint.rs +++ b/crates/cdk-cli/src/sub_commands/mint.rs @@ -1,18 +1,16 @@ use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use anyhow::Result; use cdk::amount::SplitTarget; use cdk::cdk_database::{Error, WalletDatabase}; use cdk::mint_url::MintUrl; -use cdk::nuts::{CurrencyUnit, MintQuoteState}; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload}; use cdk::wallet::multi_mint_wallet::WalletKey; -use cdk::wallet::{MultiMintWallet, Wallet}; +use cdk::wallet::{MultiMintWallet, Wallet, WalletSubscription}; use cdk::Amount; use clap::Args; use serde::{Deserialize, Serialize}; -use tokio::time::sleep; #[derive(Args, Serialize, Deserialize)] pub struct MintSubCommand { @@ -59,14 +57,18 @@ pub async fn mint( println!("Please pay: {}", quote.request); - loop { - let status = wallet.mint_quote_state("e.id).await?; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(2)).await; } let receive_amount = wallet.mint("e.id, SplitTarget::default(), None).await?; diff --git a/crates/cdk-integration-tests/Cargo.toml b/crates/cdk-integration-tests/Cargo.toml index 9be7c4c8..51d793d6 100644 --- a/crates/cdk-integration-tests/Cargo.toml +++ b/crates/cdk-integration-tests/Cargo.toml @@ -11,7 +11,7 @@ rust-version = "1.63.0" [features] - +http_subscription = ["cdk/http_subscription"] [dependencies] axum = "0.6.20" diff --git a/crates/cdk-integration-tests/src/lib.rs b/crates/cdk-integration-tests/src/lib.rs index b0026696..7211c64c 100644 --- a/crates/cdk-integration-tests/src/lib.rs +++ b/crates/cdk-integration-tests/src/lib.rs @@ -1,7 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use anyhow::{bail, Result}; use axum::Router; @@ -12,17 +11,19 @@ use cdk::cdk_lightning::MintLightning; use cdk::dhke::construct_proofs; use cdk::mint::FeeReserve; use cdk::mint_url::MintUrl; +use cdk::nuts::nut17::Params; use cdk::nuts::{ CurrencyUnit, Id, KeySet, MintBolt11Request, MintInfo, MintQuoteBolt11Request, MintQuoteState, - Nuts, PaymentMethod, PreMintSecrets, Proofs, State, + NotificationPayload, Nuts, PaymentMethod, PreMintSecrets, Proofs, State, }; use cdk::types::{LnKey, QuoteTTL}; use cdk::wallet::client::{HttpClient, MintConnector}; +use cdk::wallet::subscription::SubscriptionManager; +use cdk::wallet::WalletSubscription; use cdk::{Mint, Wallet}; use cdk_fake_wallet::FakeWallet; use init_regtest::{get_mint_addr, get_mint_port, get_mint_url}; use tokio::sync::Notify; -use tokio::time::sleep; use tower_http::cors::CorsLayer; pub mod init_fake_wallet; @@ -129,15 +130,18 @@ pub async fn wallet_mint( ) -> Result<()> { let quote = wallet.mint_quote(amount, description).await?; - loop { - let status = wallet.mint_quote_state("e.id).await?; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - println!("{:?}", status); - - sleep(Duration::from_secs(2)).await; } let receive_amount = wallet.mint("e.id, split_target, None).await?; @@ -169,17 +173,25 @@ pub async fn mint_proofs( println!("Please pay: {}", mint_quote.request); - loop { - let status = wallet_client - .get_mint_quote_status(&mint_quote.quote) - .await?; + let subscription_client = SubscriptionManager::new(Arc::new(wallet_client.clone())); - if status.state == MintQuoteState::Paid { - break; - } - println!("{:?}", status.state); + let mut subscription = subscription_client + .subscribe( + mint_url.parse()?, + Params { + filters: vec![mint_quote.quote.clone()], + kind: cdk::nuts::nut17::Kind::Bolt11MintQuote, + id: "sub".into(), + }, + ) + .await; - sleep(Duration::from_secs(2)).await; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } + } } let premint_secrets = PreMintSecrets::random(keyset_id, amount, &SplitTarget::default())?; diff --git a/crates/cdk-integration-tests/tests/fake_wallet.rs b/crates/cdk-integration-tests/tests/fake_wallet.rs index cfea089c..8f34ada7 100644 --- a/crates/cdk-integration-tests/tests/fake_wallet.rs +++ b/crates/cdk-integration-tests/tests/fake_wallet.rs @@ -1,18 +1,17 @@ use std::sync::Arc; -use std::time::Duration; use anyhow::Result; use bip39::Mnemonic; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; use cdk::nuts::{ - CurrencyUnit, MeltBolt11Request, MeltQuoteState, MintQuoteState, PreMintSecrets, State, + CurrencyUnit, MeltBolt11Request, MeltQuoteState, MintQuoteState, NotificationPayload, + PreMintSecrets, State, }; use cdk::wallet::client::{HttpClient, MintConnector}; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk_fake_wallet::{create_fake_invoice, FakeInvoiceDescription}; use cdk_integration_tests::attempt_to_swap_pending; -use tokio::time::sleep; const MINT_URL: &str = "http://127.0.0.1:8086"; @@ -379,12 +378,19 @@ async fn test_fake_melt_change_in_quote() -> Result<()> { // Keep polling the state of the mint quote id until it's paid async fn wait_for_mint_to_be_paid(wallet: &Wallet, mint_quote_id: &str) -> Result<()> { - loop { - let status = wallet.mint_quote_state(mint_quote_id).await?; - if status.state == MintQuoteState::Paid { - return Ok(()); + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![ + mint_quote_id.to_owned(), + ])) + .await; + + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_millis(5)).await; } + + Ok(()) } diff --git a/crates/cdk-integration-tests/tests/integration_tests_pure.rs b/crates/cdk-integration-tests/tests/integration_tests_pure.rs index 5d4617d6..914fc2f1 100644 --- a/crates/cdk-integration-tests/tests/integration_tests_pure.rs +++ b/crates/cdk-integration-tests/tests/integration_tests_pure.rs @@ -197,7 +197,7 @@ mod integration_tests_pure { let localstore = WalletMemoryDatabase::default(); let mut wallet = Wallet::new(&mint_url, unit, Arc::new(localstore), &seed, None)?; - wallet.set_client(Arc::from(connector)); + wallet.set_client(connector); Ok(Arc::new(wallet)) } diff --git a/crates/cdk-integration-tests/tests/mint.rs b/crates/cdk-integration-tests/tests/mint.rs index 3c9f1258..2bcd281a 100644 --- a/crates/cdk-integration-tests/tests/mint.rs +++ b/crates/cdk-integration-tests/tests/mint.rs @@ -230,12 +230,13 @@ pub async fn test_p2pk_swap() -> Result<()> { let mut listener = mint .pubsub_manager - .subscribe(Params { + .try_subscribe(Params { kind: cdk::nuts::nut17::Kind::ProofState, filters: public_keys_to_listen.clone(), id: "test".into(), }) - .await; + .await + .expect("valid subscription"); match mint.process_swap_request(swap_request).await { Ok(_) => bail!("Proofs spent without sig"), diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index 846c7eb6..eb1ff633 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -12,7 +12,7 @@ use cdk::nuts::{ PreMintSecrets, State, }; use cdk::wallet::client::{HttpClient, MintConnector}; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk_integration_tests::init_regtest::{ get_mint_url, get_mint_ws_url, init_cln_client, init_lnd_client, }; @@ -20,14 +20,14 @@ use futures::{SinkExt, StreamExt}; use lightning_invoice::Bolt11Invoice; use ln_regtest_rs::InvoiceStatus; use serde_json::json; -use tokio::time::{sleep, timeout}; +use tokio::time::timeout; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::protocol::Message; async fn get_notification> + Unpin, E: Debug>( reader: &mut T, timeout_to_wait: Duration, -) -> (String, NotificationPayload) { +) -> (String, NotificationPayload) { let msg = timeout(timeout_to_wait, reader.next()) .await .expect("timeout") @@ -361,16 +361,18 @@ async fn test_cached_mint() -> Result<()> { let quote = wallet.mint_quote(mint_amount, None).await?; lnd_client.pay_invoice(quote.request).await?; - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - println!("Quote status: {}", status.state); - - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(5)).await; } let active_keyset_id = wallet.get_active_mint_keyset().await?.id; diff --git a/crates/cdk/Cargo.toml b/crates/cdk/Cargo.toml index c5f97b11..83ea0401 100644 --- a/crates/cdk/Cargo.toml +++ b/crates/cdk/Cargo.toml @@ -6,7 +6,7 @@ authors = ["CDK Developers"] description = "Core Cashu Development Kit library implementing the Cashu protocol" homepage = "https://github.com/cashubtc/cdk" repository = "https://github.com/cashubtc/cdk.git" -rust-version = "1.63.0" # MSRV +rust-version = "1.63.0" # MSRV license = "MIT" @@ -17,12 +17,18 @@ mint = ["dep:futures"] swagger = ["mint", "dep:utoipa"] wallet = ["dep:reqwest"] bench = [] +http_subscription = [] [dependencies] async-trait = "0.1" anyhow = { version = "1.0.43", features = ["backtrace"] } -bitcoin = { version= "0.32.2", features = ["base64", "serde", "rand", "rand-std"] } +bitcoin = { version = "0.32.2", features = [ + "base64", + "serde", + "rand", + "rand-std", +] } ciborium = { version = "0.2.2", default-features = false, features = ["std"] } cbor-diag = "0.1.12" lightning-invoice = { version = "0.32.0", features = ["serde", "std"] } @@ -37,9 +43,14 @@ reqwest = { version = "0.12", default-features = false, features = [ serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" serde_with = "3" -tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } +tracing = { version = "0.1", default-features = false, features = [ + "attributes", + "log", +] } thiserror = "1" -futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] } +futures = { version = "0.3.28", default-features = false, optional = true, features = [ + "alloc", +] } url = "2.3" utoipa = { version = "4", optional = true } uuid = { version = "1", features = ["v4", "serde"] } @@ -55,6 +66,11 @@ tokio = { version = "1.21", features = [ "macros", "sync", ] } +getrandom = { version = "0.2" } +tokio-tungstenite = { version = "0.19.0", features = [ + "rustls", + "rustls-tls-native-roots", +] } [target.'cfg(target_arch = "wasm32")'.dependencies] tokio = { version = "1.21", features = ["rt", "macros", "sync", "time"] } diff --git a/crates/cdk/examples/mint-token.rs b/crates/cdk/examples/mint-token.rs index 195fb0ff..cb2dce1e 100644 --- a/crates/cdk/examples/mint-token.rs +++ b/crates/cdk/examples/mint-token.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use std::time::Duration; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; use cdk::error::Error; -use cdk::nuts::{CurrencyUnit, MintQuoteState}; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload}; use cdk::wallet::types::SendKind; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk::Amount; use rand::Rng; -use tokio::time::sleep; #[tokio::main] async fn main() -> Result<(), Error> { @@ -26,16 +24,18 @@ async fn main() -> Result<(), Error> { println!("Quote: {:#?}", quote); - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; - println!("Quote status: {}", status.state); - - if status.state == MintQuoteState::Paid { - break; + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(5)).await; } let receive_amount = wallet diff --git a/crates/cdk/examples/p2pk.rs b/crates/cdk/examples/p2pk.rs index 6e51f781..7edbe579 100644 --- a/crates/cdk/examples/p2pk.rs +++ b/crates/cdk/examples/p2pk.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use std::time::Duration; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; use cdk::error::Error; -use cdk::nuts::{CurrencyUnit, MintQuoteState, SecretKey, SpendingConditions}; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload, SecretKey, SpendingConditions}; use cdk::wallet::types::SendKind; -use cdk::wallet::Wallet; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk::Amount; use rand::Rng; -use tokio::time::sleep; #[tokio::main] async fn main() -> Result<(), Error> { @@ -26,16 +24,18 @@ async fn main() -> Result<(), Error> { println!("Minting nuts ..."); - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); - - println!("Quote status: {}", status.state); - - if status.state == MintQuoteState::Paid { - break; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; + + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - sleep(Duration::from_secs(5)).await; } let _receive_amount = wallet diff --git a/crates/cdk/examples/proof-selection.rs b/crates/cdk/examples/proof-selection.rs index 210b7731..dcfab297 100644 --- a/crates/cdk/examples/proof-selection.rs +++ b/crates/cdk/examples/proof-selection.rs @@ -1,15 +1,13 @@ //! Wallet example with memory store use std::sync::Arc; -use std::time::Duration; use cdk::amount::SplitTarget; use cdk::cdk_database::WalletMemoryDatabase; -use cdk::nuts::{CurrencyUnit, MintQuoteState}; -use cdk::wallet::Wallet; +use cdk::nuts::{CurrencyUnit, MintQuoteState, NotificationPayload}; +use cdk::wallet::{Wallet, WalletSubscription}; use cdk::Amount; use rand::Rng; -use tokio::time::sleep; #[tokio::main] async fn main() { @@ -28,16 +26,18 @@ async fn main() { println!("Pay request: {}", quote.request); - loop { - let status = wallet.mint_quote_state("e.id).await.unwrap(); - - if status.state == MintQuoteState::Paid { - break; + let mut subscription = wallet + .subscribe(WalletSubscription::Bolt11MintQuoteState(vec![quote + .id + .clone()])) + .await; + + while let Some(msg) = subscription.recv().await { + if let NotificationPayload::MintQuoteBolt11Response(response) = msg { + if response.state == MintQuoteState::Paid { + break; + } } - - println!("Quote state: {}", status.state); - - sleep(Duration::from_secs(5)).await; } let receive_amount = wallet diff --git a/crates/cdk/src/lib.rs b/crates/cdk/src/lib.rs index effb04f9..80be76bd 100644 --- a/crates/cdk/src/lib.rs +++ b/crates/cdk/src/lib.rs @@ -35,7 +35,7 @@ pub use lightning_invoice::{self, Bolt11Invoice}; pub use mint::Mint; #[cfg(feature = "wallet")] #[doc(hidden)] -pub use wallet::Wallet; +pub use wallet::{Wallet, WalletSubscription}; #[doc(hidden)] pub use self::amount::Amount; diff --git a/crates/cdk/src/nuts/mod.rs b/crates/cdk/src/nuts/mod.rs index 7f913f49..0f3bc525 100644 --- a/crates/cdk/src/nuts/mod.rs +++ b/crates/cdk/src/nuts/mod.rs @@ -18,7 +18,6 @@ pub mod nut12; pub mod nut13; pub mod nut14; pub mod nut15; -#[cfg(feature = "mint")] pub mod nut17; pub mod nut18; pub mod nut19; @@ -51,5 +50,6 @@ pub use nut12::{BlindSignatureDleq, ProofDleq}; pub use nut14::HTLCWitness; pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings}; #[cfg(feature = "mint")] -pub use nut17::{NotificationPayload, PubSubManager}; +pub use nut17::PubSubManager; +pub use nut17::{NotificationPayload, SupportedSettings as Nut17SupportedSettings}; pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport}; diff --git a/crates/cdk/src/nuts/nut04.rs b/crates/cdk/src/nuts/nut04.rs index fea9d9a5..9d1c6116 100644 --- a/crates/cdk/src/nuts/nut04.rs +++ b/crates/cdk/src/nuts/nut04.rs @@ -96,6 +96,18 @@ pub struct MintQuoteBolt11Response { pub expiry: Option, } +impl MintQuoteBolt11Response { + /// Convert the MintQuote with a quote type Q to a String + pub fn to_string_id(&self) -> MintQuoteBolt11Response { + MintQuoteBolt11Response { + quote: self.quote.to_string(), + request: self.request.clone(), + state: self.state, + expiry: self.expiry, + } + } +} + #[cfg(feature = "mint")] impl From> for MintQuoteBolt11Response { fn from(value: MintQuoteBolt11Response) -> Self { diff --git a/crates/cdk/src/nuts/nut05.rs b/crates/cdk/src/nuts/nut05.rs index 3a6f5ae3..bd2b5af5 100644 --- a/crates/cdk/src/nuts/nut05.rs +++ b/crates/cdk/src/nuts/nut05.rs @@ -115,6 +115,23 @@ pub struct MeltQuoteBolt11Response { pub change: Option>, } +impl MeltQuoteBolt11Response { + /// Convert a `MeltQuoteBolt11Response` with type Q (generic/unknown) to a + /// `MeltQuoteBolt11Response` with `String` + pub fn to_string_id(self) -> MeltQuoteBolt11Response { + MeltQuoteBolt11Response { + quote: self.quote.to_string(), + amount: self.amount, + fee_reserve: self.fee_reserve, + paid: self.paid, + state: self.state, + expiry: self.expiry, + payment_preimage: self.payment_preimage, + change: self.change, + } + } +} + #[cfg(feature = "mint")] impl From> for MeltQuoteBolt11Response { fn from(value: MeltQuoteBolt11Response) -> Self { diff --git a/crates/cdk/src/nuts/nut17/manager.rs b/crates/cdk/src/nuts/nut17/manager.rs new file mode 100644 index 00000000..21103d6e --- /dev/null +++ b/crates/cdk/src/nuts/nut17/manager.rs @@ -0,0 +1,225 @@ +//! Specific Subscription for the cdk crate +use std::ops::Deref; +use std::sync::Arc; + +use uuid::Uuid; + +use super::{Notification, NotificationPayload, OnSubscription}; +use crate::cdk_database::{self, MintDatabase}; +use crate::nuts::{ + BlindSignature, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, + MintQuoteState, ProofState, +}; +use crate::pub_sub; + +/// Manager +/// Publish–subscribe manager +/// +/// Nut-17 implementation is system-wide and not only through the WebSocket, so +/// it is possible for another part of the system to subscribe to events. +pub struct PubSubManager(pub_sub::Manager, Notification, OnSubscription>); + +#[allow(clippy::default_constructed_unit_structs)] +impl Default for PubSubManager { + fn default() -> Self { + PubSubManager(OnSubscription::default().into()) + } +} + +impl From + Send + Sync>> for PubSubManager { + fn from(val: Arc + Send + Sync>) -> Self { + PubSubManager(OnSubscription(Some(val)).into()) + } +} + +impl Deref for PubSubManager { + type Target = pub_sub::Manager, Notification, OnSubscription>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PubSubManager { + /// Helper function to emit a ProofState status + pub fn proof_state>(&self, event: E) { + self.broadcast(event.into().into()); + } + + /// Helper function to emit a MintQuoteBolt11Response status + pub fn mint_quote_bolt11_status>>( + &self, + quote: E, + new_state: MintQuoteState, + ) { + let mut event = quote.into(); + event.state = new_state; + + self.broadcast(event.into()); + } + + /// Helper function to emit a MeltQuoteBolt11Response status + pub fn melt_quote_status>>( + &self, + quote: E, + payment_preimage: Option, + change: Option>, + new_state: MeltQuoteState, + ) { + let mut quote = quote.into(); + quote.state = new_state; + quote.paid = Some(new_state == MeltQuoteState::Paid); + quote.payment_preimage = payment_preimage; + quote.change = change; + self.broadcast(quote.into()); + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use tokio::time::sleep; + + use super::*; + use crate::nuts::nut17::{Kind, Params}; + use crate::nuts::{PublicKey, State}; + + #[tokio::test] + async fn active_and_drop() { + let manager = PubSubManager::default(); + let params = Params { + kind: Kind::ProofState, + filters: vec![ + "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2".to_owned(), + ], + id: "uno".into(), + }; + + // Although the same param is used, two subscriptions are created, that + // is because each index is unique, thanks to `Unique`, it is the + // responsibility of the implementor to make sure that SubId are unique + // either globally or per client + let subscriptions = vec![ + manager + .try_subscribe(params.clone()) + .await + .expect("valid subscription"), + manager + .try_subscribe(params) + .await + .expect("valid subscription"), + ]; + assert_eq!(2, manager.active_subscriptions()); + drop(subscriptions); + + sleep(Duration::from_millis(10)).await; + + assert_eq!(0, manager.active_subscriptions()); + } + + #[tokio::test] + async fn broadcast() { + let manager = PubSubManager::default(); + let mut subscriptions = [ + manager + .try_subscribe(Params { + kind: Kind::ProofState, + filters: vec![ + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + .to_string(), + ], + id: "uno".into(), + }) + .await + .expect("valid subscription"), + manager + .try_subscribe(Params { + kind: Kind::ProofState, + filters: vec![ + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + .to_string(), + ], + id: "dos".into(), + }) + .await + .expect("valid subscription"), + ]; + + let event = ProofState { + y: PublicKey::from_hex( + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + }; + + manager.broadcast(event.into()); + + sleep(Duration::from_millis(10)).await; + + let (sub1, _) = subscriptions[0].try_recv().expect("valid message"); + assert_eq!("uno", *sub1); + + let (sub1, _) = subscriptions[1].try_recv().expect("valid message"); + assert_eq!("dos", *sub1); + + assert!(subscriptions[0].try_recv().is_err()); + assert!(subscriptions[1].try_recv().is_err()); + } + + #[test] + fn parsing_request() { + let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#; + let params: Params = serde_json::from_str(json).expect("valid json"); + assert_eq!(params.kind, Kind::ProofState); + assert_eq!(params.filters, vec!["x"]); + assert_eq!(*params.id, "uno"); + } + + #[tokio::test] + async fn json_test() { + let manager = PubSubManager::default(); + let mut subscription = manager + .try_subscribe::( + serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#) + .expect("valid json"), + ) + .await.expect("valid subscription"); + + manager.broadcast( + ProofState { + y: PublicKey::from_hex( + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + } + .into(), + ); + + // no one is listening for this event + manager.broadcast( + ProofState { + y: PublicKey::from_hex( + "020000000000000000000000000000000000000000000000000000000000000001", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + } + .into(), + ); + + sleep(Duration::from_millis(10)).await; + let (sub1, msg) = subscription.try_recv().expect("valid message"); + assert_eq!("uno", *sub1); + assert_eq!( + r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#, + serde_json::to_string(&msg).expect("valid json") + ); + assert!(subscription.try_recv().is_err()); + } +} diff --git a/crates/cdk/src/nuts/nut17/mod.rs b/crates/cdk/src/nuts/nut17/mod.rs index 424eb435..296a30e5 100644 --- a/crates/cdk/src/nuts/nut17/mod.rs +++ b/crates/cdk/src/nuts/nut17/mod.rs @@ -1,27 +1,30 @@ //! Specific Subscription for the cdk crate - -use std::ops::Deref; use std::str::FromStr; -use std::sync::Arc; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; - -mod on_subscription; - -pub use on_subscription::OnSubscription; use uuid::Uuid; use super::PublicKey; -use crate::cdk_database::{self, MintDatabase}; use crate::nuts::{ - BlindSignature, CurrencyUnit, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, - MintQuoteState, PaymentMethod, ProofState, + CurrencyUnit, MeltQuoteBolt11Response, MintQuoteBolt11Response, PaymentMethod, ProofState, }; +use crate::pub_sub::{Index, Indexable, SubscriptionGlobalId}; + +#[cfg(feature = "mint")] +mod manager; +#[cfg(feature = "mint")] +mod on_subscription; +#[cfg(feature = "mint")] +pub use manager::PubSubManager; +#[cfg(feature = "mint")] +pub use on_subscription::OnSubscription; + pub use crate::pub_sub::SubId; -use crate::pub_sub::{self, Index, Indexable, SubscriptionGlobalId}; +pub mod ws; /// Subscription Parameter according to the standard -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash, Deserialize)] pub struct Params { /// Kind pub kind: Kind, @@ -33,20 +36,12 @@ pub struct Params { } /// Check state Settings -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SupportedSettings { /// Supported methods pub supported: Vec, } -impl Default for SupportedSettings { - fn default() -> Self { - SupportedSettings { - supported: vec![SupportedMethods::default()], - } - } -} - /// Supported WS Methods #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SupportedMethods { @@ -64,11 +59,7 @@ impl SupportedMethods { Self { method, unit, - commands: vec![ - "bolt11_mint_quote".to_owned(), - "bolt11_melt_quote".to_owned(), - "proof_state".to_owned(), - ], + commands: Vec::new(), } } } @@ -88,31 +79,32 @@ impl Default for SupportedMethods { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "T: Serialize + DeserializeOwned")] #[serde(untagged)] /// Subscription response -pub enum NotificationPayload { +pub enum NotificationPayload { /// Proof State ProofState(ProofState), /// Melt Quote Bolt11 Response - MeltQuoteBolt11Response(MeltQuoteBolt11Response), + MeltQuoteBolt11Response(MeltQuoteBolt11Response), /// Mint Quote Bolt11 Response - MintQuoteBolt11Response(MintQuoteBolt11Response), + MintQuoteBolt11Response(MintQuoteBolt11Response), } -impl From for NotificationPayload { - fn from(proof_state: ProofState) -> NotificationPayload { +impl From for NotificationPayload { + fn from(proof_state: ProofState) -> NotificationPayload { NotificationPayload::ProofState(proof_state) } } -impl From> for NotificationPayload { - fn from(melt_quote: MeltQuoteBolt11Response) -> NotificationPayload { +impl From> for NotificationPayload { + fn from(melt_quote: MeltQuoteBolt11Response) -> NotificationPayload { NotificationPayload::MeltQuoteBolt11Response(melt_quote) } } -impl From> for NotificationPayload { - fn from(mint_quote: MintQuoteBolt11Response) -> NotificationPayload { +impl From> for NotificationPayload { + fn from(mint_quote: MintQuoteBolt11Response) -> NotificationPayload { NotificationPayload::MintQuoteBolt11Response(mint_quote) } } @@ -128,7 +120,7 @@ pub enum Notification { MintQuoteBolt11(Uuid), } -impl Indexable for NotificationPayload { +impl Indexable for NotificationPayload { type Type = Notification; fn to_indexes(&self) -> Vec> { @@ -146,10 +138,9 @@ impl Indexable for NotificationPayload { } } +/// Kind #[derive(Debug, Clone, Copy, Eq, Ord, PartialOrd, PartialEq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] - -/// Kind pub enum Kind { /// Bolt 11 Melt Quote Bolt11MeltQuote, @@ -165,8 +156,22 @@ impl AsRef for Params { } } -impl From for Vec> { - fn from(val: Params) -> Self { +/// Parsing error +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Uuid Error: {0}")] + /// Uuid Error + Uuid(#[from] uuid::Error), + + #[error("PublicKey Error: {0}")] + /// PublicKey Error + PublicKey(#[from] crate::nuts::nut01::Error), +} + +impl TryFrom for Vec> { + type Error = Error; + + fn try_from(val: Params) -> Result { let sub_id: SubscriptionGlobalId = Default::default(); val.filters .into_iter() @@ -183,213 +188,6 @@ impl From for Vec> { Ok(Index::from((idx, val.id.clone(), sub_id))) }) - .collect::>() - .unwrap() - // TODO don't unwrap, move to try from - } -} - -/// Manager -/// Publish–subscribe manager -/// -/// Nut-17 implementation is system-wide and not only through the WebSocket, so -/// it is possible for another part of the system to subscribe to events. -pub struct PubSubManager(pub_sub::Manager); - -#[allow(clippy::default_constructed_unit_structs)] -impl Default for PubSubManager { - fn default() -> Self { - PubSubManager(OnSubscription::default().into()) - } -} - -impl From + Send + Sync>> for PubSubManager { - fn from(val: Arc + Send + Sync>) -> Self { - PubSubManager(OnSubscription(Some(val)).into()) - } -} - -impl Deref for PubSubManager { - type Target = pub_sub::Manager; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl PubSubManager { - /// Helper function to emit a ProofState status - pub fn proof_state>(&self, event: E) { - self.broadcast(event.into().into()); - } - - /// Helper function to emit a MintQuoteBolt11Response status - pub fn mint_quote_bolt11_status>>( - &self, - quote: E, - new_state: MintQuoteState, - ) { - let mut event = quote.into(); - event.state = new_state; - - self.broadcast(event.into()); - } - - /// Helper function to emit a MeltQuoteBolt11Response status - pub fn melt_quote_status>>( - &self, - quote: E, - payment_preimage: Option, - change: Option>, - new_state: MeltQuoteState, - ) { - let mut quote = quote.into(); - quote.state = new_state; - quote.paid = Some(new_state == MeltQuoteState::Paid); - quote.payment_preimage = payment_preimage; - quote.change = change; - self.broadcast(quote.into()); - } -} - -#[cfg(test)] -mod test { - use std::time::Duration; - - use tokio::time::sleep; - - use super::*; - use crate::nuts::{PublicKey, State}; - - #[tokio::test] - async fn active_and_drop() { - let manager = PubSubManager::default(); - let params = Params { - kind: Kind::ProofState, - filters: vec![PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .unwrap() - .to_string()], - id: "uno".into(), - }; - - // Although the same param is used, two subscriptions are created, that - // is because each index is unique, thanks to `Unique`, it is the - // responsibility of the implementor to make sure that SubId are unique - // either globally or per client - let subscriptions = vec![ - manager.subscribe(params.clone()).await, - manager.subscribe(params).await, - ]; - assert_eq!(2, manager.active_subscriptions()); - drop(subscriptions); - - sleep(Duration::from_millis(10)).await; - - assert_eq!(0, manager.active_subscriptions()); - } - - #[tokio::test] - async fn broadcast() { - let manager = PubSubManager::default(); - let mut subscriptions = [ - manager - .subscribe(Params { - kind: Kind::ProofState, - filters: vec![ - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - .to_string(), - ], - id: "uno".into(), - }) - .await, - manager - .subscribe(Params { - kind: Kind::ProofState, - filters: vec![ - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - .to_string(), - ], - id: "dos".into(), - }) - .await, - ]; - - let event = ProofState { - y: PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - }; - - manager.broadcast(event.into()); - - sleep(Duration::from_millis(10)).await; - - let (sub1, _) = subscriptions[0].try_recv().expect("valid message"); - assert_eq!("uno", *sub1); - - let (sub1, _) = subscriptions[1].try_recv().expect("valid message"); - assert_eq!("dos", *sub1); - - assert!(subscriptions[0].try_recv().is_err()); - assert!(subscriptions[1].try_recv().is_err()); - } - - #[test] - fn parsing_request() { - let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#; - let params: Params = serde_json::from_str(json).expect("valid json"); - assert_eq!(params.kind, Kind::ProofState); - assert_eq!(params.filters, vec!["x"]); - assert_eq!(*params.id, "uno"); - } - - #[tokio::test] - async fn json_test() { - let manager = PubSubManager::default(); - let mut subscription = manager - .subscribe::( - serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#) - .expect("valid json"), - ) - .await; - - manager.broadcast( - ProofState { - y: PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - } - .into(), - ); - - // no one is listening for this event - manager.broadcast( - ProofState { - y: PublicKey::from_hex( - "020000000000000000000000000000000000000000000000000000000000000001", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - } - .into(), - ); - - sleep(Duration::from_millis(10)).await; - let (sub1, msg) = subscription.try_recv().expect("valid message"); - assert_eq!("uno", *sub1); - assert_eq!( - r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#, - serde_json::to_string(&msg).expect("valid json") - ); - assert!(subscription.try_recv().is_err()); + .collect::>() } } diff --git a/crates/cdk/src/nuts/nut17/on_subscription.rs b/crates/cdk/src/nuts/nut17/on_subscription.rs index d7465567..37d74587 100644 --- a/crates/cdk/src/nuts/nut17/on_subscription.rs +++ b/crates/cdk/src/nuts/nut17/on_subscription.rs @@ -22,7 +22,7 @@ pub struct OnSubscription( #[async_trait::async_trait] impl OnNewSubscription for OnSubscription { - type Event = NotificationPayload; + type Event = NotificationPayload; type Index = Notification; async fn on_new_subscription( diff --git a/crates/cdk/src/nuts/nut17/ws.rs b/crates/cdk/src/nuts/nut17/ws.rs new file mode 100644 index 00000000..e59d87b6 --- /dev/null +++ b/crates/cdk/src/nuts/nut17/ws.rs @@ -0,0 +1,215 @@ +//! Websocket types + +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use super::{NotificationPayload, Params, SubId}; + +/// JSON RPC version +pub const JSON_RPC_VERSION: &str = "2.0"; + +/// The response to a subscription request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsSubscribeResponse { + /// Status + pub status: String, + /// Subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// The response to an unsubscription request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsUnsubscribeResponse { + /// Status + pub status: String, + /// Subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// The notification +/// +/// This is the notification that is sent to the client when an event matches a +/// subscription +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "T: Serialize + DeserializeOwned")] +pub struct NotificationInner { + /// The subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, + + /// The notification payload + pub payload: NotificationPayload, +} + +impl From> for NotificationInner { + fn from(value: NotificationInner) -> Self { + NotificationInner { + sub_id: value.sub_id, + payload: match value.payload { + NotificationPayload::ProofState(pk) => NotificationPayload::ProofState(pk), + NotificationPayload::MeltQuoteBolt11Response(quote) => { + NotificationPayload::MeltQuoteBolt11Response(quote.to_string_id()) + } + NotificationPayload::MintQuoteBolt11Response(quote) => { + NotificationPayload::MintQuoteBolt11Response(quote.to_string_id()) + } + }, + } + } +} + +/// Responses from the web socket server +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum WsResponseResult { + /// A response to a subscription request + Subscribe(WsSubscribeResponse), + /// Unsubscribe + Unsubscribe(WsUnsubscribeResponse), +} + +impl From for WsResponseResult { + fn from(response: WsSubscribeResponse) -> Self { + WsResponseResult::Subscribe(response) + } +} + +impl From for WsResponseResult { + fn from(response: WsUnsubscribeResponse) -> Self { + WsResponseResult::Unsubscribe(response) + } +} + +/// The request to unsubscribe +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsUnsubscribeRequest { + /// Subscription ID + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// The inner method of the websocket request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "method", content = "params")] +pub enum WsMethodRequest { + /// Subscribe method + Subscribe(Params), + /// Unsubscribe method + Unsubscribe(WsUnsubscribeRequest), +} + +/// Websocket request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsRequest { + /// JSON RPC version + pub jsonrpc: String, + /// The method body + #[serde(flatten)] + pub method: WsMethodRequest, + /// The request ID + pub id: usize, +} + +impl From<(WsMethodRequest, usize)> for WsRequest { + fn from((method, id): (WsMethodRequest, usize)) -> Self { + WsRequest { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method, + id, + } + } +} + +/// Notification from the server +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsNotification { + /// JSON RPC version + pub jsonrpc: String, + /// The method + pub method: String, + /// The parameters + pub params: T, +} + +/// Websocket error +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WsErrorBody { + /// Error code + pub code: i32, + /// Error message + pub message: String, +} + +/// Websocket response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsResponse { + /// JSON RPC version + pub jsonrpc: String, + /// The result + pub result: WsResponseResult, + /// The request ID + pub id: usize, +} + +/// WebSocket error response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsErrorResponse { + /// JSON RPC version + pub jsonrpc: String, + /// The result + pub error: WsErrorBody, + /// The request ID + pub id: usize, +} + +/// Message from the server to the client +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum WsMessageOrResponse { + /// A response to a request + Response(WsResponse), + /// An error response + ErrorResponse(WsErrorResponse), + /// A notification + Notification(WsNotification>), +} + +impl From<(usize, Result)> for WsMessageOrResponse { + fn from((id, result): (usize, Result)) -> Self { + match result { + Ok(result) => WsMessageOrResponse::Response(WsResponse { + jsonrpc: JSON_RPC_VERSION.to_owned(), + result, + id, + }), + Err(err) => WsMessageOrResponse::ErrorResponse(WsErrorResponse { + jsonrpc: JSON_RPC_VERSION.to_owned(), + error: err, + id, + }), + } + } +} + +impl From> for WsMessageOrResponse { + fn from(notification: NotificationInner) -> Self { + WsMessageOrResponse::Notification(WsNotification { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method: "subscribe".to_string(), + params: notification.into(), + }) + } +} + +impl From> for WsMessageOrResponse { + fn from(notification: NotificationInner) -> Self { + WsMessageOrResponse::Notification(WsNotification { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method: "subscribe".to_string(), + params: notification, + }) + } +} diff --git a/crates/cdk/src/pub_sub/mod.rs b/crates/cdk/src/pub_sub/mod.rs index a8269290..cff76970 100644 --- a/crates/cdk/src/pub_sub/mod.rs +++ b/crates/cdk/src/pub_sub/mod.rs @@ -157,16 +157,14 @@ where Self::broadcast_impl(&self.indexes, event).await; } - /// Subscribe to a specific event - pub async fn subscribe + Into>>>( + /// Specific of the subscription, this is the abstraction between `subscribe` and `try_subscribe` + #[inline(always)] + async fn subscribe_inner( &self, - params: P, + sub_id: SubId, + indexes: Vec>, ) -> ActiveSubscription { let (sender, receiver) = mpsc::channel(10); - let sub_id: SubId = params.as_ref().clone(); - - let indexes: Vec> = params.into(); - if let Some(on_new_subscription) = self.on_new_subscription.as_ref() { match on_new_subscription .on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::>()) @@ -204,6 +202,25 @@ where } } + /// Try to subscribe to a specific event + pub async fn try_subscribe + TryInto>>>( + &self, + params: P, + ) -> Result, P::Error> { + Ok(self + .subscribe_inner(params.as_ref().clone(), params.try_into()?) + .await) + } + + /// Subscribe to a specific event + pub async fn subscribe + Into>>>( + &self, + params: P, + ) -> ActiveSubscription { + self.subscribe_inner(params.as_ref().clone(), params.into()) + .await + } + /// Return number of active subscriptions pub fn active_subscriptions(&self) -> usize { self.active_subscriptions.load(atomic::Ordering::SeqCst) diff --git a/crates/cdk/src/wallet/client.rs b/crates/cdk/src/wallet/client.rs index 31198173..d92eec0b 100644 --- a/crates/cdk/src/wallet/client.rs +++ b/crates/cdk/src/wallet/client.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use async_trait::async_trait; use reqwest::Client; use tracing::instrument; +#[cfg(not(target_arch = "wasm32"))] use url::Url; use super::Error; diff --git a/crates/cdk/src/wallet/mod.rs b/crates/cdk/src/wallet/mod.rs index afcf8321..10609481 100644 --- a/crates/cdk/src/wallet/mod.rs +++ b/crates/cdk/src/wallet/mod.rs @@ -7,7 +7,11 @@ use std::sync::Arc; use bitcoin::bip32::Xpriv; use bitcoin::Network; use client::MintConnector; +use getrandom::getrandom; +pub use multi_mint_wallet::MultiMintWallet; +use subscription::{ActiveSubscription, SubscriptionManager}; use tracing::instrument; +pub use types::{MeltQuote, MintQuote, SendKind}; use crate::amount::SplitTarget; use crate::cdk_database::{self, WalletDatabase}; @@ -16,6 +20,7 @@ use crate::error::Error; use crate::fees::calculate_fee; use crate::mint_url::MintUrl; use crate::nuts::nut00::token::Token; +use crate::nuts::nut17::{Kind, Params}; use crate::nuts::{ nut10, CurrencyUnit, Id, Keys, MintInfo, MintQuoteState, PreMintSecrets, Proof, Proofs, RestoreRequest, SpendingConditions, State, @@ -32,13 +37,11 @@ pub mod multi_mint_wallet; mod proofs; mod receive; mod send; +pub mod subscription; mod swap; pub mod types; pub mod util; -pub use multi_mint_wallet::MultiMintWallet; -pub use types::{MeltQuote, MintQuote, SendKind}; - use crate::nuts::nut00::ProofsMethods; /// CDK Wallet @@ -58,6 +61,54 @@ pub struct Wallet { pub target_proof_count: usize, xpriv: Xpriv, client: Arc, + subscription: SubscriptionManager, +} + +const ALPHANUMERIC: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + +/// Wallet Subscription filter +#[derive(Debug, Clone)] +pub enum WalletSubscription { + /// Proof subscription + ProofState(Vec), + /// Mint quote subscription + Bolt11MintQuoteState(Vec), + /// Melt quote subscription + Bolt11MeltQuoteState(Vec), +} + +impl From for Params { + fn from(val: WalletSubscription) -> Self { + let mut buffer = vec![0u8; 10]; + + getrandom(&mut buffer).expect("Failed to generate random bytes"); + + let id = buffer + .iter() + .map(|&byte| { + let index = byte as usize % ALPHANUMERIC.len(); // 62 alphanumeric characters (A-Z, a-z, 0-9) + ALPHANUMERIC[index] as char + }) + .collect::(); + + match val { + WalletSubscription::ProofState(filters) => Params { + filters, + kind: Kind::ProofState, + id: id.into(), + }, + WalletSubscription::Bolt11MintQuoteState(filters) => Params { + filters, + kind: Kind::Bolt11MintQuote, + id: id.into(), + }, + WalletSubscription::Bolt11MeltQuoteState(filters) => Params { + filters, + kind: Kind::Bolt11MeltQuote, + id: id.into(), + }, + } + } } impl Wallet { @@ -88,10 +139,13 @@ impl Wallet { let xpriv = Xpriv::new_master(Network::Bitcoin, seed).expect("Could not create master key"); let mint_url = MintUrl::from_str(mint_url)?; + let http_client = Arc::new(HttpClient::new(mint_url.clone())); + Ok(Self { mint_url: mint_url.clone(), unit, - client: Arc::new(HttpClient::new(mint_url)), + client: http_client.clone(), + subscription: SubscriptionManager::new(http_client), localstore, xpriv, target_proof_count: target_proof_count.unwrap_or(3), @@ -99,8 +153,16 @@ impl Wallet { } /// Change HTTP client - pub fn set_client(&mut self, client: Arc) { - self.client = client; + pub fn set_client(&mut self, client: C) { + self.client = Arc::new(client); + self.subscription = SubscriptionManager::new(self.client.clone()); + } + + /// Subscribe to events + pub async fn subscribe>(&self, query: T) -> ActiveSubscription { + self.subscription + .subscribe(self.mint_url.clone(), query.into()) + .await } /// Fee required for proof set diff --git a/crates/cdk/src/wallet/subscription/http.rs b/crates/cdk/src/wallet/subscription/http.rs new file mode 100644 index 00000000..d77a852e --- /dev/null +++ b/crates/cdk/src/wallet/subscription/http.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::{mpsc, RwLock}; +use tokio::time; + +use super::WsSubscriptionBody; +use crate::nuts::nut17::Kind; +use crate::nuts::{nut01, nut04, nut05, nut07, CheckStateRequest, NotificationPayload}; +use crate::pub_sub::SubId; +use crate::wallet::client::MintConnector; + +#[derive(Debug, Hash, PartialEq, Eq)] +enum UrlType { + Mint(String), + Melt(String), + PublicKey(nut01::PublicKey), +} + +#[derive(Debug, Eq, PartialEq)] +enum AnyState { + MintQuoteState(nut04::QuoteState), + MeltQuoteState(nut05::QuoteState), + PublicKey(nut07::State), + Empty, +} + +type SubscribedTo = HashMap>, SubId, AnyState)>; + +async fn convert_subscription( + sub_id: SubId, + subscriptions: &Arc>>, + subscribed_to: &mut SubscribedTo, +) -> Option<()> { + let subscription = subscriptions.read().await; + let sub = subscription.get(&sub_id)?; + tracing::debug!("New subscription: {:?}", sub); + match sub.1.kind { + Kind::Bolt11MintQuote => { + for id in sub.1.filters.iter().map(|id| UrlType::Mint(id.clone())) { + subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); + } + } + Kind::Bolt11MeltQuote => { + for id in sub.1.filters.iter().map(|id| UrlType::Melt(id.clone())) { + subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); + } + } + Kind::ProofState => { + for id in sub + .1 + .filters + .iter() + .map(|id| nut01::PublicKey::from_hex(id).map(UrlType::PublicKey)) + { + match id { + Ok(id) => { + subscribed_to + .insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); + } + Err(err) => { + tracing::error!("Error parsing public key: {:?}. Subscription ignored, will never yield any result", err); + } + } + } + } + } + + Some(()) +} + +#[allow(clippy::incompatible_msrv)] +#[inline] +pub async fn http_main>( + initial_state: S, + http_client: Arc, + subscriptions: Arc>>, + mut new_subscription_recv: mpsc::Receiver, + mut on_drop: mpsc::Receiver, +) { + let mut interval = time::interval(Duration::from_secs(2)); + let mut subscribed_to = HashMap::, _, AnyState)>::new(); + + for sub_id in initial_state { + convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await; + } + + loop { + tokio::select! { + _ = interval.tick() => { + for (url, (sender, _, last_state)) in subscribed_to.iter_mut() { + tracing::debug!("Polling: {:?}", url); + match url { + UrlType::Mint(id) => { + let response = http_client.get_mint_quote_status(id).await; + if let Ok(response) = response { + if *last_state == AnyState::MintQuoteState(response.state) { + continue; + } + *last_state = AnyState::MintQuoteState(response.state); + if let Err(err) = sender.try_send(NotificationPayload::MintQuoteBolt11Response(response)) { + tracing::error!("Error sending mint quote response: {:?}", err); + } + } + } + UrlType::Melt(id) => { + let response = http_client.get_melt_quote_status(id).await; + if let Ok(response) = response { + if *last_state == AnyState::MeltQuoteState(response.state) { + continue; + } + *last_state = AnyState::MeltQuoteState(response.state); + if let Err(err) = sender.try_send(NotificationPayload::MeltQuoteBolt11Response(response)) { + tracing::error!("Error sending melt quote response: {:?}", err); + } + } + } + UrlType::PublicKey(id) => { + let responses = http_client.post_check_state(CheckStateRequest { + ys: vec![*id], + }).await; + if let Ok(mut responses) = responses { + let response = if let Some(state) = responses.states.pop() { + state + } else { + continue; + }; + + if *last_state == AnyState::PublicKey(response.state) { + continue; + } + *last_state = AnyState::PublicKey(response.state); + if let Err(err) = sender.try_send(NotificationPayload::ProofState(response)) { + tracing::error!("Error sending proof state response: {:?}", err); + } + } + } + } + } + } + Some(subid) = new_subscription_recv.recv() => { + convert_subscription(subid, &subscriptions, &mut subscribed_to).await; + } + Some(id) = on_drop.recv() => { + subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id); + } + } + } +} diff --git a/crates/cdk/src/wallet/subscription/mod.rs b/crates/cdk/src/wallet/subscription/mod.rs new file mode 100644 index 00000000..79ee1b7b --- /dev/null +++ b/crates/cdk/src/wallet/subscription/mod.rs @@ -0,0 +1,322 @@ +//! Client for subscriptions +//! +//! Mint servers can send notifications to clients about changes in the state, +//! according to NUT-17, using the WebSocket protocol. This module provides a +//! subscription manager that allows clients to subscribe to notifications from +//! multiple mint servers using WebSocket or with a poll-based system, using +//! the HTTP client. +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use tokio::sync::{mpsc, RwLock}; +use tokio::task::JoinHandle; +use tracing::error; + +use crate::mint_url::MintUrl; +use crate::nuts::nut17::Params; +use crate::pub_sub::SubId; +use crate::wallet::client::MintConnector; + +mod http; +#[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") +))] +mod ws; + +type WsSubscriptionBody = (mpsc::Sender, Params); + +/// Subscription manager +/// +/// This structure should be instantiated once per wallet at most. It is +/// cloneable since all its members are Arcs. +/// +/// The main goal is to provide a single interface to manage multiple +/// subscriptions to many servers to subscribe to events. If supported, the +/// WebSocket method is used to subscribe to server-side events. Otherwise, a +/// poll-based system is used, where a background task fetches information about +/// the resource every few seconds and notifies subscribers of any change +/// upstream. +/// +/// The subscribers have a simple-to-use interface, receiving an +/// ActiveSubscription struct, which can be used to receive updates and to +/// unsubscribe from updates automatically on the drop. +#[derive(Debug, Clone)] +pub struct SubscriptionManager { + all_connections: Arc>>, + http_client: Arc, +} + +impl SubscriptionManager { + /// Create a new subscription manager + pub fn new(http_client: Arc) -> Self { + Self { + all_connections: Arc::new(RwLock::new(HashMap::new())), + http_client, + } + } + + /// Subscribe to updates from a mint server with a given filter + pub async fn subscribe(&self, mint_url: MintUrl, filter: Params) -> ActiveSubscription { + let subscription_clients = self.all_connections.read().await; + let id = filter.id.clone(); + if let Some(subscription_client) = subscription_clients.get(&mint_url) { + let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await; + ActiveSubscription::new(receiver, id, on_drop_notif) + } else { + drop(subscription_clients); + + #[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") + ))] + let is_ws_support = self + .http_client + .get_mint_info() + .await + .map(|info| !info.nuts.nut17.supported.is_empty()) + .unwrap_or_default(); + + #[cfg(any( + feature = "http_subscription", + not(feature = "mint"), + target_arch = "wasm32" + ))] + let is_ws_support = false; + + tracing::debug!( + "Connect to {:?} to subscribe. WebSocket is supported ({})", + mint_url, + is_ws_support + ); + + let mut subscription_clients = self.all_connections.write().await; + let subscription_client = + SubscriptionClient::new(mint_url.clone(), self.http_client.clone(), is_ws_support); + let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await; + subscription_clients.insert(mint_url, subscription_client); + + ActiveSubscription::new(receiver, id, on_drop_notif) + } + } +} + +/// Subscription client +/// +/// If the server supports WebSocket subscriptions, this client will be used, +/// otherwise the HTTP pool and pause will be used (which is the less efficient +/// method). +#[derive(Debug)] +pub struct SubscriptionClient { + new_subscription_notif: mpsc::Sender, + on_drop_notif: mpsc::Sender, + subscriptions: Arc>>, + worker: Option>, +} + +type NotificationPayload = crate::nuts::NotificationPayload; + +/// Active Subscription +pub struct ActiveSubscription { + sub_id: Option, + on_drop_notif: mpsc::Sender, + receiver: mpsc::Receiver, +} + +impl ActiveSubscription { + fn new( + receiver: mpsc::Receiver, + sub_id: SubId, + on_drop_notif: mpsc::Sender, + ) -> Self { + Self { + sub_id: Some(sub_id), + on_drop_notif, + receiver, + } + } + + /// Try to receive a notification + pub fn try_recv(&mut self) -> Result, Error> { + match self.receiver.try_recv() { + Ok(payload) => Ok(Some(payload)), + Err(mpsc::error::TryRecvError::Empty) => Ok(None), + Err(mpsc::error::TryRecvError::Disconnected) => Err(Error::Disconnected), + } + } + + /// Receive a notification asynchronously + pub async fn recv(&mut self) -> Option { + self.receiver.recv().await + } +} + +impl Drop for ActiveSubscription { + fn drop(&mut self) { + if let Some(sub_id) = self.sub_id.take() { + let _ = self.on_drop_notif.try_send(sub_id); + } + } +} + +/// Subscription client error +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Url error + #[error("Could not join paths: {0}")] + Url(#[from] crate::mint_url::Error), + /// Disconnected from the notification channel + #[error("Disconnected from the notification channel")] + Disconnected, +} + +impl SubscriptionClient { + /// Create new [`WebSocketClient`] + pub fn new( + url: MintUrl, + http_client: Arc, + prefer_ws_method: bool, + ) -> Self { + let subscriptions = Arc::new(RwLock::new(HashMap::new())); + let (new_subscription_notif, new_subscription_recv) = mpsc::channel(100); + let (on_drop_notif, on_drop_recv) = mpsc::channel(1000); + + Self { + new_subscription_notif, + on_drop_notif, + subscriptions: subscriptions.clone(), + worker: Some(Self::start_worker( + prefer_ws_method, + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop_recv, + )), + } + } + + #[allow(unused_variables)] + fn start_worker( + prefer_ws_method: bool, + http_client: Arc, + url: MintUrl, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop_recv: mpsc::Receiver, + ) -> JoinHandle<()> { + #[cfg(any( + feature = "http_subscription", + not(feature = "mint"), + target_arch = "wasm32" + ))] + return Self::http_worker( + http_client, + subscriptions, + new_subscription_recv, + on_drop_recv, + ); + + #[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") + ))] + if prefer_ws_method { + Self::ws_worker( + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop_recv, + ) + } else { + Self::http_worker( + http_client, + subscriptions, + new_subscription_recv, + on_drop_recv, + ) + } + } + + /// Subscribe to a WebSocket channel + pub async fn subscribe( + &self, + filter: Params, + ) -> (mpsc::Sender, mpsc::Receiver) { + let mut subscriptions = self.subscriptions.write().await; + let id = filter.id.clone(); + + let (sender, receiver) = mpsc::channel(10_000); + subscriptions.insert(id.clone(), (sender, filter)); + drop(subscriptions); + + let _ = self.new_subscription_notif.send(id).await; + (self.on_drop_notif.clone(), receiver) + } + + /// HTTP subscription client + /// + /// This is a poll based subscription, where the client will poll the server + /// from time to time to get updates, notifying the subscribers on changes + fn http_worker( + http_client: Arc, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop: mpsc::Receiver, + ) -> JoinHandle<()> { + let http_worker = http::http_main( + vec![], + http_client, + subscriptions, + new_subscription_recv, + on_drop, + ); + + #[cfg(target_arch = "wasm32")] + let ret = tokio::task::spawn_local(http_worker); + + #[cfg(not(target_arch = "wasm32"))] + let ret = tokio::spawn(http_worker); + + ret + } + + /// WebSocket subscription client + /// + /// This is a WebSocket based subscription, where the client will connect to + /// the server and stay there idle waiting for server-side notifications + #[allow(clippy::incompatible_msrv)] + #[cfg(all( + not(feature = "http_subscription"), + feature = "mint", + not(target_arch = "wasm32") + ))] + fn ws_worker( + http_client: Arc, + url: MintUrl, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop: mpsc::Receiver, + ) -> JoinHandle<()> { + tokio::spawn(ws::ws_main( + http_client, + url, + subscriptions, + new_subscription_recv, + on_drop, + )) + } +} + +impl Drop for SubscriptionClient { + fn drop(&mut self) { + if let Some(sender) = self.worker.take() { + sender.abort(); + } + } +} diff --git a/crates/cdk/src/wallet/subscription/ws.rs b/crates/cdk/src/wallet/subscription/ws.rs new file mode 100644 index 00000000..22c3a70e --- /dev/null +++ b/crates/cdk/src/wallet/subscription/ws.rs @@ -0,0 +1,214 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use futures::{SinkExt, StreamExt}; +use tokio::sync::{mpsc, RwLock}; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; + +use super::http::http_main; +use super::WsSubscriptionBody; +use crate::mint_url::MintUrl; +use crate::nuts::nut17::ws::{ + WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest, +}; +use crate::nuts::nut17::Params; +use crate::pub_sub::SubId; +use crate::wallet::client::MintConnector; + +const MAX_ATTEMPT_FALLBACK_HTTP: usize = 10; + +async fn fallback_to_http>( + initial_state: S, + http_client: Arc, + subscriptions: Arc>>, + new_subscription_recv: mpsc::Receiver, + on_drop: mpsc::Receiver, +) { + http_main( + initial_state, + http_client, + subscriptions, + new_subscription_recv, + on_drop, + ) + .await +} + +#[allow(clippy::incompatible_msrv)] +#[inline] +pub async fn ws_main( + http_client: Arc, + mint_url: MintUrl, + subscriptions: Arc>>, + mut new_subscription_recv: mpsc::Receiver, + mut on_drop: mpsc::Receiver, +) { + let url = mint_url + .join_paths(&["v1", "ws"]) + .as_mut() + .map(|url| { + if url.scheme() == "https" { + url.set_scheme("wss").expect("Could not set scheme"); + } else { + url.set_scheme("ws").expect("Could not set scheme"); + } + url + }) + .expect("Could not join paths") + .to_string(); + + let mut active_subscriptions = HashMap::>::new(); + let mut failure_count = 0; + + loop { + tracing::debug!("Connecting to {}", url); + let ws_stream = match connect_async(&url).await { + Ok((ws_stream, _)) => ws_stream, + Err(err) => { + failure_count += 1; + tracing::error!("Could not connect to server: {:?}", err); + if failure_count > MAX_ATTEMPT_FALLBACK_HTTP { + tracing::error!( + "Could not connect to server after {MAX_ATTEMPT_FALLBACK_HTTP} attempts, falling back to HTTP-subscription client" + ); + return fallback_to_http( + active_subscriptions.into_keys(), + http_client, + subscriptions, + new_subscription_recv, + on_drop, + ) + .await; + } + continue; + } + }; + tracing::debug!("Connected to {}", url); + + failure_count = 0; + + let (mut write, mut read) = ws_stream.split(); + let req_id = AtomicUsize::new(0); + + let get_sub_request = |params: Params| -> Option<(usize, String)> { + let request: WsRequest = ( + WsMethodRequest::Subscribe(params), + req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + .into(); + + match serde_json::to_string(&request) { + Ok(json) => Some((request.id, json)), + Err(err) => { + tracing::error!("Could not serialize subscribe message: {:?}", err); + None + } + } + }; + + let get_unsub_request = |sub_id: SubId| -> Option { + let request: WsRequest = ( + WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }), + req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + .into(); + + match serde_json::to_string(&request) { + Ok(json) => Some(json), + Err(err) => { + tracing::error!("Could not serialize unsubscribe message: {:?}", err); + None + } + } + }; + + // Websocket reconnected, restore all subscriptions + let mut subscription_requests = HashSet::new(); + + let read_subscriptions = subscriptions.read().await; + for (sub_id, _) in active_subscriptions.iter() { + if let Some(Some((req_id, req))) = read_subscriptions + .get(sub_id) + .map(|(_, params)| get_sub_request(params.clone())) + { + let _ = write.send(Message::Text(req)).await; + subscription_requests.insert(req_id); + } + } + drop(read_subscriptions); + + loop { + tokio::select! { + Some(msg) = read.next() => { + let msg = match msg { + Ok(msg) => msg, + Err(_) => break, + }; + let msg = match msg { + Message::Text(msg) => msg, + _ => continue, + }; + let msg = match serde_json::from_str::(&msg) { + Ok(msg) => msg, + Err(_) => continue, + }; + + match msg { + WsMessageOrResponse::Notification(payload) => { + tracing::debug!("Received notification from server: {:?}", payload); + let _ = active_subscriptions.get(&payload.params.sub_id).map(|sender| { + let _ = sender.try_send(payload.params.payload); + }); + } + WsMessageOrResponse::Response(response) => { + tracing::debug!("Received response from server: {:?}", response); + subscription_requests.remove(&response.id); + } + WsMessageOrResponse::ErrorResponse(error) => { + tracing::error!("Received error from server: {:?}", error); + if subscription_requests.contains(&error.id) { + // If the server sends an error response to a subscription request, we should + // fallback to HTTP. + // TODO: Add some retry before giving up to HTTP. + return fallback_to_http( + active_subscriptions.into_keys(), + http_client, + subscriptions, + new_subscription_recv, + on_drop, + ).await; + } + } + } + + } + Some(subid) = new_subscription_recv.recv() => { + let subscription = subscriptions.read().await; + let sub = if let Some(subscription) = subscription.get(&subid) { + subscription + } else { + continue + }; + tracing::debug!("Subscribing to {:?}", sub.1); + active_subscriptions.insert(subid, sub.0.clone()); + if let Some((req_id, json)) = get_sub_request(sub.1.clone()) { + let _ = write.send(Message::Text(json)).await; + subscription_requests.insert(req_id); + } + }, + Some(subid) = on_drop.recv() => { + let mut subscription = subscriptions.write().await; + if let Some(sub) = subscription.remove(&subid) { + drop(sub); + } + tracing::debug!("Unsubscribing from {:?}", subid); + if let Some(json) = get_unsub_request(subid) { + let _ = write.send(Message::Text(json)).await; + } + } + } + } + } +} diff --git a/crates/cdk/src/wallet/websocket.rs b/crates/cdk/src/wallet/websocket.rs new file mode 100644 index 00000000..023156e9 --- /dev/null +++ b/crates/cdk/src/wallet/websocket.rs @@ -0,0 +1,54 @@ +//! Websocket types + +use serde::{Deserialize, Serialize}; + +use crate::{nuts::nut17::Params, pub_sub::SubId}; + +/// JSON RPC version +pub const JSON_RPC_VERSION: &str = "2.0"; + +/// Websocket request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsRequest { + pub jsonrpc: String, + #[serde(flatten)] + pub method: WsMethod, + pub id: usize, +} + +/// Websocket method +/// +/// List of possible methods that can be called on the websocket +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "method", content = "params")] +pub enum WsMethod { + /// Subscribe to a topic + Subscribe(Params), + /// Unsubscribe from a topic + Unsubscribe(UnsubscribeMethod), +} + +/// Unsubscribe method +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UnsubscribeMethod { + #[serde(rename = "subId")] + pub sub_id: SubId, +} + +/// Websocket error response +#[derive(Debug, Clone, Serialize)] +pub struct WsErrorResponse { + code: i32, + message: String, +} + +/// Websocket response +#[derive(Debug, Clone, Serialize)] +pub struct WsResponse { + jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + id: usize, +} diff --git a/misc/itests.sh b/misc/itests.sh index 50eb7f7f..ef85d23b 100755 --- a/misc/itests.sh +++ b/misc/itests.sh @@ -83,6 +83,9 @@ done # Run cargo test cargo test -p cdk-integration-tests --test regtest +# Run cargo test with the http_subscription feature +cargo test -p cdk-integration-tests --test regtest --features http_subscription + # Capture the exit status of cargo test test_status=$?