-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
--------- Co-authored-by: thesimplekid <[email protected]>
- Loading branch information
1 parent
479b4e7
commit 6973e53
Showing
26 changed files
with
1,384 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
/// Source: https://www.jsonrpc.org/specification#error_object | ||
pub enum WsError { | ||
/// Invalid JSON was received by the server. | ||
/// An error occurred on the server while parsing the JSON text. | ||
ParseError, | ||
/// The JSON sent is not a valid Request object. | ||
InvalidRequest, | ||
/// The method does not exist / is not available. | ||
MethodNotFound, | ||
/// Invalid method parameter(s). | ||
InvalidParams, | ||
/// Internal JSON-RPC error. | ||
InternalError, | ||
/// Custom error | ||
ServerError(i32, String), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
use super::{WsContext, WsError, JSON_RPC_VERSION}; | ||
use serde::Serialize; | ||
|
||
impl From<WsError> for WsErrorResponse { | ||
fn from(val: WsError) -> Self { | ||
let (id, message) = match val { | ||
WsError::ParseError => (-32700, "Parse error".to_string()), | ||
WsError::InvalidRequest => (-32600, "Invalid Request".to_string()), | ||
WsError::MethodNotFound => (-32601, "Method not found".to_string()), | ||
WsError::InvalidParams => (-32602, "Invalid params".to_string()), | ||
WsError::InternalError => (-32603, "Internal error".to_string()), | ||
WsError::ServerError(code, message) => (code, message), | ||
}; | ||
WsErrorResponse { code: id, message } | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize)] | ||
struct WsErrorResponse { | ||
code: i32, | ||
message: String, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize)] | ||
struct WsResponse<T: Serialize + Sized> { | ||
jsonrpc: String, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
result: Option<T>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
error: Option<WsErrorResponse>, | ||
id: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize)] | ||
pub struct WsNotification<T> { | ||
pub jsonrpc: String, | ||
pub method: String, | ||
pub params: T, | ||
} | ||
|
||
#[async_trait::async_trait] | ||
pub trait WsHandle { | ||
type Response: Serialize + Sized; | ||
|
||
async fn process( | ||
self, | ||
req_id: usize, | ||
context: &mut WsContext, | ||
) -> Result<serde_json::Value, serde_json::Error> | ||
where | ||
Self: Sized, | ||
{ | ||
serde_json::to_value(&match self.handle(context).await { | ||
Ok(response) => WsResponse { | ||
jsonrpc: JSON_RPC_VERSION.to_owned(), | ||
result: Some(response), | ||
error: None, | ||
id: req_id, | ||
}, | ||
Err(error) => WsResponse { | ||
jsonrpc: JSON_RPC_VERSION.to_owned(), | ||
result: None, | ||
error: Some(error.into()), | ||
id: req_id, | ||
}, | ||
}) | ||
} | ||
|
||
async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError>; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
use crate::MintState; | ||
use axum::extract::ws::{Message, WebSocket}; | ||
use cdk::nuts::nut17::{NotificationPayload, SubId}; | ||
use futures::StreamExt; | ||
use handler::{WsHandle, WsNotification}; | ||
use serde::{Deserialize, Serialize}; | ||
use std::collections::HashMap; | ||
use subscribe::Notification; | ||
use tokio::sync::mpsc; | ||
|
||
mod error; | ||
mod handler; | ||
mod subscribe; | ||
mod unsubscribe; | ||
|
||
/// JSON RPC version | ||
pub const JSON_RPC_VERSION: &str = "2.0"; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct WsRequest { | ||
jsonrpc: String, | ||
#[serde(flatten)] | ||
method: WsMethod, | ||
id: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[serde(rename_all = "snake_case", tag = "method", content = "params")] | ||
pub enum WsMethod { | ||
Subscribe(subscribe::Method), | ||
Unsubscribe(unsubscribe::Method), | ||
} | ||
|
||
impl WsMethod { | ||
pub async fn process( | ||
self, | ||
req_id: usize, | ||
context: &mut WsContext, | ||
) -> Result<serde_json::Value, serde_json::Error> { | ||
match self { | ||
WsMethod::Subscribe(sub) => sub.process(req_id, context), | ||
WsMethod::Unsubscribe(unsub) => unsub.process(req_id, context), | ||
} | ||
.await | ||
} | ||
} | ||
|
||
pub use error::WsError; | ||
|
||
pub struct WsContext { | ||
state: MintState, | ||
subscriptions: HashMap<SubId, tokio::task::JoinHandle<()>>, | ||
publisher: mpsc::Sender<(SubId, NotificationPayload)>, | ||
} | ||
|
||
/// Main function for websocket connections | ||
/// | ||
/// This function will handle all incoming websocket connections and keep them in their own loop. | ||
/// | ||
/// For simplicity sake this function will spawn tasks for each subscription and | ||
/// keep them in a hashmap, and will have a single subscriber for all of them. | ||
#[allow(clippy::incompatible_msrv)] | ||
pub async fn main_websocket(mut socket: WebSocket, state: MintState) { | ||
let (publisher, mut subscriber) = mpsc::channel(100); | ||
let mut context = WsContext { | ||
state, | ||
subscriptions: HashMap::new(), | ||
publisher, | ||
}; | ||
|
||
loop { | ||
tokio::select! { | ||
Some((sub_id, payload)) = subscriber.recv() => { | ||
if !context.subscriptions.contains_key(&sub_id) { | ||
// It may be possible an incoming message has come from a dropped Subscriptions that has not yet been | ||
// unsubscribed from the subscription manager, just ignore it. | ||
continue; | ||
} | ||
let notification: WsNotification<Notification> = (sub_id, payload).into(); | ||
let message = match serde_json::to_string(¬ification) { | ||
Ok(message) => message, | ||
Err(err) => { | ||
tracing::error!("Could not serialize notification: {}", err); | ||
continue; | ||
} | ||
}; | ||
|
||
if let Err(err)= socket.send(Message::Text(message)).await { | ||
tracing::error!("Could not send websocket message: {}", err); | ||
break; | ||
} | ||
} | ||
Some(Ok(Message::Text(text))) = socket.next() => { | ||
let request = match serde_json::from_str::<WsRequest>(&text) { | ||
Ok(request) => request, | ||
Err(err) => { | ||
tracing::error!("Could not parse request: {}", err); | ||
continue; | ||
} | ||
}; | ||
|
||
match request.method.process(request.id, &mut context).await { | ||
Ok(result) => { | ||
if let Err(err) = socket | ||
.send(Message::Text(result.to_string())) | ||
.await | ||
{ | ||
tracing::error!("Could not send request: {}", err); | ||
break; | ||
} | ||
} | ||
Err(err) => { | ||
tracing::error!("Error serializing response: {}", err); | ||
break; | ||
} | ||
} | ||
} | ||
else => { | ||
|
||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
use super::{ | ||
handler::{WsHandle, WsNotification}, | ||
WsContext, WsError, JSON_RPC_VERSION, | ||
}; | ||
use cdk::{ | ||
nuts::nut17::{NotificationPayload, Params}, | ||
pub_sub::SubId, | ||
}; | ||
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||
pub struct Method(Params); | ||
|
||
#[derive(Debug, Clone, serde::Serialize)] | ||
pub struct Response { | ||
status: String, | ||
#[serde(rename = "subId")] | ||
sub_id: SubId, | ||
} | ||
|
||
#[derive(Debug, Clone, serde::Serialize)] | ||
pub struct Notification { | ||
#[serde(rename = "subId")] | ||
pub sub_id: SubId, | ||
|
||
pub payload: NotificationPayload, | ||
} | ||
|
||
impl From<(SubId, NotificationPayload)> for WsNotification<Notification> { | ||
fn from((sub_id, payload): (SubId, NotificationPayload)) -> Self { | ||
WsNotification { | ||
jsonrpc: JSON_RPC_VERSION.to_owned(), | ||
method: "subscribe".to_string(), | ||
params: Notification { sub_id, payload }, | ||
} | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
impl WsHandle for Method { | ||
type Response = Response; | ||
|
||
async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> { | ||
let sub_id = self.0.id.clone(); | ||
if context.subscriptions.contains_key(&sub_id) { | ||
return Err(WsError::InvalidParams); | ||
} | ||
let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await; | ||
let publisher = context.publisher.clone(); | ||
context.subscriptions.insert( | ||
sub_id.clone(), | ||
tokio::spawn(async move { | ||
while let Some(response) = subscription.recv().await { | ||
let _ = publisher.send(response).await; | ||
} | ||
}), | ||
); | ||
Ok(Response { | ||
status: "OK".to_string(), | ||
sub_id, | ||
}) | ||
} | ||
} |
Oops, something went wrong.