diff --git a/src/server.rs b/src/server.rs index 5fcc6c8..8853b9f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,32 +1,30 @@ +use std::io; use std::net::{SocketAddr, TcpListener}; -use std::task::{Context, Poll}; -use std::{fmt, io}; use axum::Router; -use hyper::service::Service; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::mpsc::Receiver; use tower_http::cors::CorsLayer; -use tower_layer::Layer; +use tower_http::trace::TraceLayer; +/// A simple local server. #[derive(Debug)] pub struct LocalServer { router: Router, listener: TcpListener, - shutdown_rx: Option>, + shutdown_rx: Option>, } impl LocalServer { pub fn new(router: Router) -> anyhow::Result { // Port number of 0 requests OS to find an available port. let listener = TcpListener::bind("localhost:0")?; - - let (tx, rx) = unbounded_channel::<()>(); - let router = router.layer(ShutdownLayer::new(tx)); + // To view the logs emitted by the server, set `RUST_LOG=tower_http=trace` + let router = router.layer(TraceLayer::new_for_http()); Ok(Self { router, listener, - shutdown_rx: Some(rx), + shutdown_rx: None, }) } @@ -37,9 +35,8 @@ impl LocalServer { } /// Disable immediately shutdown the server upon handling the first request. - #[allow(dead_code)] - pub fn no_immediate_shutdown(mut self) -> Self { - self.shutdown_rx = None; + pub fn with_shutdown_signal(mut self, receiver: Receiver<()>) -> Self { + self.shutdown_rx = Some(receiver); self } @@ -64,74 +61,33 @@ impl LocalServer { } } -/// Layer for handling sending a shutdown signal to the server upon -/// receiving the callback request. -#[derive(Clone)] -struct ShutdownLayer { - tx: UnboundedSender<()>, -} - -impl ShutdownLayer { - pub fn new(tx: UnboundedSender<()>) -> Self { - Self { tx } - } -} - -impl Layer for ShutdownLayer { - type Service = ShutdownService; - - fn layer(&self, service: S) -> Self::Service { - ShutdownService { - tx: self.tx.clone(), - service, - } - } -} - -#[derive(Clone)] -pub struct ShutdownService { - tx: UnboundedSender<()>, - service: S, -} - -impl Service for ShutdownService -where - S: Service, - Request: fmt::Debug, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } - - fn call(&mut self, request: Request) -> Self::Future { - self.tx.send(()).expect("failed to send shutdown signal"); - self.service.call(request) - } -} - #[cfg(test)] mod tests { use crate::server::LocalServer; use axum::{routing::get, Router}; #[tokio::test] - async fn test_server_immediate_shutdown() { - let router = Router::new().route("/callback", get(|| async { "Hello, World!" })); - let server = LocalServer::new(router).unwrap(); + async fn test_server_graceful_shutdown() { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let router = Router::new().route("/callback", get(|| async { "Hello, World!" })); + let server = LocalServer::new(router).unwrap().with_shutdown_signal(rx); let port = server.local_addr().unwrap().port(); + let client = reqwest::Client::new(); + let url = format!("http://localhost:{port}/callback"); + // start the local server tokio::spawn(server.start()); - let url = format!("http://localhost:{port}/callback"); // first request should succeed assert!(client.get(&url).send().await.is_ok()); - // second request should fail as server should've been shutdown after first request + + // send shutdown signal + tx.send(()).await.unwrap(); + + // sending request after sending the shutdown signal should fail as server + // should've been shutdown assert!(client.get(url).send().await.is_err()) } } diff --git a/src/session.rs b/src/session.rs index 3106efa..33386f7 100644 --- a/src/session.rs +++ b/src/session.rs @@ -9,7 +9,7 @@ use starknet::core::types::FieldElement; use thiserror::Error; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tower_http::cors::CorsLayer; -use tracing::trace; +use tracing::info; use url::Url; use crate::credential::{self, Credentials}; @@ -46,6 +46,7 @@ pub struct SessionCredentials { pub authorization: Vec, } +// TODO(kariy): unify the error types for the whole Slot crate. #[derive(Debug, Error)] pub enum Error { #[error(transparent)] @@ -104,6 +105,8 @@ where Ok(rx.recv().await.context("Channel dropped.")?) } +/// Get the session token of the chain id `chain` for the currently authenticated user. It will +/// use `config_dir` as the root path to look for the session file. fn get_at

(config_dir: P, chain: FieldElement) -> Result, Error> where P: AsRef, @@ -123,6 +126,8 @@ where } } +/// Stores the session token of the chain id `chain` for the currently authenticated user. It will +/// use `config_dir` as the root path to store the session file. fn store_at

( config_dir: P, chain: FieldElement, @@ -198,17 +203,34 @@ fn prepare_query_params( } /// Create the callback server that will receive the session token from the browser. -fn callback_server(tx: Sender) -> anyhow::Result { - let handler = move |tx: State>, session: Json| async move { - trace!("Received session token from the browser."); - tx.0.send(session.0).await.expect("qed; channel closed"); - }; +fn callback_server(result_sender: Sender) -> anyhow::Result { + let handler = + move |State((res_sender, shutdown_sender)): State<(Sender, Sender<()>)>, + session: Json| async move { + info!("Received session token from the browser."); + + res_sender + .send(session.0) + .await + .expect("qed; channel closed"); + + // send shutdown signal to the server ONLY after succesfully receiving and processing + // the session token. + shutdown_sender + .send(()) + .await + .expect("failed to send shutdown signal."); + }; + + let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1); let router = Router::new() .route("/callback", post(handler)) - .with_state(tx); + .with_state((result_sender, shutdown_tx)); - Ok(LocalServer::new(router)?.cors(CorsLayer::permissive())) + Ok(LocalServer::new(router)? + .cors(CorsLayer::permissive()) + .with_shutdown_signal(shutdown_rx)) } fn get_user_relative_file_path(username: &str, chain_id: FieldElement) -> PathBuf {