From 4e160dc37bfc2c615f4741c3a7cdc0b40a113092 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 15:50:11 +0800 Subject: [PATCH] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 49 +++++++++++++++++--------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 7817cb6e..eeeef41b 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -162,12 +162,12 @@ impl std::fmt::Debug for Config { /// Callback fn that processes [`WebSocket`] pub trait Callback: Send + 'static { /// Called when a connection upgrade succeeds - fn call(self, _: WebSocket) -> impl Future + Send; + fn call(self, _: WebSocket) -> impl Future + Send; } impl Callback for C where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocket) -> Fut + Send + 'static, C: Copy, { @@ -362,7 +362,7 @@ where /// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`. pub fn on_upgrade(self, default_callback: C1) -> ServerResponse where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C1: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, { let on_upgrade = self.on_upgrade; @@ -493,7 +493,10 @@ impl FromContext for WebSocketUpgrade { .remove::() .ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?; - let sec_websocket_protocol = parts.headers.get(http::header::SEC_WEBSOCKET_PROTOCOL).cloned(); + let sec_websocket_protocol = parts + .headers + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .cloned(); Ok(Self { config: Default::default(), @@ -513,7 +516,7 @@ mod websocket_tests { use std::{net, ops::Add}; use futures_util::{SinkExt, StreamExt}; - use http::{Uri}; + use http::Uri; use motore::Service; use tokio::net::TcpStream; use tokio_tungstenite::{ @@ -538,7 +541,7 @@ mod websocket_tests { ) -> (WebSocketStream>, ServerResponse) where R: IntoClientRequest + Unpin, - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static, C: Send + Sync + Clone, { @@ -624,23 +627,23 @@ mod websocket_tests { ws.set_config( WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), ) - .on_protocol(HashMap::from([( - "graphql-ws", - |mut socket: WebSocket| async move { - while let Some(Ok(msg)) = socket.next().await { - match msg { - Message::Text(text) => { - socket - .send(Message::Text(text.add("-graphql-ws"))) - .await - .unwrap(); - } - _ => {} + .on_protocol(HashMap::from([( + "graphql-ws", + |mut socket: WebSocket| async move { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(text) => { + socket + .send(Message::Text(text.add("-graphql-ws"))) + .await + .unwrap(); } + _ => {} } - }, - )])) - .on_upgrade(|_| async {}) + } + }, + )])) + .on_upgrade(|_| async {}) } let addr = Address::Ip(net::SocketAddr::new( @@ -653,7 +656,7 @@ mod websocket_tests { .parse::() .unwrap(), ) - .with_sub_protocol("graphql-ws"); + .with_sub_protocol("graphql-ws"); let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await; let input = Message::Text("foobar".to_owned()); @@ -696,7 +699,7 @@ mod websocket_tests { |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), builder, ) - .await; + .await; let input = Message::Text("foobar".to_owned()); ws_stream.send(input.clone()).await.unwrap();