Skip to content

Commit

Permalink
feat(http): support websocket server
Browse files Browse the repository at this point in the history
  • Loading branch information
StellarisW committed Aug 7, 2024
1 parent 44f12ac commit 4e160dc
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions volo-http/src/server/utils/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ impl std::fmt::Debug for Config {
/// Callback fn that processes [`WebSocket`]
pub trait Callback: Send + 'static {
/// Called when a connection upgrade succeeds
fn call(self, _: WebSocket) -> impl Future<Output=()> + Send;
fn call(self, _: WebSocket) -> impl Future<Output = ()> + Send;
}

impl<Fut, C> Callback for C
where
Fut: Future<Output=()> + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
C: FnOnce(WebSocket) -> Fut + Send + 'static,
C: Copy,
{
Expand Down Expand Up @@ -362,7 +362,7 @@ where
/// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`.
pub fn on_upgrade<Fut, C1>(self, default_callback: C1) -> ServerResponse
where
Fut: Future<Output=()> + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
C1: FnOnce(WebSocket) -> Fut + Send + Sync + 'static,
{
let on_upgrade = self.on_upgrade;
Expand Down Expand Up @@ -493,7 +493,10 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
.remove::<hyper::upgrade::OnUpgrade>()
.ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?;

let sec_websocket_protocol = parts.headers.get(http::header::SEC_WEBSOCKET_PROTOCOL).cloned();
let sec_websocket_protocol = parts
.headers
.get(http::header::SEC_WEBSOCKET_PROTOCOL)
.cloned();

Ok(Self {
config: Default::default(),
Expand All @@ -513,7 +516,7 @@ mod websocket_tests {
use std::{net, ops::Add};

use futures_util::{SinkExt, StreamExt};
use http::{Uri};
use http::Uri;
use motore::Service;
use tokio::net::TcpStream;
use tokio_tungstenite::{
Expand All @@ -538,7 +541,7 @@ mod websocket_tests {
) -> (WebSocketStream<MaybeTlsStream<TcpStream>>, ServerResponse)
where
R: IntoClientRequest + Unpin,
Fut: Future<Output=ServerResponse> + Send + 'static,
Fut: Future<Output = ServerResponse> + Send + 'static,
C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static,
C: Send + Sync + Clone,
{
Expand Down Expand Up @@ -624,23 +627,23 @@ mod websocket_tests {
ws.set_config(
WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]),
)
.on_protocol(HashMap::from([(
"graphql-ws",
|mut socket: WebSocket| async move {
while let Some(Ok(msg)) = socket.next().await {
match msg {
Message::Text(text) => {
socket
.send(Message::Text(text.add("-graphql-ws")))
.await
.unwrap();
}
_ => {}
.on_protocol(HashMap::from([(
"graphql-ws",
|mut socket: WebSocket| async move {
while let Some(Ok(msg)) = socket.next().await {
match msg {
Message::Text(text) => {
socket
.send(Message::Text(text.add("-graphql-ws")))
.await
.unwrap();
}
_ => {}
}
},
)]))
.on_upgrade(|_| async {})
}
},
)]))
.on_upgrade(|_| async {})
}

let addr = Address::Ip(net::SocketAddr::new(
Expand All @@ -653,7 +656,7 @@ mod websocket_tests {
.parse::<Uri>()
.unwrap(),
)
.with_sub_protocol("graphql-ws");
.with_sub_protocol("graphql-ws");
let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await;

let input = Message::Text("foobar".to_owned());
Expand Down Expand Up @@ -696,7 +699,7 @@ mod websocket_tests {
|ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)),
builder,
)
.await;
.await;

let input = Message::Text("foobar".to_owned());
ws_stream.send(input.clone()).await.unwrap();
Expand Down

0 comments on commit 4e160dc

Please sign in to comment.