Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port miniserve to async #17

Open
wants to merge 3 commits into
base: 00-chat-route-b
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[workspace]
members = ["crates/*"]
resolver = "2"

[workspace.dependencies]
tokio = { version = "1.39.2", default-features = false }
4 changes: 4 additions & 0 deletions crates/miniserve/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@ version = "0.1.0"
edition = "2021"

[dependencies]
futures = { version = "0.3.30", default-features = false }
http = "1.1.0"
httparse = "1.9.4"
tokio = { workspace = true, features = ["net", "rt"] }
tokio-stream = { version = "0.1.15", default-features = false }
tokio-util = { version = "0.7.11", default-features = false, features = ["codec"] }
58 changes: 35 additions & 23 deletions crates/miniserve/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
#![warn(clippy::pedantic)]

use std::{
collections::HashMap,
io::{self},
net::{TcpListener, TcpStream},
sync::Arc,
thread,
};
use std::{collections::HashMap, future::Future, io, pin::Pin, sync::Arc};
use tokio::net::{TcpListener, TcpStream};

/// Re-export for library clients.
pub use http;
Expand All @@ -32,15 +27,26 @@ pub enum Content {
pub type Response = Result<Content, http::StatusCode>;

/// Trait alias for functions that can handle requests and return responses.
pub trait Handler: Fn(Request) -> Response + Send + Sync + 'static {}
pub trait Handler: Fn(Request) -> Self::Future + Send + Sync + 'static {
type Future: Future<Output = Response> + Send + Sync + 'static;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a close look at the change to the Handler trait. Previously, a handler was a function that took a request and returned a response. Now, a handler is a function that takes a request and returns a future which eventually returns a response.


impl<F, H> Handler for H
where
F: Future<Output = Response> + Send + Sync + 'static,
H: Fn(Request) -> F + Send + Sync + 'static,
{
type Future = F;
}

impl<F> Handler for F where F: Fn(Request) -> Response + Send + Sync + 'static {}
type ErasedHandler =
Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + Sync>> + Send + Sync>;

/// The main server data structure.
#[derive(Default)]
pub struct Server {
/// Map from a route path (e.g., "/foo") to a handler function for that route.
routes: HashMap<String, Box<dyn Handler>>,
routes: HashMap<String, ErasedHandler>,
}

impl Server {
Expand All @@ -55,7 +61,12 @@ impl Server {
/// Adds a new route to the server.
#[must_use]
pub fn route<H: Handler>(mut self, route: impl Into<String>, handler: H) -> Self {
self.routes.insert(route.into(), Box::new(handler));
let handler = Arc::new(handler);
let erased = Box::new(move |req| {
let handler_ref = Arc::clone(&handler);
Box::pin(handler_ref(req)) as Pin<Box<dyn Future<Output = Response> + Send + Sync>>
});
self.routes.insert(route.into(), erased);
self
}

Expand All @@ -66,21 +77,22 @@ impl Server {
/// # Panics
///
/// Panics if `127.0.0.1:3000` is not available.
pub fn run(self) {
let listener =
TcpListener::bind("127.0.0.1:3000").expect("Failed to connect to 127.0.0.1:3000");
pub async fn run(self) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the run function is now asynchronous, indicated by the async keyword.

let listener = TcpListener::bind("127.0.0.1:3000")
.await
.expect("Failed to connect to 127.0.0.1:3000");
let this = Arc::new(self);
for stream in listener.incoming().flatten() {
let this_ref = Arc::clone(&this);
thread::spawn(move || {
let _ = this_ref.handle(&stream);
});
loop {
if let Ok((stream, _)) = listener.accept().await {
let this_ref = Arc::clone(&this);
tokio::spawn(async move {
let _ = this_ref.handle(stream).await;
});
}
}
}

fn handle(&self, stream: &TcpStream) -> io::Result<()> {
protocol::handle(stream, |route, request| {
self.routes.get(route).map(move |handler| handler(request))
})
async fn handle(&self, stream: TcpStream) -> io::Result<()> {
protocol::handle(stream, &|route| self.routes.get(route)).await
}
}
176 changes: 85 additions & 91 deletions crates/miniserve/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,86 @@
//!
//! You should not need to deal with this module.

use std::{
io::{self, BufRead, BufReader, BufWriter, Write},
net::{Shutdown, TcpStream},
};

use futures::SinkExt;
use http::StatusCode;
use std::io;
use tokio::net::TcpStream;
use tokio_stream::StreamExt;
use tokio_util::{
bytes::BytesMut,
codec::{Decoder, Encoder, Framed},
};

pub fn stringify_response(response: http::Response<Vec<u8>>) -> Vec<u8> {
let (parts, body) = response.into_parts();
struct HttpCodec;

let mut buf = Vec::with_capacity(body.len() + 256);
buf.extend_from_slice(b"HTTP/1.1 ");
buf.extend(parts.status.as_str().as_bytes());
if let Some(reason) = parts.status.canonical_reason() {
buf.extend_from_slice(b" ");
buf.extend(reason.as_bytes());
}
impl Encoder<http::Response<Vec<u8>>> for HttpCodec {
type Error = io::Error;

buf.extend_from_slice(b"\r\n");
fn encode(
&mut self,
response: http::Response<Vec<u8>>,
buf: &mut BytesMut,
) -> Result<(), Self::Error> {
let (parts, body) = response.into_parts();

for (name, value) in parts.headers {
if let Some(name) = name {
buf.extend(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(b"HTTP/1.1 ");
buf.extend(parts.status.as_str().as_bytes());
if let Some(reason) = parts.status.canonical_reason() {
buf.extend_from_slice(b" ");
buf.extend(reason.as_bytes());
}
buf.extend(value.as_bytes());

buf.extend_from_slice(b"\r\n");
}

buf.extend_from_slice(b"\r\n");
buf.extend(body);
for (name, value) in parts.headers {
if let Some(name) = name {
buf.extend(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
}
buf.extend(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}

buf.extend_from_slice(b"\r\n");
buf.extend(body);

buf
Ok(())
}
}

#[allow(clippy::result_large_err)]
fn parse_request(src: &[u8]) -> Result<Option<http::Request<Vec<u8>>>, http::Response<Vec<u8>>> {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut parsed_req = httparse::Request::new(&mut headers);
let Ok(status) = parsed_req.parse(src) else {
return Err(make_response(
StatusCode::BAD_REQUEST,
"Failed to parse request",
));
};
let amt = match status {
httparse::Status::Complete(amt) => amt,
httparse::Status::Partial => return Ok(None),
};
impl Decoder for HttpCodec {
type Item = http::Request<Vec<u8>>;
type Error = io::Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut parsed_req = httparse::Request::new(&mut headers);
let status = parsed_req.parse(src).map_err(|e| {
let msg = format!("failed to parse http request: {e:?}");
io::Error::new(io::ErrorKind::Other, msg)
})?;
let amt = match status {
httparse::Status::Complete(amt) => amt,
httparse::Status::Partial => return Ok(None),
};

let Ok(method) = http::Method::try_from(parsed_req.method.unwrap()) else {
return Err(make_response(
StatusCode::BAD_REQUEST,
"Failed to parse request",
));
};
let method = http::Method::try_from(parsed_req.method.unwrap()).map_err(|e| {
let msg = format!("failed to parse http request: {e:?}");
io::Error::new(io::ErrorKind::Other, msg)
})?;

let data = &src[amt..];
let data = &src[amt..];

let mut builder = http::Request::builder()
.method(method)
.version(http::Version::HTTP_11)
.uri(parsed_req.path.unwrap());
for header in parsed_req.headers {
builder = builder.header(header.name, header.value);
}
let mut builder = http::Request::builder()
.method(method)
.version(http::Version::HTTP_11)
.uri(parsed_req.path.unwrap());
for header in parsed_req.headers {
builder = builder.header(header.name, header.value);
}

Ok(Some(builder.body(data.to_vec()).unwrap()))
Ok(Some(builder.body(data.to_vec()).unwrap()))
}
}

fn make_response(status: http::StatusCode, explanation: &str) -> http::Response<Vec<u8>> {
Expand All @@ -79,9 +91,9 @@ fn make_response(status: http::StatusCode, explanation: &str) -> http::Response<
.unwrap()
}

fn generate_response(
async fn generate_response<'a>(
req: http::Request<Vec<u8>>,
callback: impl Fn(&str, crate::Request) -> Option<crate::Response>,
callback: impl Fn(&str) -> Option<&'a crate::ErasedHandler> + 'a,
) -> http::Response<Vec<u8>> {
let (parts, body) = req.into_parts();
let request = match parts.method {
Expand All @@ -90,10 +102,12 @@ fn generate_response(
_ => return make_response(StatusCode::METHOD_NOT_ALLOWED, "Not implemented"),
};

let Some(response_res) = callback(parts.uri.path(), request) else {
let Some(handler) = callback(parts.uri.path()) else {
return make_response(StatusCode::NOT_FOUND, "No valid route");
};

let response_res = handler(request).await;

match response_res {
Ok(content) => {
let (body, ty) = match content {
Expand All @@ -110,43 +124,23 @@ fn generate_response(
}
}

pub fn handle(
stream: &TcpStream,
callback: impl Fn(&str, crate::Request) -> Option<crate::Response>,
pub async fn handle<'a>(
stream: TcpStream,
callback: &'a (impl Fn(&str) -> Option<&'a crate::ErasedHandler> + 'a),
) -> io::Result<()> {
let mut reader = BufReader::new(stream.try_clone()?);
let mut writer = BufWriter::new(stream.try_clone()?);

loop {
let req = loop {
let buf = reader.fill_buf()?;
if buf.is_empty() {
stream.shutdown(Shutdown::Both)?;
return Ok(());
let mut transport = Framed::new(stream, HttpCodec);
if let Some(request) = transport.next().await {
match request {
Ok(request) => {
let response = generate_response(request, callback).await;
transport.send(response).await?;
}

match parse_request(buf) {
Ok(None) => continue,
Ok(Some(req)) => {
let amt = buf.len();
reader.consume(amt);
break Ok(req);
}
Err(resp) => {
let amt = buf.len();
reader.consume(amt);
break Err(resp);
}
Err(e) => {
let response = make_response(StatusCode::BAD_REQUEST, &e.to_string());
transport.send(response).await?;
}
};

let resp = match req {
Ok(req) => generate_response(req, &callback),
Err(resp) => resp,
};

let buf = stringify_response(resp);
writer.write_all(&buf)?;
writer.flush()?;
}
}

Ok(())
}
Loading