Skip to content

Commit

Permalink
Merge pull request #2 from ovstinga/01-async-await-a
Browse files Browse the repository at this point in the history
Port miniserve to async
  • Loading branch information
ovstinga authored Dec 20, 2024
2 parents bfb00bd + ab25251 commit c0ea6df
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 114 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[workspace]
members = ["crates/*"]
resolver = "2"

[workspace.dependencies]
tokio = { version = "1.39.2", default-features = false }
4 changes: 4 additions & 0 deletions crates/miniserve/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@ version = "0.1.0"
edition = "2021"

[dependencies]
futures = { version = "0.3.30", default-features = false }
http = "1.1.0"
httparse = "1.9.4"
tokio = { workspace = true, features = ["net", "rt"] }
tokio-stream = { version = "0.1.15", default-features = false }
tokio-util = { version = "0.7.11", default-features = false, features = ["codec"] }
58 changes: 35 additions & 23 deletions crates/miniserve/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
#![warn(clippy::pedantic)]

use std::{
collections::HashMap,
io::{self},
net::{TcpListener, TcpStream},
sync::Arc,
thread,
};
use std::{collections::HashMap, future::Future, io, pin::Pin, sync::Arc};
use tokio::net::{TcpListener, TcpStream};

/// Re-export for library clients.
pub use http;
Expand All @@ -32,15 +27,26 @@ pub enum Content {
pub type Response = Result<Content, http::StatusCode>;

/// Trait alias for functions that can handle requests and return responses.
pub trait Handler: Fn(Request) -> Response + Send + Sync + 'static {}
pub trait Handler: Fn(Request) -> Self::Future + Send + Sync + 'static {
type Future: Future<Output = Response> + Send + Sync + 'static;
}

impl<F, H> Handler for H
where
F: Future<Output = Response> + Send + Sync + 'static,
H: Fn(Request) -> F + Send + Sync + 'static,
{
type Future = F;
}

impl<F> Handler for F where F: Fn(Request) -> Response + Send + Sync + 'static {}
type ErasedHandler =
Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + Sync>> + Send + Sync>;

/// The main server data structure.
#[derive(Default)]
pub struct Server {
/// Map from a route path (e.g., "/foo") to a handler function for that route.
routes: HashMap<String, Box<dyn Handler>>,
routes: HashMap<String, ErasedHandler>,
}

