Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mint Websockets (NUT-17) #394

Merged
merged 10 commits into from
Nov 6, 2024
25 changes: 17 additions & 8 deletions crates/cdk-axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,30 @@ edition = "2021"
license = "MIT"
homepage = "https://github.com/cashubtc/cdk"
repository = "https://github.com/cashubtc/cdk.git"
rust-version = "1.63.0" # MSRV
rust-version = "1.63.0" # MSRV
description = "Cashu CDK axum webserver"

[dependencies]
anyhow = "1"
async-trait = "0.1"
axum = "0.6.20"
cdk = { path = "../cdk", version = "0.4.0", default-features = false, features = ["mint"] }
tokio = { version = "1", default-features = false }
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
utoipa = { version = "4", features = ["preserve_order", "preserve_path_order"], optional = true }
async-trait = "0.1.83"
axum = { version = "0.6.20", features = ["ws"] }
cdk = { path = "../cdk", version = "0.4.0", default-features = false, features = [
"mint",
] }
tokio = { version = "1", default-features = false, features = ["io-util"] }
tracing = { version = "0.1", default-features = false, features = [
"attributes",
"log",
] }
utoipa = { version = "4", features = [
"preserve_order",
"preserve_path_order",
], optional = true }
futures = { version = "0.3.28", default-features = false }
moka = { version = "0.11.1", features = ["future"] }
serde_json = "1"
paste = "1.0.15"
serde = { version = "1.0.210", features = ["derive"] }

[features]
swagger = ["cdk/swagger", "dep:utoipa"]
swagger = ["cdk/swagger", "dep:utoipa"]
2 changes: 2 additions & 0 deletions crates/cdk-axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use moka::future::Cache;
use router_handlers::*;

mod router_handlers;
mod ws;

