diff --git a/crates/cdk-axum/Cargo.toml b/crates/cdk-axum/Cargo.toml index 241da962..ae458dd4 100644 --- a/crates/cdk-axum/Cargo.toml +++ b/crates/cdk-axum/Cargo.toml @@ -5,21 +5,30 @@ edition = "2021" license = "MIT" homepage = "https://github.com/cashubtc/cdk" repository = "https://github.com/cashubtc/cdk.git" -rust-version = "1.63.0" # MSRV +rust-version = "1.63.0" # MSRV description = "Cashu CDK axum webserver" [dependencies] anyhow = "1" -async-trait = "0.1" -axum = "0.6.20" -cdk = { path = "../cdk", version = "0.4.0", default-features = false, features = ["mint"] } -tokio = { version = "1", default-features = false } -tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } -utoipa = { version = "4", features = ["preserve_order", "preserve_path_order"], optional = true } +async-trait = "0.1.83" +axum = { version = "0.6.20", features = ["ws"] } +cdk = { path = "../cdk", version = "0.4.0", default-features = false, features = [ + "mint", +] } +tokio = { version = "1", default-features = false, features = ["io-util"] } +tracing = { version = "0.1", default-features = false, features = [ + "attributes", + "log", +] } +utoipa = { version = "4", features = [ + "preserve_order", + "preserve_path_order", +], optional = true } futures = { version = "0.3.28", default-features = false } moka = { version = "0.11.1", features = ["future"] } serde_json = "1" paste = "1.0.15" +serde = { version = "1.0.210", features = ["derive"] } [features] -swagger = ["cdk/swagger", "dep:utoipa"] \ No newline at end of file +swagger = ["cdk/swagger", "dep:utoipa"] diff --git a/crates/cdk-axum/src/lib.rs b/crates/cdk-axum/src/lib.rs index 9083163b..91ca512c 100644 --- a/crates/cdk-axum/src/lib.rs +++ b/crates/cdk-axum/src/lib.rs @@ -14,6 +14,7 @@ use moka::future::Cache; use router_handlers::*; mod router_handlers; +mod ws; #[cfg(feature = "swagger")] mod swagger_imports { @@ -154,6 +155,7 @@ pub async fn create_mint_router(mint: Arc, cache_ttl: u64, cache_tti: u64) ) .route("/mint/bolt11", post(cache_post_mint_bolt11)) .route("/melt/quote/bolt11", post(get_melt_bolt11_quote)) + .route("/ws", get(ws_handler)) .route( "/melt/quote/bolt11/:quote_id", get(get_check_melt_bolt11_quote), diff --git a/crates/cdk-axum/src/router_handlers.rs b/crates/cdk-axum/src/router_handlers.rs index c4ba781a..b6f462ba 100644 --- a/crates/cdk-axum/src/router_handlers.rs +++ b/crates/cdk-axum/src/router_handlers.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use axum::extract::{Json, Path, State}; +use axum::extract::{ws::WebSocketUpgrade, Json, Path, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use cdk::error::ErrorResponse; @@ -13,7 +13,7 @@ use cdk::util::unix_time; use cdk::Error; use paste::paste; -use crate::MintState; +use crate::{ws::main_websocket, MintState}; macro_rules! post_cache_wrapper { ($handler:ident, $request_type:ty, $response_type:ty) => { @@ -174,6 +174,15 @@ pub async fn get_check_mint_bolt11_quote( Ok(Json(quote)) } +pub async fn ws_handler(State(state): State, ws: WebSocketUpgrade) -> impl IntoResponse { + ws.on_upgrade(|ws| main_websocket(ws, state)) +} + +/// Mint tokens by paying a BOLT11 Lightning invoice. +/// +/// Requests the minting of tokens belonging to a paid payment request. +/// +/// Call this endpoint after `POST /v1/mint/quote`. #[cfg_attr(feature = "swagger", utoipa::path( post, context_path = "/v1", @@ -184,11 +193,6 @@ pub async fn get_check_mint_bolt11_quote( (status = 500, description = "Server error", body = ErrorResponse, content_type = "application/json") ) ))] -/// Mint tokens by paying a BOLT11 Lightning invoice. -/// -/// Requests the minting of tokens belonging to a paid payment request. -/// -/// Call this endpoint after `POST /v1/mint/quote`. pub async fn post_mint_bolt11( State(state): State, Json(payload): Json, diff --git a/crates/cdk-axum/src/ws/error.rs b/crates/cdk-axum/src/ws/error.rs new file mode 100644 index 00000000..24fa4c8c --- /dev/null +++ b/crates/cdk-axum/src/ws/error.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// Source: https://www.jsonrpc.org/specification#error_object +pub enum WsError { + /// Invalid JSON was received by the server. + /// An error occurred on the server while parsing the JSON text. + ParseError, + /// The JSON sent is not a valid Request object. + InvalidRequest, + /// The method does not exist / is not available. + MethodNotFound, + /// Invalid method parameter(s). + InvalidParams, + /// Internal JSON-RPC error. + InternalError, + /// Custom error + ServerError(i32, String), +} diff --git a/crates/cdk-axum/src/ws/handler.rs b/crates/cdk-axum/src/ws/handler.rs new file mode 100644 index 00000000..ea1ba3ae --- /dev/null +++ b/crates/cdk-axum/src/ws/handler.rs @@ -0,0 +1,70 @@ +use super::{WsContext, WsError, JSON_RPC_VERSION}; +use serde::Serialize; + +impl From for WsErrorResponse { + fn from(val: WsError) -> Self { + let (id, message) = match val { + WsError::ParseError => (-32700, "Parse error".to_string()), + WsError::InvalidRequest => (-32600, "Invalid Request".to_string()), + WsError::MethodNotFound => (-32601, "Method not found".to_string()), + WsError::InvalidParams => (-32602, "Invalid params".to_string()), + WsError::InternalError => (-32603, "Internal error".to_string()), + WsError::ServerError(code, message) => (code, message), + }; + WsErrorResponse { code: id, message } + } +} + +#[derive(Debug, Clone, Serialize)] +struct WsErrorResponse { + code: i32, + message: String, +} + +#[derive(Debug, Clone, Serialize)] +struct WsResponse { + jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + id: usize, +} + +#[derive(Debug, Clone, Serialize)] +pub struct WsNotification { + pub jsonrpc: String, + pub method: String, + pub params: T, +} + +#[async_trait::async_trait] +pub trait WsHandle { + type Response: Serialize + Sized; + + async fn process( + self, + req_id: usize, + context: &mut WsContext, + ) -> Result + where + Self: Sized, + { + serde_json::to_value(&match self.handle(context).await { + Ok(response) => WsResponse { + jsonrpc: JSON_RPC_VERSION.to_owned(), + result: Some(response), + error: None, + id: req_id, + }, + Err(error) => WsResponse { + jsonrpc: JSON_RPC_VERSION.to_owned(), + result: None, + error: Some(error.into()), + id: req_id, + }, + }) + } + + async fn handle(self, context: &mut WsContext) -> Result; +} diff --git a/crates/cdk-axum/src/ws/mod.rs b/crates/cdk-axum/src/ws/mod.rs new file mode 100644 index 00000000..4e7bf5f9 --- /dev/null +++ b/crates/cdk-axum/src/ws/mod.rs @@ -0,0 +1,121 @@ +use crate::MintState; +use axum::extract::ws::{Message, WebSocket}; +use cdk::nuts::nut17::{NotificationPayload, SubId}; +use futures::StreamExt; +use handler::{WsHandle, WsNotification}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use subscribe::Notification; +use tokio::sync::mpsc; + +mod error; +mod handler; +mod subscribe; +mod unsubscribe; + +/// JSON RPC version +pub const JSON_RPC_VERSION: &str = "2.0"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsRequest { + jsonrpc: String, + #[serde(flatten)] + method: WsMethod, + id: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "method", content = "params")] +pub enum WsMethod { + Subscribe(subscribe::Method), + Unsubscribe(unsubscribe::Method), +} + +impl WsMethod { + pub async fn process( + self, + req_id: usize, + context: &mut WsContext, + ) -> Result { + match self { + WsMethod::Subscribe(sub) => sub.process(req_id, context), + WsMethod::Unsubscribe(unsub) => unsub.process(req_id, context), + } + .await + } +} + +pub use error::WsError; + +pub struct WsContext { + state: MintState, + subscriptions: HashMap>, + publisher: mpsc::Sender<(SubId, NotificationPayload)>, +} + +/// Main function for websocket connections +/// +/// This function will handle all incoming websocket connections and keep them in their own loop. +/// +/// For simplicity sake this function will spawn tasks for each subscription and +/// keep them in a hashmap, and will have a single subscriber for all of them. +#[allow(clippy::incompatible_msrv)] +pub async fn main_websocket(mut socket: WebSocket, state: MintState) { + let (publisher, mut subscriber) = mpsc::channel(100); + let mut context = WsContext { + state, + subscriptions: HashMap::new(), + publisher, + }; + + loop { + tokio::select! { + Some((sub_id, payload)) = subscriber.recv() => { + if !context.subscriptions.contains_key(&sub_id) { + // It may be possible an incoming message has come from a dropped Subscriptions that has not yet been + // unsubscribed from the subscription manager, just ignore it. + continue; + } + let notification: WsNotification = (sub_id, payload).into(); + let message = if let Ok(message) = serde_json::to_string(¬ification) { + message + } else { + tracing::error!("Could not serialize notification"); + continue; + }; + + if socket.send(Message::Text(message)).await.is_err() { + break; + } + } + Some(Ok(Message::Text(text))) = socket.next() => { + let request = match serde_json::from_str::(&text) { + Ok(request) => request, + Err(err) => { + tracing::error!("Could not parse request: {}", err); + continue; + } + }; + + match request.method.process(request.id, &mut context).await { + Ok(result) => { + if socket + .send(Message::Text(result.to_string())) + .await + .is_err() + { + break; + } + } + Err(err) => { + tracing::error!("Error serializing response: {}", err); + break; + } + } + } + else => { + + } + } + } +} diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs new file mode 100644 index 00000000..893406f5 --- /dev/null +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -0,0 +1,61 @@ +use super::{ + handler::{WsHandle, WsNotification}, + WsContext, WsError, JSON_RPC_VERSION, +}; +use cdk::{ + nuts::nut17::{NotificationPayload, Params}, + pub_sub::SubId, +}; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct Method(Params); + +#[derive(Debug, Clone, serde::Serialize)] +pub struct Response { + status: String, + sub_id: SubId, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct Notification { + #[serde(rename = "subId")] + pub sub_id: SubId, + + pub payload: NotificationPayload, +} + +impl From<(SubId, NotificationPayload)> for WsNotification { + fn from((sub_id, payload): (SubId, NotificationPayload)) -> Self { + WsNotification { + jsonrpc: JSON_RPC_VERSION.to_owned(), + method: "subscribe".to_string(), + params: Notification { sub_id, payload }, + } + } +} + +#[async_trait::async_trait] +impl WsHandle for Method { + type Response = Response; + + async fn handle(self, context: &mut WsContext) -> Result { + let sub_id = self.0.id.clone(); + if context.subscriptions.contains_key(&sub_id) { + return Err(WsError::InvalidParams); + } + let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; + let publisher = context.publisher.clone(); + context.subscriptions.insert( + sub_id.clone(), + tokio::spawn(async move { + while let Some(response) = subscription.recv().await { + let _ = publisher.send(response).await; + } + }), + ); + Ok(Response { + status: "OK".to_string(), + sub_id, + }) + } +} diff --git a/crates/cdk-axum/src/ws/unsubscribe.rs b/crates/cdk-axum/src/ws/unsubscribe.rs new file mode 100644 index 00000000..6047468e --- /dev/null +++ b/crates/cdk-axum/src/ws/unsubscribe.rs @@ -0,0 +1,29 @@ +use super::{handler::WsHandle, WsContext, WsError}; +use cdk::pub_sub::SubId; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct Method { + pub sub_id: SubId, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct Response { + status: String, + sub_id: SubId, +} + +#[async_trait::async_trait] +impl WsHandle for Method { + type Response = Response; + + async fn handle(self, context: &mut WsContext) -> Result { + if context.subscriptions.remove(&self.sub_id).is_some() { + Ok(Response { + status: "OK".to_string(), + sub_id: self.sub_id, + }) + } else { + Err(WsError::InvalidParams) + } + } +} diff --git a/crates/cdk-integration-tests/Cargo.toml b/crates/cdk-integration-tests/Cargo.toml index 7291eba9..8753408c 100644 --- a/crates/cdk-integration-tests/Cargo.toml +++ b/crates/cdk-integration-tests/Cargo.toml @@ -7,7 +7,7 @@ description = "Core Cashu Development Kit library implementing the Cashu protoco license = "MIT" homepage = "https://github.com/cashubtc/cdk" repository = "https://github.com/cashubtc/cdk.git" -rust-version = "1.63.0" # MSRV +rust-version = "1.63.0" # MSRV [features] @@ -20,12 +20,14 @@ bip39 = { version = "2.0", features = ["rand"] } anyhow = "1" cdk = { path = "../cdk", version = "0.4.0", features = ["mint", "wallet"] } cdk-cln = { path = "../cdk-cln", version = "0.4.0" } -cdk-axum = { path = "../cdk-axum"} -cdk-sqlite = { path = "../cdk-sqlite"} -cdk-redb = { path = "../cdk-redb"} +cdk-axum = { path = "../cdk-axum" } +cdk-sqlite = { path = "../cdk-sqlite" } +cdk-redb = { path = "../cdk-redb" } cdk-fake-wallet = { path = "../cdk-fake-wallet" } tower-http = { version = "0.4.4", features = ["cors"] } -futures = { version = "0.3.28", default-features = false, features = ["executor"] } +futures = { version = "0.3.28", default-features = false, features = [ + "executor", +] } once_cell = "1.19.0" uuid = { version = "1", features = ["v4"] } serde = "1" @@ -33,9 +35,13 @@ serde_json = "1" # ln-regtest-rs = { path = "../../../../ln-regtest-rs" } ln-regtest-rs = { git = "https://github.com/thesimplekid/ln-regtest-rs", rev = "1d88d3d0b" } lightning-invoice = { version = "0.32.0", features = ["serde", "std"] } -tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } +tracing = { version = "0.1", default-features = false, features = [ + "attributes", + "log", +] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tower-service = "0.3.3" +tokio-tungstenite = "0.24.0" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio = { version = "1", features = [ @@ -52,7 +58,7 @@ instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] } [dev-dependencies] rand = "0.8.5" -bip39 = { version= "2.0", features = ["rand"] } +bip39 = { version = "2.0", features = ["rand"] } anyhow = "1" cdk = { path = "../cdk", features = ["mint", "wallet"] } cdk-axum = { path = "../cdk-axum" } diff --git a/crates/cdk-integration-tests/src/init_regtest.rs b/crates/cdk-integration-tests/src/init_regtest.rs index 769a3350..f9664c18 100644 --- a/crates/cdk-integration-tests/src/init_regtest.rs +++ b/crates/cdk-integration-tests/src/init_regtest.rs @@ -45,6 +45,10 @@ pub fn get_mint_url() -> String { format!("http://{}:{}", get_mint_addr(), get_mint_port()) } +pub fn get_mint_ws_url() -> String { + format!("ws://{}:{}/v1/ws", get_mint_addr(), get_mint_port()) +} + pub fn get_temp_dir() -> PathBuf { let dir = env::var("cdk_itests").expect("Temp dir set"); std::fs::create_dir_all(&dir).unwrap(); diff --git a/crates/cdk-integration-tests/tests/mint.rs b/crates/cdk-integration-tests/tests/mint.rs index c86e2dd3..c5cfbc58 100644 --- a/crates/cdk-integration-tests/tests/mint.rs +++ b/crates/cdk-integration-tests/tests/mint.rs @@ -6,16 +6,20 @@ use cdk::amount::{Amount, SplitTarget}; use cdk::cdk_database::mint_memory::MintMemoryDatabase; use cdk::dhke::construct_proofs; use cdk::mint::MintQuote; +use cdk::nuts::nut00::ProofsMethods; +use cdk::nuts::nut17::Params; use cdk::nuts::{ - CurrencyUnit, Id, MintBolt11Request, MintInfo, Nuts, PreMintSecrets, Proofs, SecretKey, - SpendingConditions, SwapRequest, + CurrencyUnit, Id, MintBolt11Request, MintInfo, NotificationPayload, Nuts, PreMintSecrets, + ProofState, Proofs, SecretKey, SpendingConditions, State, SwapRequest, }; use cdk::types::QuoteTTL; use cdk::util::unix_time; use cdk::Mint; use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use tokio::sync::OnceCell; +use tokio::time::sleep; pub const MINT_URL: &str = "http://127.0.0.1:8088"; @@ -206,6 +210,31 @@ pub async fn test_p2pk_swap() -> Result<()> { let swap_request = SwapRequest::new(proofs.clone(), pre_swap.blinded_messages()); + let public_keys_to_listen: Vec<_> = swap_request + .inputs + .ys() + .expect("key") + .into_iter() + .enumerate() + .filter_map(|(key, pk)| { + if key % 2 == 0 { + // Only expect messages from every other key + Some(pk.to_string()) + } else { + None + } + }) + .collect(); + + let mut listener = mint + .pubsub_manager + .subscribe(Params { + kind: cdk::nuts::nut17::Kind::ProofState, + filters: public_keys_to_listen.clone(), + id: "test".into(), + }) + .await; + match mint.process_swap_request(swap_request).await { Ok(_) => bail!("Proofs spent without sig"), Err(err) => match err { @@ -227,6 +256,34 @@ pub async fn test_p2pk_swap() -> Result<()> { assert!(attempt_swap.is_ok()); + sleep(Duration::from_millis(10)).await; + + let mut msgs = HashMap::new(); + while let Ok((sub_id, msg)) = listener.try_recv() { + assert_eq!(sub_id, "test".into()); + match msg { + NotificationPayload::ProofState(ProofState { y, state, .. }) => { + let pk = y.to_string(); + msgs.get_mut(&pk) + .map(|x: &mut Vec| { + x.push(state); + }) + .unwrap_or_else(|| { + msgs.insert(pk, vec![state]); + }); + } + _ => bail!("Wrong message received"), + } + } + + for keys in public_keys_to_listen { + let statuses = msgs.remove(&keys).expect("some events"); + assert_eq!(statuses, vec![State::Pending, State::Pending, State::Spent]); + } + + assert!(listener.try_recv().is_err(), "no other event is happening"); + assert!(msgs.is_empty(), "Only expected key events are received"); + Ok(()) } diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index c10d5ffe..da378f2a 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -1,17 +1,56 @@ -use std::{str::FromStr, sync::Arc, time::Duration}; +use std::{fmt::Debug, str::FromStr, sync::Arc, time::Duration}; use anyhow::{bail, Result}; use bip39::Mnemonic; use cdk::{ amount::{Amount, SplitTarget}, cdk_database::WalletMemoryDatabase, - nuts::{CurrencyUnit, MeltQuoteState, MintQuoteState, PreMintSecrets, State}, + nuts::{ + CurrencyUnit, MeltQuoteState, MintQuoteState, NotificationPayload, PreMintSecrets, State, + }, wallet::{client::HttpClient, Wallet}, }; -use cdk_integration_tests::init_regtest::{get_mint_url, init_cln_client, init_lnd_client}; +use cdk_integration_tests::init_regtest::{ + get_mint_url, get_mint_ws_url, init_cln_client, init_lnd_client, +}; +use futures::{SinkExt, StreamExt}; use lightning_invoice::Bolt11Invoice; use ln_regtest_rs::InvoiceStatus; -use tokio::time::sleep; +use serde_json::json; +use tokio::time::{sleep, timeout}; +use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; + +async fn get_notification> + Unpin, E: Debug>( + reader: &mut T, + timeout_to_wait: Duration, +) -> (String, NotificationPayload) { + let msg = timeout(timeout_to_wait, reader.next()) + .await + .expect("timeout") + .unwrap() + .unwrap(); + + let mut response: serde_json::Value = + serde_json::from_str(&msg.to_text().unwrap()).expect("valid json"); + + let mut params_raw = response + .as_object_mut() + .expect("object") + .remove("params") + .expect("valid params"); + + let params_map = params_raw.as_object_mut().expect("params is object"); + + ( + params_map + .remove("subId") + .unwrap() + .as_str() + .unwrap() + .to_string(), + serde_json::from_value(params_map.remove("payload").unwrap()).unwrap(), + ) +} #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_regtest_mint_melt_round_trip() -> Result<()> { @@ -25,6 +64,11 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> { None, )?; + let (ws_stream, _) = connect_async(get_mint_ws_url()) + .await + .expect("Failed to connect"); + let (mut write, mut reader) = ws_stream.split(); + let mint_quote = wallet.mint_quote(100.into(), None).await?; lnd_client.pay_invoice(mint_quote.request).await?; @@ -39,11 +83,40 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> { let melt = wallet.melt_quote(invoice, None).await?; - let melt = wallet.melt(&melt.id).await.unwrap(); - - assert!(melt.preimage.is_some()); + write + .send(Message::Text(serde_json::to_string(&json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "subscribe", + "params": { + "kind": "bolt11_melt_quote", + "filters": [ + melt.id.clone(), + ], + "subId": "test-sub", + } + + }))?)) + .await?; - assert!(melt.state == MeltQuoteState::Paid); + assert_eq!( + reader.next().await.unwrap().unwrap().to_text().unwrap(), + r#"{"jsonrpc":"2.0","result":{"status":"OK","sub_id":"test-sub"},"id":2}"# + ); + + let melt_response = wallet.melt(&melt.id).await.unwrap(); + assert!(melt_response.preimage.is_some()); + assert!(melt_response.state == MeltQuoteState::Paid); + + let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await; + assert_eq!("test-sub", sub_id); + let payload = match payload { + NotificationPayload::MeltQuoteBolt11Response(melt) => melt, + _ => panic!("Wrong payload"), + }; + assert_eq!(payload.amount + payload.fee_reserve, 100.into()); + assert_eq!(payload.quote, melt.id); + assert_eq!(payload.state, MeltQuoteState::Paid); Ok(()) } diff --git a/crates/cdk-mintd/src/main.rs b/crates/cdk-mintd/src/main.rs index 4adc2809..194facbe 100644 --- a/crates/cdk-mintd/src/main.rs +++ b/crates/cdk-mintd/src/main.rs @@ -537,14 +537,15 @@ async fn check_pending_mint_quotes( for quote in unpaid_quotes { tracing::trace!("Checking status of mint quote: {}", quote.id); - let lookup_id = quote.request_lookup_id; - match ln.check_incoming_invoice_status(&lookup_id).await { + let lookup_id = quote.request_lookup_id.as_str(); + match ln.check_incoming_invoice_status(lookup_id).await { Ok(state) => { if state != quote.state { tracing::trace!("Mint quote status changed: {}", quote.id); mint.localstore .update_mint_quote_state("e.id, state) .await?; + mint.pubsub_manager.mint_quote_bolt11_status("e, state); } } diff --git a/crates/cdk/src/lib.rs b/crates/cdk/src/lib.rs index 3d243047..effb04f9 100644 --- a/crates/cdk/src/lib.rs +++ b/crates/cdk/src/lib.rs @@ -20,6 +20,8 @@ pub mod util; #[cfg(feature = "wallet")] pub mod wallet; +pub mod pub_sub; + pub mod fees; #[doc(hidden)] diff --git a/crates/cdk/src/mint/check_spendable.rs b/crates/cdk/src/mint/check_spendable.rs index 7527abea..3ac3ad34 100644 --- a/crates/cdk/src/mint/check_spendable.rs +++ b/crates/cdk/src/mint/check_spendable.rs @@ -57,6 +57,17 @@ impl Mint { return Err(Error::TokenAlreadySpent); } + for public_key in ys { + self.pubsub_manager.broadcast( + ProofState { + y: *public_key, + state: proof_state, + witness: None, + } + .into(), + ); + } + Ok(()) } } diff --git a/crates/cdk/src/mint/melt.rs b/crates/cdk/src/mint/melt.rs index 53268b4d..6a73b179 100644 --- a/crates/cdk/src/mint/melt.rs +++ b/crates/cdk/src/mint/melt.rs @@ -16,6 +16,7 @@ use crate::{ Amount, Error, }; +use super::ProofState; use super::{ CurrencyUnit, MeltBolt11Request, MeltQuote, MeltQuoteBolt11Request, MeltQuoteBolt11Response, Mint, PaymentMethod, PublicKey, State, @@ -358,6 +359,22 @@ impl Mint { .update_melt_quote_state(&melt_request.quote, MeltQuoteState::Unpaid) .await?; + if let Ok(Some(quote)) = self.localstore.get_melt_quote(&melt_request.quote).await { + self.pubsub_manager + .melt_quote_status("e, None, None, MeltQuoteState::Unpaid); + } + + for public_key in input_ys { + self.pubsub_manager.broadcast( + ProofState { + y: public_key, + state: State::Unspent, + witness: None, + } + .into(), + ); + } + Ok(()) } @@ -595,6 +612,24 @@ impl Mint { .update_melt_quote_state(&melt_request.quote, MeltQuoteState::Paid) .await?; + self.pubsub_manager.melt_quote_status( + "e, + payment_preimage.clone(), + None, + MeltQuoteState::Paid, + ); + + for public_key in input_ys { + self.pubsub_manager.broadcast( + ProofState { + y: public_key, + state: State::Spent, + witness: None, + } + .into(), + ); + } + let mut change = None; // Check if there is change to return diff --git a/crates/cdk/src/mint/mint_nut04.rs b/crates/cdk/src/mint/mint_nut04.rs index 7d46263d..11ce0016 100644 --- a/crates/cdk/src/mint/mint_nut04.rs +++ b/crates/cdk/src/mint/mint_nut04.rs @@ -4,7 +4,7 @@ use crate::{nuts::MintQuoteState, types::LnKey, util::unix_time, Amount, Error}; use super::{ nut04, CurrencyUnit, Mint, MintQuote, MintQuoteBolt11Request, MintQuoteBolt11Response, - PaymentMethod, PublicKey, + NotificationPayload, PaymentMethod, PublicKey, }; impl Mint { @@ -114,7 +114,12 @@ impl Mint { self.localstore.add_mint_quote(quote.clone()).await?; - Ok(quote.into()) + let quote: MintQuoteBolt11Response = quote.into(); + + self.pubsub_manager + .broadcast(NotificationPayload::MintQuoteBolt11Response(quote.clone())); + + Ok(quote) } /// Check mint quote @@ -201,7 +206,6 @@ impl Mint { "Received payment notification for mint quote {}", mint_quote.id ); - if mint_quote.state != MintQuoteState::Issued && mint_quote.state != MintQuoteState::Paid { @@ -233,6 +237,9 @@ impl Mint { mint_quote.state ); } + + self.pubsub_manager + .mint_quote_bolt11_status(&mint_quote, MintQuoteState::Paid); } Ok(()) } @@ -243,14 +250,12 @@ impl Mint { &self, mint_request: nut04::MintBolt11Request, ) -> Result { - if self - .localstore - .get_mint_quote(&mint_request.quote) - .await? - .is_none() - { - return Err(Error::UnknownQuote); - } + let mint_quote = + if let Some(mint_quote) = self.localstore.get_mint_quote(&mint_request.quote).await? { + mint_quote + } else { + return Err(Error::UnknownQuote); + }; let state = self .localstore @@ -295,6 +300,10 @@ impl Mint { .update_mint_quote_state(&mint_request.quote, MintQuoteState::Paid) .await .unwrap(); + + self.pubsub_manager + .mint_quote_bolt11_status(&mint_quote, MintQuoteState::Paid); + return Err(Error::BlindedMessageAlreadySigned); } @@ -321,6 +330,9 @@ impl Mint { .update_mint_quote_state(&mint_request.quote, MintQuoteState::Issued) .await?; + self.pubsub_manager + .mint_quote_bolt11_status(&mint_quote, MintQuoteState::Issued); + Ok(nut04::MintBolt11Response { signatures: blind_signatures, }) diff --git a/crates/cdk/src/mint/mod.rs b/crates/cdk/src/mint/mod.rs index 0699cf7f..29a0b16f 100644 --- a/crates/cdk/src/mint/mod.rs +++ b/crates/cdk/src/mint/mod.rs @@ -46,6 +46,8 @@ pub struct Mint { pub localstore: Arc + Send + Sync>, /// Ln backends for mint pub ln: HashMap + Send + Sync>>, + /// Subscription manager + pub pubsub_manager: Arc, /// Active Mint Keysets keysets: Arc>>, secp_ctx: Secp256k1, @@ -170,6 +172,7 @@ impl Mint { Ok(Self { mint_url: MintUrl::from_str(mint_url)?, keysets: Arc::new(RwLock::new(active_keysets)), + pubsub_manager: Default::default(), secp_ctx, quote_ttl, xpriv, diff --git a/crates/cdk/src/mint/swap.rs b/crates/cdk/src/mint/swap.rs index 16a72afe..682a052f 100644 --- a/crates/cdk/src/mint/swap.rs +++ b/crates/cdk/src/mint/swap.rs @@ -6,7 +6,7 @@ use crate::nuts::nut00::ProofsMethods; use crate::Error; use super::nut11::{enforce_sig_flag, EnforceSigFlag}; -use super::{Id, Mint, PublicKey, SigFlag, State, SwapRequest, SwapResponse}; +use super::{Id, Mint, ProofState, PublicKey, SigFlag, State, SwapRequest, SwapResponse}; impl Mint { /// Process Swap @@ -166,6 +166,17 @@ impl Mint { .update_proofs_states(&input_ys, State::Spent) .await?; + for pub_key in input_ys { + self.pubsub_manager.broadcast( + ProofState { + y: pub_key, + state: State::Spent, + witness: None, + } + .into(), + ); + } + self.localstore .add_blind_signatures( &swap_request diff --git a/crates/cdk/src/nuts/mod.rs b/crates/cdk/src/nuts/mod.rs index 07518bff..79bfb0d8 100644 --- a/crates/cdk/src/nuts/mod.rs +++ b/crates/cdk/src/nuts/mod.rs @@ -18,6 +18,7 @@ pub mod nut12; pub mod nut13; pub mod nut14; pub mod nut15; +pub mod nut17; pub mod nut18; pub use nut00::{ @@ -47,4 +48,5 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions}; pub use nut12::{BlindSignatureDleq, ProofDleq}; pub use nut14::HTLCWitness; pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings}; +pub use nut17::{NotificationPayload, PubSubManager}; pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport}; diff --git a/crates/cdk/src/nuts/nut06.rs b/crates/cdk/src/nuts/nut06.rs index 358cc4ff..95bc3c6c 100644 --- a/crates/cdk/src/nuts/nut06.rs +++ b/crates/cdk/src/nuts/nut06.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::nut01::PublicKey; -use super::{nut04, nut05, nut15, MppMethodSettings}; +use super::{nut04, nut05, nut15, nut17, MppMethodSettings}; /// Mint Version #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -232,6 +232,10 @@ pub struct Nuts { #[serde(default)] #[serde(rename = "15")] pub nut15: nut15::Settings, + /// NUT17 Settings + #[serde(default)] + #[serde(rename = "17")] + pub nut17: nut17::SupportedSettings, } impl Nuts { diff --git a/crates/cdk/src/nuts/nut17.rs b/crates/cdk/src/nuts/nut17.rs new file mode 100644 index 00000000..089fae9a --- /dev/null +++ b/crates/cdk/src/nuts/nut17.rs @@ -0,0 +1,358 @@ +//! Specific Subscription for the cdk crate + +#[cfg(feature = "mint")] +use crate::mint::{MeltQuote, MintQuote}; +use crate::{ + nuts::{ + MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState, + ProofState, + }, + pub_sub::{self, Index, Indexable, SubscriptionGlobalId}, +}; +use serde::{Deserialize, Serialize}; +use std::ops::Deref; + +/// Subscription Parameter according to the standard +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Params { + /// Kind + pub kind: Kind, + /// Filters + pub filters: Vec, + /// Subscription Id + #[serde(rename = "subId")] + pub id: SubId, +} + +/// Check state Settings +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct SupportedSettings { + supported: Vec, +} + +impl Default for SupportedSettings { + fn default() -> Self { + SupportedSettings { + supported: vec![SupportedMethods::default()], + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +struct SupportedMethods { + method: PaymentMethod, + unit: CurrencyUnit, + commands: Vec, +} + +impl Default for SupportedMethods { + fn default() -> Self { + SupportedMethods { + method: PaymentMethod::Bolt11, + unit: CurrencyUnit::Sat, + commands: vec![ + "bolt11_mint_quote".to_owned(), + "bolt11_melt_quote".to_owned(), + "proof_state".to_owned(), + ], + } + } +} + +pub use crate::pub_sub::SubId; + +use super::{BlindSignature, CurrencyUnit, PaymentMethod}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +/// Subscription response +pub enum NotificationPayload { + /// Proof State + ProofState(ProofState), + /// Melt Quote Bolt11 Response + MeltQuoteBolt11Response(MeltQuoteBolt11Response), + /// Mint Quote Bolt11 Response + MintQuoteBolt11Response(MintQuoteBolt11Response), +} + +impl From for NotificationPayload { + fn from(proof_state: ProofState) -> NotificationPayload { + NotificationPayload::ProofState(proof_state) + } +} + +impl From for NotificationPayload { + fn from(melt_quote: MeltQuoteBolt11Response) -> NotificationPayload { + NotificationPayload::MeltQuoteBolt11Response(melt_quote) + } +} + +impl From for NotificationPayload { + fn from(mint_quote: MintQuoteBolt11Response) -> NotificationPayload { + NotificationPayload::MintQuoteBolt11Response(mint_quote) + } +} + +impl Indexable for NotificationPayload { + type Type = (String, Kind); + + fn to_indexes(&self) -> Vec> { + match self { + NotificationPayload::ProofState(proof_state) => { + vec![Index::from((proof_state.y.to_hex(), Kind::ProofState))] + } + NotificationPayload::MeltQuoteBolt11Response(melt_quote) => { + vec![Index::from(( + melt_quote.quote.clone(), + Kind::Bolt11MeltQuote, + ))] + } + NotificationPayload::MintQuoteBolt11Response(mint_quote) => { + vec![Index::from(( + mint_quote.quote.clone(), + Kind::Bolt11MintQuote, + ))] + } + } + } +} + +#[derive(Debug, Clone, Copy, Eq, Ord, PartialOrd, PartialEq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] + +/// Kind +pub enum Kind { + /// Bolt 11 Melt Quote + Bolt11MeltQuote, + /// Bolt 11 Mint Quote + Bolt11MintQuote, + /// Proof State + ProofState, +} + +impl AsRef for Params { + fn as_ref(&self) -> &SubId { + &self.id + } +} + +impl From for Vec> { + fn from(val: Params) -> Self { + let sub_id: SubscriptionGlobalId = Default::default(); + val.filters + .iter() + .map(|filter| Index::from(((filter.clone(), val.kind), val.id.clone(), sub_id))) + .collect() + } +} + +/// Manager +#[derive(Default)] +/// Publish–subscribe manager +/// +/// Nut-17 implementation is system-wide and not only through the WebSocket, so +/// it is possible for another part of the system to subscribe to events. +pub struct PubSubManager(pub_sub::Manager); + +impl Deref for PubSubManager { + type Target = pub_sub::Manager; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(feature = "mint")] +impl From<&MintQuote> for MintQuoteBolt11Response { + fn from(mint_quote: &MintQuote) -> MintQuoteBolt11Response { + MintQuoteBolt11Response { + quote: mint_quote.id.clone(), + request: mint_quote.request.clone(), + state: mint_quote.state, + expiry: Some(mint_quote.expiry), + } + } +} + +#[cfg(feature = "mint")] +impl From<&MeltQuote> for MeltQuoteBolt11Response { + fn from(melt_quote: &MeltQuote) -> MeltQuoteBolt11Response { + MeltQuoteBolt11Response { + quote: melt_quote.id.clone(), + payment_preimage: None, + change: None, + state: melt_quote.state, + paid: Some(melt_quote.state == MeltQuoteState::Paid), + expiry: melt_quote.expiry, + amount: melt_quote.amount, + fee_reserve: melt_quote.fee_reserve, + } + } +} + +impl PubSubManager { + /// Helper function to emit a MintQuoteBolt11Response status + pub fn mint_quote_bolt11_status>( + &self, + quote: E, + new_state: MintQuoteState, + ) { + let mut event = quote.into(); + event.state = new_state; + + self.broadcast(event.into()); + } + + /// Helper function to emit a MeltQuoteBolt11Response status + pub fn melt_quote_status>( + &self, + quote: E, + payment_preimage: Option, + change: Option>, + new_state: MeltQuoteState, + ) { + let mut quote = quote.into(); + quote.state = new_state; + quote.paid = Some(new_state == MeltQuoteState::Paid); + quote.payment_preimage = payment_preimage; + quote.change = change; + self.broadcast(quote.into()); + } +} + +#[cfg(test)] +mod test { + use crate::nuts::{PublicKey, State}; + + use super::*; + use std::time::Duration; + use tokio::time::sleep; + + #[tokio::test] + async fn active_and_drop() { + let manager = PubSubManager::default(); + let params = Params { + kind: Kind::ProofState, + filters: vec!["x".to_string()], + id: "uno".into(), + }; + + // Although the same param is used, two subscriptions are created, that + // is because each index is unique, thanks to `Unique`, it is the + // responsibility of the implementor to make sure that SubId are unique + // either globally or per client + let subscriptions = vec![ + manager.subscribe(params.clone()).await, + manager.subscribe(params).await, + ]; + assert_eq!(2, manager.active_subscriptions()); + drop(subscriptions); + + sleep(Duration::from_millis(10)).await; + + assert_eq!(0, manager.active_subscriptions()); + } + + #[tokio::test] + async fn broadcast() { + let manager = PubSubManager::default(); + let mut subscriptions = [ + manager + .subscribe(Params { + kind: Kind::ProofState, + filters: vec![ + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + .to_string(), + ], + id: "uno".into(), + }) + .await, + manager + .subscribe(Params { + kind: Kind::ProofState, + filters: vec![ + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + .to_string(), + ], + id: "dos".into(), + }) + .await, + ]; + + let event = ProofState { + y: PublicKey::from_hex( + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + }; + + manager.broadcast(event.into()); + + sleep(Duration::from_millis(10)).await; + + let (sub1, _) = subscriptions[0].try_recv().expect("valid message"); + assert_eq!("uno", *sub1); + + let (sub1, _) = subscriptions[1].try_recv().expect("valid message"); + assert_eq!("dos", *sub1); + + assert!(subscriptions[0].try_recv().is_err()); + assert!(subscriptions[1].try_recv().is_err()); + } + + #[test] + fn parsing_request() { + let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#; + let params: Params = serde_json::from_str(json).expect("valid json"); + assert_eq!(params.kind, Kind::ProofState); + assert_eq!(params.filters, vec!["x"]); + assert_eq!(*params.id, "uno"); + } + + #[tokio::test] + async fn json_test() { + let manager = PubSubManager::default(); + let mut subscription = manager + .subscribe::( + serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#) + .expect("valid json"), + ) + .await; + + manager.broadcast( + ProofState { + y: PublicKey::from_hex( + "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + } + .into(), + ); + + // no one is listening for this event + manager.broadcast( + ProofState { + y: PublicKey::from_hex( + "020000000000000000000000000000000000000000000000000000000000000001", + ) + .expect("valid pk"), + state: State::Pending, + witness: None, + } + .into(), + ); + + sleep(Duration::from_millis(10)).await; + let (sub1, msg) = subscription.try_recv().expect("valid message"); + assert_eq!("uno", *sub1); + assert_eq!( + r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#, + serde_json::to_string(&msg).expect("valid json") + ); + assert!(subscription.try_recv().is_err()); + } +} diff --git a/crates/cdk/src/pub_sub/index.rs b/crates/cdk/src/pub_sub/index.rs new file mode 100644 index 00000000..f255617e --- /dev/null +++ b/crates/cdk/src/pub_sub/index.rs @@ -0,0 +1,160 @@ +use super::SubId; +use std::{ + fmt::Debug, + ops::Deref, + sync::atomic::{AtomicUsize, Ordering}, +}; + +/// Indexable trait +pub trait Indexable { + /// The type of the index, it is unknown and it is up to the Manager's + /// generic type + type Type: PartialOrd + Ord + Send + Sync + Debug; + + /// To indexes + fn to_indexes(&self) -> Vec>; +} + +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone)] +/// Index +/// +/// The Index is a sorted structure that is used to quickly find matches +/// +/// The counter is used to make sure each Index is unique, even if the prefix +/// are the same, and also to make sure that earlier indexes matches first +pub struct Index +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + prefix: T, + counter: SubscriptionGlobalId, + id: super::SubId, +} + +impl From<&Index> for super::SubId +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + fn from(val: &Index) -> Self { + val.id.clone() + } +} + +impl Deref for Index +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.prefix + } +} + +impl Index +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + /// Compare the + pub fn cmp_prefix(&self, other: &Index) -> std::cmp::Ordering { + self.prefix.cmp(&other.prefix) + } + + /// Returns a globaly unique id for the Index + pub fn unique_id(&self) -> usize { + self.counter.0 + } +} + +impl From<(T, SubId, SubscriptionGlobalId)> for Index +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + fn from((prefix, id, counter): (T, SubId, SubscriptionGlobalId)) -> Self { + Self { + prefix, + id, + counter, + } + } +} + +impl From<(T, SubId)> for Index +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + fn from((prefix, id): (T, SubId)) -> Self { + Self { + prefix, + id, + counter: Default::default(), + } + } +} + +impl From for Index +where + T: PartialOrd + Ord + Send + Sync + Debug, +{ + fn from(prefix: T) -> Self { + Self { + prefix, + id: Default::default(), + counter: SubscriptionGlobalId(0), + } + } +} + +static COUNTER: AtomicUsize = AtomicUsize::new(0); + +/// Dummy type +/// +/// This is only use so each Index is unique, with the same prefix. +/// +/// The prefix is used to leverage the BTree to find things quickly, but each +/// entry/key must be unique, so we use this dummy type to make sure each Index +/// is unique. +/// +/// Unique is also used to make sure that the indexes are sorted by creation order +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Copy)] +pub struct SubscriptionGlobalId(usize); + +impl Default for SubscriptionGlobalId { + fn default() -> Self { + Self(COUNTER.fetch_add(1, Ordering::Relaxed)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_index_from_tuple() { + let sub_id = SubId::from("test_sub_id"); + let prefix = "test_prefix"; + let index: Index<&str> = Index::from((prefix, sub_id.clone())); + assert_eq!(index.prefix, "test_prefix"); + assert_eq!(index.id, sub_id); + } + + #[test] + fn test_index_cmp_prefix() { + let sub_id = SubId::from("test_sub_id"); + let index1: Index<&str> = Index::from(("a", sub_id.clone())); + let index2: Index<&str> = Index::from(("b", sub_id.clone())); + assert_eq!(index1.cmp_prefix(&index2), std::cmp::Ordering::Less); + } + + #[test] + fn test_sub_id_from_str() { + let sub_id = SubId::from("test_sub_id"); + assert_eq!(sub_id.0, "test_sub_id"); + } + + #[test] + fn test_sub_id_deref() { + let sub_id = SubId::from("test_sub_id"); + assert_eq!(&*sub_id, "test_sub_id"); + } +} diff --git a/crates/cdk/src/pub_sub/mod.rs b/crates/cdk/src/pub_sub/mod.rs new file mode 100644 index 00000000..c2952938 --- /dev/null +++ b/crates/cdk/src/pub_sub/mod.rs @@ -0,0 +1,311 @@ +//! Publish–subscribe pattern. +//! +//! This is a generic implementation for +//! [NUT-17(https://github.com/cashubtc/nuts/blob/main/17.md) with a type +//! agnostic Publish-subscribe manager. +//! +//! The manager has a method for subscribers to subscribe to events with a +//! generic type that must be converted to a vector of indexes. +//! +//! Events are also generic that should implement the `Indexable` trait. +use serde::{Deserialize, Serialize}; +use std::{ + cmp::Ordering, + collections::{BTreeMap, HashSet}, + fmt::Debug, + ops::{Deref, DerefMut}, + str::FromStr, + sync::{ + atomic::{self, AtomicUsize}, + Arc, + }, +}; +use tokio::{ + sync::{mpsc, RwLock}, + task::JoinHandle, +}; + +mod index; + +pub use index::{Index, Indexable, SubscriptionGlobalId}; + +type IndexTree = Arc, mpsc::Sender<(SubId, T)>>>>; + +/// Default size of the remove channel +pub const DEFAULT_REMOVE_SIZE: usize = 10_000; + +/// Default channel size for subscription buffering +pub const DEFAULT_CHANNEL_SIZE: usize = 10; + +/// Subscription manager +/// +/// This object keep track of all subscription listener and it is also +/// responsible for broadcasting events to all listeners +/// +/// The content of the notification is not relevant to this scope and it is up +/// to the application, therefore the generic T is used instead of a specific +/// type +pub struct Manager +where + T: Indexable + Clone + Send + Sync + 'static, + I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, +{ + indexes: IndexTree, + unsubscription_sender: mpsc::Sender<(SubId, Vec>)>, + active_subscriptions: Arc, + background_subscription_remover: Option>, +} + +impl Default for Manager +where + T: Indexable + Clone + Send + Sync + 'static, + I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, +{ + fn default() -> Self { + let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE); + let active_subscriptions: Arc = Default::default(); + let storage: IndexTree = Arc::new(Default::default()); + + Self { + background_subscription_remover: Some(tokio::spawn(Self::remove_subscription( + receiver, + storage.clone(), + active_subscriptions.clone(), + ))), + unsubscription_sender: sender, + active_subscriptions, + indexes: storage, + } + } +} + +impl Manager +where + T: Indexable + Clone + Send + Sync + 'static, + I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, +{ + #[inline] + /// Broadcast an event to all listeners + /// + /// This function takes an Arc to the storage struct, the event_id, the kind + /// and the vent to broadcast + async fn broadcast_impl(storage: &IndexTree, event: T) { + let index_storage = storage.read().await; + let mut sent = HashSet::new(); + for index in event.to_indexes() { + for (key, sender) in index_storage.range(index.clone()..) { + if index.cmp_prefix(key) != Ordering::Equal { + break; + } + let sub_id = key.unique_id(); + if sent.contains(&sub_id) { + continue; + } + sent.insert(sub_id); + let _ = sender.try_send((key.into(), event.clone())); + } + } + } + + /// Broadcasts an event to all listeners + /// + /// This public method will not block the caller, it will spawn a new task + /// instead + pub fn broadcast(&self, event: T) { + let storage = self.indexes.clone(); + tokio::spawn(async move { + Self::broadcast_impl(&storage, event).await; + }); + } + + /// Broadcasts an event to all listeners + /// + /// This method is async and will await for the broadcast to be completed + pub async fn broadcast_async(&self, event: T) { + Self::broadcast_impl(&self.indexes, event).await; + } + + /// Subscribe to a specific event + pub async fn subscribe + Into>>>( + &self, + params: P, + ) -> ActiveSubscription { + let (sender, receiver) = mpsc::channel(10); + let sub_id: SubId = params.as_ref().clone(); + let indexes: Vec> = params.into(); + + let mut index_storage = self.indexes.write().await; + for index in indexes.clone() { + index_storage.insert(index, sender.clone()); + } + drop(index_storage); + + self.active_subscriptions + .fetch_add(1, atomic::Ordering::Relaxed); + + ActiveSubscription { + sub_id, + receiver, + indexes, + drop: self.unsubscription_sender.clone(), + } + } + + /// Return number of active subscriptions + pub fn active_subscriptions(&self) -> usize { + self.active_subscriptions.load(atomic::Ordering::SeqCst) + } + + /// Task to remove dropped subscriptions from the storage struct + /// + /// This task will run in the background (and will dropped when the Manager + /// is ) and will remove subscriptions from the storage struct it is dropped. + async fn remove_subscription( + mut receiver: mpsc::Receiver<(SubId, Vec>)>, + storage: IndexTree, + active_subscriptions: Arc, + ) { + while let Some((sub_id, indexes)) = receiver.recv().await { + tracing::info!("Removing subscription: {}", *sub_id); + + active_subscriptions.fetch_sub(1, atomic::Ordering::AcqRel); + + let mut index_storage = storage.write().await; + for key in indexes { + index_storage.remove(&key); + } + drop(index_storage); + } + } +} + +/// Manager goes out of scope, stop all background tasks +impl Drop for Manager +where + T: Indexable + Clone + Send + Sync + 'static, + I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, +{ + fn drop(&mut self) { + if let Some(handler) = self.background_subscription_remover.take() { + handler.abort(); + } + } +} + +/// Active Subscription +/// +/// This struct is a wrapper around the mpsc::Receiver and it also used +/// to keep track of the subscription itself. When this struct goes out of +/// scope, it will notify the Manager about it, so it can be removed from the +/// list of active listeners +pub struct ActiveSubscription +where + T: Send + Sync, + I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, +{ + /// The subscription ID + pub sub_id: SubId, + indexes: Vec>, + receiver: mpsc::Receiver<(SubId, T)>, + drop: mpsc::Sender<(SubId, Vec>)>, +} + +impl Deref for ActiveSubscription +where + T: Send + Sync, + I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, +{ + type Target = mpsc::Receiver<(SubId, T)>; + + fn deref(&self) -> &Self::Target { + &self.receiver + } +} + +impl DerefMut for ActiveSubscription +where + T: Indexable + Clone + Send + Sync + 'static, + I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.receiver + } +} + +/// The ActiveSubscription is Drop out of scope, notify the Manager about it, so +/// it can be removed from the list of active listeners +/// +/// Having this in place, we can avoid memory leaks and also makes it super +/// simple to implement the Unsubscribe method +impl Drop for ActiveSubscription +where + T: Send + Sync, + I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, +{ + fn drop(&mut self) { + let _ = self + .drop + .try_send((self.sub_id.clone(), self.indexes.drain(..).collect())); + } +} + +/// Subscription Id wrapper +/// +/// This is the place to add some sane default (like a max length) to the +/// subscription ID +#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub struct SubId(String); + +impl From<&str> for SubId { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl From for SubId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl FromStr for SubId { + type Err = (); + + fn from_str(s: &str) -> Result { + Ok(Self(s.to_string())) + } +} + +impl Deref for SubId { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(test)] +mod test { + use super::*; + use tokio::sync::mpsc; + + #[test] + fn test_active_subscription_drop() { + let (tx, rx) = mpsc::channel::<(SubId, ())>(10); + let sub_id = SubId::from("test_sub_id"); + let indexes: Vec> = vec![Index::from(("test".to_string(), sub_id.clone()))]; + let (drop_tx, mut drop_rx) = mpsc::channel(10); + + { + let _active_subscription = ActiveSubscription { + sub_id: sub_id.clone(), + indexes, + receiver: rx, + drop: drop_tx, + }; + // When it goes out of scope, it should notify + } + assert_eq!(drop_rx.try_recv().unwrap().0, sub_id); // it should have notified + assert!(tx.try_send(("foo".into(), ())).is_err()); // subscriber is dropped + } +}