impl Server {
Expand All @@ -55,7 +61,12 @@ impl Server {
/// Adds a new route to the server.
#[must_use]
pub fn route<H: Handler>(mut self, route: impl Into<String>, handler: H) -> Self {
self.routes.insert(route.into(), Box::new(handler));
let handler = Arc::new(handler);
let erased = Box::new(move |req| {
let handler_ref = Arc::clone(&handler);
Box::pin(handler_ref(req)) as Pin<Box<dyn Future<Output = Response> + Send + Sync>>
});
self.routes.insert(route.into(), erased);
self
}

Expand All @@ -66,21 +77,22 @@ impl Server {
/// # Panics
///
/// Panics if `127.0.0.1:3000` is not available.
pub fn run(self) {
let listener =
TcpListener::bind("127.0.0.1:3000").expect("Failed to connect to 127.0.0.1:3000");
pub async fn run(self) {
let listener = TcpListener::bind("127.0.0.1:3000")
.await
.expect("Failed to connect to 127.0.0.1:3000");
let this = Arc::new(self);
for stream in listener.incoming().flatten() {
let this_ref = Arc::clone(&this);
thread::spawn(move || {
let _ = this_ref.handle(&stream);
});
loop {
if let Ok((stream, _)) = listener.accept().await {
let this_ref = Arc::clone(&this);
tokio::spawn(async move {
let _ = this_ref.handle(stream).await;
});
}
}
}

fn handle(&self, stream: &TcpStream) -> io::Result<()> {
protocol::handle(stream, |route, request| {
self.routes.get(route).map(move |handler| handler(request))
})
async fn handle(&self, stream: TcpStream) -> io::Result<()> {
protocol::handle(stream, &|route| self.routes.get(route)).await
}
}
176 changes: 85 additions & 91 deletions crates/miniserve/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,86 @@
//!
//! You should not need to deal with this module.
use std::{
io::{self, BufRead, BufReader, BufWriter, Write},
net::{Shutdown, TcpStream},
};

use futures::SinkExt;
use http::StatusCode;
use std::io;
use tokio::net::TcpStream;
use tokio_stream::StreamExt;
use tokio_util::{
bytes::BytesMut,
codec::{Decoder, Encoder, Framed},
};

pub fn stringify_response(response: http::Response<Vec<u8>>) -> Vec<u8> {
let (parts, body) = response.into_parts();
struct HttpCodec;

let mut buf = Vec::with_capacity(body.len() + 256);
buf.extend_from_slice(b"HTTP/1.1 ");
buf.extend(parts.status.as_str().as_bytes());
if let Some(reason) = parts.status.canonical_reason() {
buf.extend_from_slice(b" ");
buf.extend(reason.as_bytes());
}
impl Encoder<http::Response<Vec<u8>>> for HttpCodec {
type Error = io::Error;

buf.extend_from_slice(b"\r\n");
fn encode(
&mut self,
response: http::Response<Vec<u8>>,
buf: &mut BytesMut,
) -> Result<(), Self::Error> {
let (parts, body) = response.into_parts();

for (name, value) in parts.headers {
if let Some(name) = name {
buf.extend(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(b"HTTP/1.1 ");
buf.extend(parts.status.as_str().as_bytes());
if let Some(reason) = parts.status.canonical_reason() {
buf.extend_from_slice(b" ");
buf.extend(reason.as_bytes());
}
buf.extend(value.as_bytes());

buf.extend_from_slice(b"\r\n");
}

buf.extend_from_slice(b"\r\n");
buf.extend(body);
for (name, value) in parts.headers {
if let Some(name) = name {
buf.extend(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
}
buf.extend(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}

buf.extend_from_slice(b"\r\n");
buf.extend(body);

buf
Ok(())
}
}

#[allow(clippy::result_large_err)]
fn parse_request(src: &[u8]) -> Result<Option<http::Request<Vec<u8>>>, http::Response<Vec<u8>>> {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut parsed_req = httparse::Request::new(&mut headers);
let Ok(status) = parsed_req.parse(src) else {
return Err(make_response(
StatusCode::BAD_REQUEST,
"Failed to parse request",
));
};
let amt = match status {
httparse::Status::Complete(amt) => amt,
httparse::Status::Partial => return Ok(None),
};
impl Decoder for HttpCodec {
type Item = http::Request<Vec<u8>>;
type Error = io::Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut parsed_req = httparse::Request::new(&mut headers);
let status = parsed_req.parse(src).map_err(|e| {
let msg = format!("failed to parse http request: {e:?}");
io::Error::new(io::ErrorKind::Other, msg)
})?;
let amt = match status {
httparse::Status::Complete(amt) => amt,
httparse::Status::Partial => return Ok(None),
};

let Ok(method) = http::Method::try_from(parsed_req.method.unwrap()) else {
return Err(make_response(
StatusCode::BAD_REQUEST,
"Failed to parse request",
));
};
let method = http::Method::try_from(parsed_req.method.unwrap()).map_err(|e| {
let msg = format!("failed to parse http request: {e:?}");
io::Error::new(io::ErrorKind::Other, msg)
})?;

let data = &src[amt..];
let data = &src[amt..];

let mut builder = http::Request::builder()
.method(method)
.version(http::Version::HTTP_11)
.uri(parsed_req.path.unwrap());
for header in parsed_req.headers {
builder = builder.header(header.name, header.value);
}
let mut builder = http::Request::builder()
.method(method)
.version(http::Version::HTTP_11)
.uri(parsed_req.path.unwrap());
for header in parsed_req.headers {
builder = builder.header(header.name, header.value);
}

Ok(Some(builder.body(data.to_vec()).unwrap()))
Ok(Some(builder.body(data.to_vec()).unwrap()))
}
}

fn make_response(status: http::StatusCode, explanation: &str) -> http::Response<Vec<u8>> {
Expand All @@ -79,9 +91,9 @@ fn make_response(status: http::StatusCode, explanation: &str) -> http::Response<
.unwrap()
}

fn generate_response(
async fn generate_response<'a>(
req: http::Request<Vec<u8>>,
callback: impl Fn(&str, crate::Request) -> Option<crate::Response>,
callback: impl Fn(&str) -> Option<&'a crate::ErasedHandler> + 'a,
) -> http::Response<Vec<u8>> {
let (parts, body) = req.into_parts();
let request = match parts.method {
Expand All @@ -90,10 +102,12 @@ fn generate_response(
_ => return make_response(StatusCode::METHOD_NOT_ALLOWED, "Not implemented"),
};

let Some(response_res) = callback(parts.uri.path(), request) else {
let Some(handler) = callback(parts.uri.path()) else {
return make_response(StatusCode::NOT_FOUND, "No valid route");
};

let response_res = handler(request).await;

match response_res {
Ok(content) => {
let (body, ty) = match content {
Expand All @@ -110,43 +124,23 @@ fn generate_response(
}
}

pub fn handle(
stream: &TcpStream,
callback: impl Fn(&str, crate::Request) -> Option<crate::Response>,
pub async fn handle<'a>(
stream: TcpStream,
callback: &'a (impl Fn(&str) -> Option<&'a crate::ErasedHandler> + 'a),
) -> io::Result<()> {
let mut reader = BufReader::new(stream.try_clone()?);
let mut writer = BufWriter::new(stream.try_clone()?);

loop {
let req = loop {
let buf = reader.fill_buf()?;
if buf.is_empty() {
stream.shutdown(Shutdown::Both)?;
return Ok(());
let mut transport = Framed::new(stream, HttpCodec);
if let Some(request) = transport.next().await {
match request {
Ok(request) => {
let response = generate_response(request, callback).await;
transport.send(response).await?;
}

match parse_request(buf) {
Ok(None) => continue,
Ok(Some(req)) => {
let amt = buf.len();
reader.consume(amt);
break Ok(req);
}
Err(resp) => {
let amt = buf.len();
reader.consume(amt);
break Err(resp);
}
Err(e) => {
let response = make_response(StatusCode::BAD_REQUEST, &e.to_string());
transport.send(response).await?;
}
};

let resp = match req {
Ok(req) => generate_response(req, &callback),
Err(resp) => resp,
};

let buf = stringify_response(resp);
writer.write_all(&buf)?;
writer.flush()?;
}
}

Ok(())
}

0 comments on commit c0ea6df

Please sign in to comment.