#[cfg(feature = "swagger")]
mod swagger_imports {
Expand Down Expand Up @@ -154,6 +155,7 @@ pub async fn create_mint_router(mint: Arc<Mint>, cache_ttl: u64, cache_tti: u64)
)
.route("/mint/bolt11", post(cache_post_mint_bolt11))
.route("/melt/quote/bolt11", post(get_melt_bolt11_quote))
.route("/ws", get(ws_handler))
.route(
"/melt/quote/bolt11/:quote_id",
get(get_check_melt_bolt11_quote),
Expand Down
18 changes: 11 additions & 7 deletions crates/cdk-axum/src/router_handlers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use axum::extract::{Json, Path, State};
use axum::extract::{ws::WebSocketUpgrade, Json, Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use cdk::error::ErrorResponse;
Expand All @@ -13,7 +13,7 @@ use cdk::util::unix_time;
use cdk::Error;
use paste::paste;

use crate::MintState;
use crate::{ws::main_websocket, MintState};

macro_rules! post_cache_wrapper {
($handler:ident, $request_type:ty, $response_type:ty) => {
Expand Down Expand Up @@ -174,6 +174,15 @@ pub async fn get_check_mint_bolt11_quote(
Ok(Json(quote))
}

pub async fn ws_handler(State(state): State<MintState>, ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(|ws| main_websocket(ws, state))
}

/// Mint tokens by paying a BOLT11 Lightning invoice.
///
/// Requests the minting of tokens belonging to a paid payment request.
///
/// Call this endpoint after `POST /v1/mint/quote`.
#[cfg_attr(feature = "swagger", utoipa::path(
post,
context_path = "/v1",
Expand All @@ -184,11 +193,6 @@ pub async fn get_check_mint_bolt11_quote(
(status = 500, description = "Server error", body = ErrorResponse, content_type = "application/json")
)
))]
/// Mint tokens by paying a BOLT11 Lightning invoice.
///
/// Requests the minting of tokens belonging to a paid payment request.
///
/// Call this endpoint after `POST /v1/mint/quote`.
pub async fn post_mint_bolt11(
State(state): State<MintState>,
Json(payload): Json<MintBolt11Request>,
Expand Down
19 changes: 19 additions & 0 deletions crates/cdk-axum/src/ws/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
/// Source: https://www.jsonrpc.org/specification#error_object
pub enum WsError {
/// Invalid JSON was received by the server.
/// An error occurred on the server while parsing the JSON text.
ParseError,
/// The JSON sent is not a valid Request object.
InvalidRequest,
/// The method does not exist / is not available.
MethodNotFound,
/// Invalid method parameter(s).
InvalidParams,
/// Internal JSON-RPC error.
InternalError,
/// Custom error
ServerError(i32, String),
}
70 changes: 70 additions & 0 deletions crates/cdk-axum/src/ws/handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use super::{WsContext, WsError, JSON_RPC_VERSION};
use serde::Serialize;

impl From<WsError> for WsErrorResponse {
fn from(val: WsError) -> Self {
let (id, message) = match val {
WsError::ParseError => (-32700, "Parse error".to_string()),
WsError::InvalidRequest => (-32600, "Invalid Request".to_string()),
WsError::MethodNotFound => (-32601, "Method not found".to_string()),
WsError::InvalidParams => (-32602, "Invalid params".to_string()),
WsError::InternalError => (-32603, "Internal error".to_string()),
WsError::ServerError(code, message) => (code, message),
};
WsErrorResponse { code: id, message }
}
}

#[derive(Debug, Clone, Serialize)]
struct WsErrorResponse {
code: i32,
message: String,
}

#[derive(Debug, Clone, Serialize)]
struct WsResponse<T: Serialize + Sized> {
jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<T>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<WsErrorResponse>,
id: usize,
}

#[derive(Debug, Clone, Serialize)]
pub struct WsNotification<T> {
pub jsonrpc: String,
pub method: String,
pub params: T,
}

#[async_trait::async_trait]
pub trait WsHandle {
type Response: Serialize + Sized;

async fn process(
self,
req_id: usize,
context: &mut WsContext,
) -> Result<serde_json::Value, serde_json::Error>
where
Self: Sized,
{
serde_json::to_value(&match self.handle(context).await {
Ok(response) => WsResponse {
jsonrpc: JSON_RPC_VERSION.to_owned(),
result: Some(response),
error: None,
id: req_id,
},
Err(error) => WsResponse {
jsonrpc: JSON_RPC_VERSION.to_owned(),
result: None,
error: Some(error.into()),
id: req_id,
},
})
}

async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError>;
}
121 changes: 121 additions & 0 deletions crates/cdk-axum/src/ws/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use crate::MintState;
use axum::extract::ws::{Message, WebSocket};
use cdk::nuts::nut17::{NotificationPayload, SubId};
use futures::StreamExt;
use handler::{WsHandle, WsNotification};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use subscribe::Notification;
use tokio::sync::mpsc;

mod error;
mod handler;
mod subscribe;
mod unsubscribe;

/// JSON RPC version
pub const JSON_RPC_VERSION: &str = "2.0";

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsRequest {
jsonrpc: String,
#[serde(flatten)]
method: WsMethod,
id: usize,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "method", content = "params")]
pub enum WsMethod {
Subscribe(subscribe::Method),
Unsubscribe(unsubscribe::Method),
}

impl WsMethod {
pub async fn process(
self,
req_id: usize,
context: &mut WsContext,
) -> Result<serde_json::Value, serde_json::Error> {
match self {
WsMethod::Subscribe(sub) => sub.process(req_id, context),
WsMethod::Unsubscribe(unsub) => unsub.process(req_id, context),
}
.await
}
}

pub use error::WsError;

pub struct WsContext {
state: MintState,
subscriptions: HashMap<SubId, tokio::task::JoinHandle<()>>,
publisher: mpsc::Sender<(SubId, NotificationPayload)>,
}

/// Main function for websocket connections
///
/// This function will handle all incoming websocket connections and keep them in their own loop.
///
/// For simplicity sake this function will spawn tasks for each subscription and
/// keep them in a hashmap, and will have a single subscriber for all of them.
#[allow(clippy::incompatible_msrv)]
pub async fn main_websocket(mut socket: WebSocket, state: MintState) {
let (publisher, mut subscriber) = mpsc::channel(100);
let mut context = WsContext {
state,
subscriptions: HashMap::new(),
publisher,
};

loop {
tokio::select! {
Some((sub_id, payload)) = subscriber.recv() => {
if !context.subscriptions.contains_key(&sub_id) {
// It may be possible an incoming message has come from a dropped Subscriptions that has not yet been
// unsubscribed from the subscription manager, just ignore it.
continue;
}
let notification: WsNotification<Notification> = (sub_id, payload).into();
let message = if let Ok(message) = serde_json::to_string(&notification) {
message
} else {
tracing::error!("Could not serialize notification");
continue;
};

if socket.send(Message::Text(message)).await.is_err() {
break;
}
crodas marked this conversation as resolved.
Show resolved Hide resolved
}
Some(Ok(Message::Text(text))) = socket.next() => {
let request = match serde_json::from_str::<WsRequest>(&text) {
Ok(request) => request,
Err(err) => {
tracing::error!("Could not parse request: {}", err);
continue;
}
};

match request.method.process(request.id, &mut context).await {
Ok(result) => {
if socket
.send(Message::Text(result.to_string()))
.await
.is_err()
{
break;
}
crodas marked this conversation as resolved.
Show resolved Hide resolved
}
Err(err) => {
tracing::error!("Error serializing response: {}", err);
break;
}
}
}
else => {

}
}
}
}
61 changes: 61 additions & 0 deletions crates/cdk-axum/src/ws/subscribe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use super::{
handler::{WsHandle, WsNotification},
WsContext, WsError, JSON_RPC_VERSION,
};
use cdk::{
nuts::nut17::{NotificationPayload, Params},
pub_sub::SubId,
};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Method(Params);

#[derive(Debug, Clone, serde::Serialize)]
pub struct Response {
status: String,
sub_id: SubId,
crodas marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Debug, Clone, serde::Serialize)]
pub struct Notification {
#[serde(rename = "subId")]
pub sub_id: SubId,

pub payload: NotificationPayload,
}

impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
fn from((sub_id, payload): (SubId, NotificationPayload)) -> Self {
WsNotification {
jsonrpc: JSON_RPC_VERSION.to_owned(),
method: "subscribe".to_string(),
params: Notification { sub_id, payload },
}
}
}

#[async_trait::async_trait]
impl WsHandle for Method {
type Response = Response;

async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
let sub_id = self.0.id.clone();
if context.subscriptions.contains_key(&sub_id) {
return Err(WsError::InvalidParams);
}
let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;
let publisher = context.publisher.clone();
context.subscriptions.insert(
sub_id.clone(),
tokio::spawn(async move {
while let Some(response) = subscription.recv().await {
let _ = publisher.send(response).await;
}
}),
);
Ok(Response {
status: "OK".to_string(),
sub_id,
})
}
}
Loading
Loading