Skip to content

Commit

Permalink
simplify local server when handling shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy committed Jun 16, 2024
1 parent 6d9dccf commit cfbf44f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 75 deletions.
90 changes: 23 additions & 67 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
use std::io;
use std::net::{SocketAddr, TcpListener};
use std::task::{Context, Poll};
use std::{fmt, io};

use axum::Router;
use hyper::service::Service;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::mpsc::Receiver;
use tower_http::cors::CorsLayer;
use tower_layer::Layer;
use tower_http::trace::TraceLayer;

/// A simple local server.
#[derive(Debug)]
pub struct LocalServer {
router: Router,
listener: TcpListener,
shutdown_rx: Option<UnboundedReceiver<()>>,
shutdown_rx: Option<Receiver<()>>,
}

impl LocalServer {
pub fn new(router: Router) -> anyhow::Result<Self> {
// Port number of 0 requests OS to find an available port.
let listener = TcpListener::bind("localhost:0")?;

let (tx, rx) = unbounded_channel::<()>();
let router = router.layer(ShutdownLayer::new(tx));
// To view the logs emitted by the server, set `RUST_LOG=tower_http=trace`
let router = router.layer(TraceLayer::new_for_http());

Ok(Self {
router,
listener,
shutdown_rx: Some(rx),
shutdown_rx: None,
})
}

Expand All @@ -37,9 +35,8 @@ impl LocalServer {
}

/// Disable immediately shutdown the server upon handling the first request.
#[allow(dead_code)]
pub fn no_immediate_shutdown(mut self) -> Self {
self.shutdown_rx = None;
pub fn with_shutdown_signal(mut self, receiver: Receiver<()>) -> Self {
self.shutdown_rx = Some(receiver);
self
}

Expand All @@ -64,74 +61,33 @@ impl LocalServer {
}
}

/// Layer for handling sending a shutdown signal to the server upon
/// receiving the callback request.
#[derive(Clone)]
struct ShutdownLayer {
tx: UnboundedSender<()>,
}

impl ShutdownLayer {
pub fn new(tx: UnboundedSender<()>) -> Self {
Self { tx }
}
}

impl<S> Layer<S> for ShutdownLayer {
type Service = ShutdownService<S>;

fn layer(&self, service: S) -> Self::Service {
ShutdownService {
tx: self.tx.clone(),
service,
}
}
}

#[derive(Clone)]
pub struct ShutdownService<S> {
tx: UnboundedSender<()>,
service: S,
}

impl<S, Request> Service<Request> for ShutdownService<S>
where
S: Service<Request>,
Request: fmt::Debug,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}

fn call(&mut self, request: Request) -> Self::Future {
self.tx.send(()).expect("failed to send shutdown signal");
self.service.call(request)
}
}

#[cfg(test)]
mod tests {
use crate::server::LocalServer;
use axum::{routing::get, Router};

#[tokio::test]
async fn test_server_immediate_shutdown() {
let router = Router::new().route("/callback", get(|| async { "Hello, World!" }));
let server = LocalServer::new(router).unwrap();
async fn test_server_graceful_shutdown() {
let (tx, rx) = tokio::sync::mpsc::channel(1);

let router = Router::new().route("/callback", get(|| async { "Hello, World!" }));
let server = LocalServer::new(router).unwrap().with_shutdown_signal(rx);
let port = server.local_addr().unwrap().port();

let client = reqwest::Client::new();
let url = format!("http://localhost:{port}/callback");

// start the local server
tokio::spawn(server.start());

let url = format!("http://localhost:{port}/callback");
// first request should succeed
assert!(client.get(&url).send().await.is_ok());
// second request should fail as server should've been shutdown after first request

// send shutdown signal
tx.send(()).await.unwrap();

// sending request after sending the shutdown signal should fail as server
// should've been shutdown
assert!(client.get(url).send().await.is_err())
}
}
38 changes: 30 additions & 8 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use starknet::core::types::FieldElement;
use thiserror::Error;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tower_http::cors::CorsLayer;
use tracing::trace;
use tracing::info;
use url::Url;

use crate::credential::{self, Credentials};
Expand Down Expand Up @@ -46,6 +46,7 @@ pub struct SessionCredentials {
pub authorization: Vec<FieldElement>,
}

// TODO(kariy): unify the error types for the whole Slot crate.
#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
Expand Down Expand Up @@ -104,6 +105,8 @@ where
Ok(rx.recv().await.context("Channel dropped.")?)
}

/// Get the session token of the chain id `chain` for the currently authenticated user. It will
/// use `config_dir` as the root path to look for the session file.
fn get_at<P>(config_dir: P, chain: FieldElement) -> Result<Option<SessionDetails>, Error>
where
P: AsRef<Path>,
Expand All @@ -123,6 +126,8 @@ where
}
}

/// Stores the session token of the chain id `chain` for the currently authenticated user. It will
/// use `config_dir` as the root path to store the session file.
fn store_at<P>(
config_dir: P,
chain: FieldElement,
Expand Down Expand Up @@ -198,17 +203,34 @@ fn prepare_query_params(
}

/// Create the callback server that will receive the session token from the browser.
fn callback_server(tx: Sender<SessionDetails>) -> anyhow::Result<LocalServer> {
let handler = move |tx: State<Sender<SessionDetails>>, session: Json<SessionDetails>| async move {
trace!("Received session token from the browser.");
tx.0.send(session.0).await.expect("qed; channel closed");
};
fn callback_server(result_sender: Sender<SessionDetails>) -> anyhow::Result<LocalServer> {
let handler =
move |State((res_sender, shutdown_sender)): State<(Sender<SessionDetails>, Sender<()>)>,
session: Json<SessionDetails>| async move {
info!("Received session token from the browser.");

res_sender
.send(session.0)
.await
.expect("qed; channel closed");

// send shutdown signal to the server ONLY after succesfully receiving and processing
// the session token.
shutdown_sender
.send(())
.await
.expect("failed to send shutdown signal.");
};

let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1);

let router = Router::new()
.route("/callback", post(handler))
.with_state(tx);
.with_state((result_sender, shutdown_tx));

Ok(LocalServer::new(router)?.cors(CorsLayer::permissive()))
Ok(LocalServer::new(router)?
.cors(CorsLayer::permissive())
.with_shutdown_signal(shutdown_rx))
}

fn get_user_relative_file_path(username: &str, chain_id: FieldElement) -> PathBuf {
Expand Down

0 comments on commit cfbf44f

Please sign in to comment.