Skip to content

Commit

Permalink
Mint Websockets (NUT-17) (#394)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: thesimplekid <[email protected]>
  • Loading branch information
crodas and thesimplekid authored Nov 6, 2024
1 parent 479b4e7 commit 6973e53
Show file tree
Hide file tree
Showing 26 changed files with 1,384 additions and 46 deletions.
25 changes: 17 additions & 8 deletions crates/cdk-axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,30 @@ edition = "2021"
license = "MIT"
homepage = "https://github.com/cashubtc/cdk"
repository = "https://github.com/cashubtc/cdk.git"
rust-version = "1.63.0" # MSRV
rust-version = "1.63.0" # MSRV
description = "Cashu CDK axum webserver"

[dependencies]
anyhow = "1"
async-trait = "0.1"
axum = "0.6.20"
cdk = { path = "../cdk", version = "0.4.0", default-features = false, features = ["mint"] }
tokio = { version = "1", default-features = false }
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
utoipa = { version = "4", features = ["preserve_order", "preserve_path_order"], optional = true }
async-trait = "0.1.83"
axum = { version = "0.6.20", features = ["ws"] }
cdk = { path = "../cdk", version = "0.4.0", default-features = false, features = [
"mint",
] }
tokio = { version = "1", default-features = false, features = ["io-util"] }
tracing = { version = "0.1", default-features = false, features = [
"attributes",
"log",
] }
utoipa = { version = "4", features = [
"preserve_order",
"preserve_path_order",
], optional = true }
futures = { version = "0.3.28", default-features = false }
moka = { version = "0.11.1", features = ["future"] }
serde_json = "1"
paste = "1.0.15"
serde = { version = "1.0.210", features = ["derive"] }

[features]
swagger = ["cdk/swagger", "dep:utoipa"]
swagger = ["cdk/swagger", "dep:utoipa"]
2 changes: 2 additions & 0 deletions crates/cdk-axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use moka::future::Cache;
use router_handlers::*;

mod router_handlers;
mod ws;

#[cfg(feature = "swagger")]
mod swagger_imports {
Expand Down Expand Up @@ -154,6 +155,7 @@ pub async fn create_mint_router(mint: Arc<Mint>, cache_ttl: u64, cache_tti: u64)
)
.route("/mint/bolt11", post(cache_post_mint_bolt11))
.route("/melt/quote/bolt11", post(post_melt_bolt11_quote))
.route("/ws", get(ws_handler))
.route(
"/melt/quote/bolt11/:quote_id",
get(get_check_melt_bolt11_quote),
Expand Down
18 changes: 11 additions & 7 deletions crates/cdk-axum/src/router_handlers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use axum::extract::{Json, Path, State};
use axum::extract::{ws::WebSocketUpgrade, Json, Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use cdk::error::ErrorResponse;
Expand All @@ -13,7 +13,7 @@ use cdk::util::unix_time;
use cdk::Error;
use paste::paste;

use crate::MintState;
use crate::{ws::main_websocket, MintState};

macro_rules! post_cache_wrapper {
($handler:ident, $request_type:ty, $response_type:ty) => {
Expand Down Expand Up @@ -174,6 +174,15 @@ pub async fn get_check_mint_bolt11_quote(
Ok(Json(quote))
}

pub async fn ws_handler(State(state): State<MintState>, ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(|ws| main_websocket(ws, state))
}

/// Mint tokens by paying a BOLT11 Lightning invoice.
///
/// Requests the minting of tokens belonging to a paid payment request.
///
/// Call this endpoint after `POST /v1/mint/quote`.
#[cfg_attr(feature = "swagger", utoipa::path(
post,
context_path = "/v1",
Expand All @@ -184,11 +193,6 @@ pub async fn get_check_mint_bolt11_quote(
(status = 500, description = "Server error", body = ErrorResponse, content_type = "application/json")
)
))]
/// Mint tokens by paying a BOLT11 Lightning invoice.
///
/// Requests the minting of tokens belonging to a paid payment request.
///
/// Call this endpoint after `POST /v1/mint/quote`.
pub async fn post_mint_bolt11(
State(state): State<MintState>,
Json(payload): Json<MintBolt11Request>,
Expand Down
19 changes: 19 additions & 0 deletions crates/cdk-axum/src/ws/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
/// Source: https://www.jsonrpc.org/specification#error_object
pub enum WsError {
/// Invalid JSON was received by the server.
/// An error occurred on the server while parsing the JSON text.
ParseError,
/// The JSON sent is not a valid Request object.
InvalidRequest,
/// The method does not exist / is not available.
MethodNotFound,
/// Invalid method parameter(s).
InvalidParams,
/// Internal JSON-RPC error.
InternalError,
/// Custom error
ServerError(i32, String),
}
70 changes: 70 additions & 0 deletions crates/cdk-axum/src/ws/handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use super::{WsContext, WsError, JSON_RPC_VERSION};
use serde::Serialize;

impl From<WsError> for WsErrorResponse {
fn from(val: WsError) -> Self {
let (id, message) = match val {
WsError::ParseError => (-32700, "Parse error".to_string()),
WsError::InvalidRequest => (-32600, "Invalid Request".to_string()),
WsError::MethodNotFound => (-32601, "Method not found".to_string()),
WsError::InvalidParams => (-32602, "Invalid params".to_string()),
WsError::InternalError => (-32603, "Internal error".to_string()),
WsError::ServerError(code, message) => (code, message),
};
WsErrorResponse { code: id, message }
}
}

#[derive(Debug, Clone, Serialize)]
struct WsErrorResponse {
code: i32,
message: String,
}

#[derive(Debug, Clone, Serialize)]
struct WsResponse<T: Serialize + Sized> {
jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<T>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<WsErrorResponse>,
id: usize,
}

#[derive(Debug, Clone, Serialize)]
pub struct WsNotification<T> {
pub jsonrpc: String,
pub method: String,
pub params: T,
}

#[async_trait::async_trait]
pub trait WsHandle {
type Response: Serialize + Sized;

async fn process(
self,
req_id: usize,
context: &mut WsContext,
) -> Result<serde_json::Value, serde_json::Error>
where
Self: Sized,
{
serde_json::to_value(&match self.handle(context).await {
Ok(response) => WsResponse {
jsonrpc: JSON_RPC_VERSION.to_owned(),
result: Some(response),
error: None,
id: req_id,
},
Err(error) => WsResponse {
jsonrpc: JSON_RPC_VERSION.to_owned(),
result: None,
error: Some(error.into()),
id: req_id,
},
})
}

async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError>;
}
123 changes: 123 additions & 0 deletions crates/cdk-axum/src/ws/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use crate::MintState;
use axum::extract::ws::{Message, WebSocket};
use cdk::nuts::nut17::{NotificationPayload, SubId};
use futures::StreamExt;
use handler::{WsHandle, WsNotification};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use subscribe::Notification;
use tokio::sync::mpsc;

mod error;
mod handler;
mod subscribe;
mod unsubscribe;

/// JSON RPC version
pub const JSON_RPC_VERSION: &str = "2.0";

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsRequest {
jsonrpc: String,
#[serde(flatten)]
method: WsMethod,
id: usize,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "method", content = "params")]
pub enum WsMethod {
Subscribe(subscribe::Method),
Unsubscribe(unsubscribe::Method),
}

impl WsMethod {
pub async fn process(
self,
req_id: usize,
context: &mut WsContext,
) -> Result<serde_json::Value, serde_json::Error> {
match self {
WsMethod::Subscribe(sub) => sub.process(req_id, context),
WsMethod::Unsubscribe(unsub) => unsub.process(req_id, context),
}
.await
}
}

pub use error::WsError;

pub struct WsContext {
state: MintState,
subscriptions: HashMap<SubId, tokio::task::JoinHandle<()>>,
publisher: mpsc::Sender<(SubId, NotificationPayload)>,
}

/// Main function for websocket connections
///
/// This function will handle all incoming websocket connections and keep them in their own loop.
///
/// For simplicity sake this function will spawn tasks for each subscription and
/// keep them in a hashmap, and will have a single subscriber for all of them.
#[allow(clippy::incompatible_msrv)]
pub async fn main_websocket(mut socket: WebSocket, state: MintState) {
let (publisher, mut subscriber) = mpsc::channel(100);
let mut context = WsContext {
state,
subscriptions: HashMap::new(),
publisher,
};

loop {
tokio::select! {
Some((sub_id, payload)) = subscriber.recv() => {
if !context.subscriptions.contains_key(&sub_id) {
// It may be possible an incoming message has come from a dropped Subscriptions that has not yet been
// unsubscribed from the subscription manager, just ignore it.
continue;
}
let notification: WsNotification<Notification> = (sub_id, payload).into();
let message = match serde_json::to_string(&notification) {
Ok(message) => message,
Err(err) => {
tracing::error!("Could not serialize notification: {}", err);
continue;
}
};

if let Err(err)= socket.send(Message::Text(message)).await {
tracing::error!("Could not send websocket message: {}", err);
break;
}
}
Some(Ok(Message::Text(text))) = socket.next() => {
let request = match serde_json::from_str::<WsRequest>(&text) {
Ok(request) => request,
Err(err) => {
tracing::error!("Could not parse request: {}", err);
continue;
}
};

match request.method.process(request.id, &mut context).await {
Ok(result) => {
if let Err(err) = socket
.send(Message::Text(result.to_string()))
.await
{
tracing::error!("Could not send request: {}", err);
break;
}
}
Err(err) => {
tracing::error!("Error serializing response: {}", err);
break;
}
}
}
else => {

}
}
}
}
62 changes: 62 additions & 0 deletions crates/cdk-axum/src/ws/subscribe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use super::{
handler::{WsHandle, WsNotification},
WsContext, WsError, JSON_RPC_VERSION,
};
use cdk::{
nuts::nut17::{NotificationPayload, Params},
pub_sub::SubId,
};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Method(Params);

#[derive(Debug, Clone, serde::Serialize)]
pub struct Response {
status: String,
#[serde(rename = "subId")]
sub_id: SubId,
}

#[derive(Debug, Clone, serde::Serialize)]
pub struct Notification {
#[serde(rename = "subId")]
pub sub_id: SubId,

pub payload: NotificationPayload,
}

impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
fn from((sub_id, payload): (SubId, NotificationPayload)) -> Self {
WsNotification {
jsonrpc: JSON_RPC_VERSION.to_owned(),
method: "subscribe".to_string(),
params: Notification { sub_id, payload },
}
}
}

#[async_trait::async_trait]
impl WsHandle for Method {
type Response = Response;

async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
let sub_id = self.0.id.clone();
if context.subscriptions.contains_key(&sub_id) {
return Err(WsError::InvalidParams);
}
let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;
let publisher = context.publisher.clone();
context.subscriptions.insert(
sub_id.clone(),
tokio::spawn(async move {
while let Some(response) = subscription.recv().await {
let _ = publisher.send(response).await;
}
}),
);
Ok(Response {
status: "OK".to_string(),
sub_id,
})
}
}
Loading

0 comments on commit 6973e53

Please sign in to comment.