diff --git a/Cargo.toml b/Cargo.toml index d767148..6e84da2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,6 @@ [workspace] members = ["crates/*"] resolver = "2" + +[workspace.dependencies] +tokio = { version = "1.39.2", default-features = false } \ No newline at end of file diff --git a/crates/miniserve/Cargo.toml b/crates/miniserve/Cargo.toml index e39ff98..897b706 100644 --- a/crates/miniserve/Cargo.toml +++ b/crates/miniserve/Cargo.toml @@ -4,5 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] +futures = { version = "0.3.30", default-features = false } http = "1.1.0" httparse = "1.9.4" +tokio = { workspace = true, features = ["net", "rt"] } +tokio-stream = { version = "0.1.15", default-features = false } +tokio-util = { version = "0.7.11", default-features = false, features = ["codec"] } diff --git a/crates/miniserve/src/lib.rs b/crates/miniserve/src/lib.rs index ad4322b..90d61bf 100644 --- a/crates/miniserve/src/lib.rs +++ b/crates/miniserve/src/lib.rs @@ -1,12 +1,7 @@ #![warn(clippy::pedantic)] -use std::{ - collections::HashMap, - io::{self}, - net::{TcpListener, TcpStream}, - sync::Arc, - thread, -}; +use std::{collections::HashMap, future::Future, io, pin::Pin, sync::Arc}; +use tokio::net::{TcpListener, TcpStream}; /// Re-export for library clients. pub use http; @@ -32,15 +27,26 @@ pub enum Content { pub type Response = Result; /// Trait alias for functions that can handle requests and return responses. -pub trait Handler: Fn(Request) -> Response + Send + Sync + 'static {} +pub trait Handler: Fn(Request) -> Self::Future + Send + Sync + 'static { + type Future: Future + Send + Sync + 'static; +} + +impl Handler for H +where + F: Future + Send + Sync + 'static, + H: Fn(Request) -> F + Send + Sync + 'static, +{ + type Future = F; +} -impl Handler for F where F: Fn(Request) -> Response + Send + Sync + 'static {} +type ErasedHandler = + Box Pin + Send + Sync>> + Send + Sync>; /// The main server data structure. #[derive(Default)] pub struct Server { /// Map from a route path (e.g., "/foo") to a handler function for that route. - routes: HashMap>, + routes: HashMap, } impl Server { @@ -55,7 +61,12 @@ impl Server { /// Adds a new route to the server. #[must_use] pub fn route(mut self, route: impl Into, handler: H) -> Self { - self.routes.insert(route.into(), Box::new(handler)); + let handler = Arc::new(handler); + let erased = Box::new(move |req| { + let handler_ref = Arc::clone(&handler); + Box::pin(handler_ref(req)) as Pin + Send + Sync>> + }); + self.routes.insert(route.into(), erased); self } @@ -66,21 +77,22 @@ impl Server { /// # Panics /// /// Panics if `127.0.0.1:3000` is not available. - pub fn run(self) { - let listener = - TcpListener::bind("127.0.0.1:3000").expect("Failed to connect to 127.0.0.1:3000"); + pub async fn run(self) { + let listener = TcpListener::bind("127.0.0.1:3000") + .await + .expect("Failed to connect to 127.0.0.1:3000"); let this = Arc::new(self); - for stream in listener.incoming().flatten() { - let this_ref = Arc::clone(&this); - thread::spawn(move || { - let _ = this_ref.handle(&stream); - }); + loop { + if let Ok((stream, _)) = listener.accept().await { + let this_ref = Arc::clone(&this); + tokio::spawn(async move { + let _ = this_ref.handle(stream).await; + }); + } } } - fn handle(&self, stream: &TcpStream) -> io::Result<()> { - protocol::handle(stream, |route, request| { - self.routes.get(route).map(move |handler| handler(request)) - }) + async fn handle(&self, stream: TcpStream) -> io::Result<()> { + protocol::handle(stream, &|route| self.routes.get(route)).await } } diff --git a/crates/miniserve/src/protocol.rs b/crates/miniserve/src/protocol.rs index b785634..74638a9 100644 --- a/crates/miniserve/src/protocol.rs +++ b/crates/miniserve/src/protocol.rs @@ -2,74 +2,86 @@ //! //! You should not need to deal with this module. -use std::{ - io::{self, BufRead, BufReader, BufWriter, Write}, - net::{Shutdown, TcpStream}, -}; - +use futures::SinkExt; use http::StatusCode; +use std::io; +use tokio::net::TcpStream; +use tokio_stream::StreamExt; +use tokio_util::{ + bytes::BytesMut, + codec::{Decoder, Encoder, Framed}, +}; -pub fn stringify_response(response: http::Response>) -> Vec { - let (parts, body) = response.into_parts(); +struct HttpCodec; - let mut buf = Vec::with_capacity(body.len() + 256); - buf.extend_from_slice(b"HTTP/1.1 "); - buf.extend(parts.status.as_str().as_bytes()); - if let Some(reason) = parts.status.canonical_reason() { - buf.extend_from_slice(b" "); - buf.extend(reason.as_bytes()); - } +impl Encoder>> for HttpCodec { + type Error = io::Error; - buf.extend_from_slice(b"\r\n"); + fn encode( + &mut self, + response: http::Response>, + buf: &mut BytesMut, + ) -> Result<(), Self::Error> { + let (parts, body) = response.into_parts(); - for (name, value) in parts.headers { - if let Some(name) = name { - buf.extend(name.as_str().as_bytes()); - buf.extend_from_slice(b": "); + buf.extend_from_slice(b"HTTP/1.1 "); + buf.extend(parts.status.as_str().as_bytes()); + if let Some(reason) = parts.status.canonical_reason() { + buf.extend_from_slice(b" "); + buf.extend(reason.as_bytes()); } - buf.extend(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); - } - buf.extend_from_slice(b"\r\n"); - buf.extend(body); + for (name, value) in parts.headers { + if let Some(name) = name { + buf.extend(name.as_str().as_bytes()); + buf.extend_from_slice(b": "); + } + buf.extend(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + + buf.extend_from_slice(b"\r\n"); + buf.extend(body); - buf + Ok(()) + } } -#[allow(clippy::result_large_err)] -fn parse_request(src: &[u8]) -> Result>>, http::Response>> { - let mut headers = [httparse::EMPTY_HEADER; 64]; - let mut parsed_req = httparse::Request::new(&mut headers); - let Ok(status) = parsed_req.parse(src) else { - return Err(make_response( - StatusCode::BAD_REQUEST, - "Failed to parse request", - )); - }; - let amt = match status { - httparse::Status::Complete(amt) => amt, - httparse::Status::Partial => return Ok(None), - }; +impl Decoder for HttpCodec { + type Item = http::Request>; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let mut headers = [httparse::EMPTY_HEADER; 64]; + let mut parsed_req = httparse::Request::new(&mut headers); + let status = parsed_req.parse(src).map_err(|e| { + let msg = format!("failed to parse http request: {e:?}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let amt = match status { + httparse::Status::Complete(amt) => amt, + httparse::Status::Partial => return Ok(None), + }; - let Ok(method) = http::Method::try_from(parsed_req.method.unwrap()) else { - return Err(make_response( - StatusCode::BAD_REQUEST, - "Failed to parse request", - )); - }; + let method = http::Method::try_from(parsed_req.method.unwrap()).map_err(|e| { + let msg = format!("failed to parse http request: {e:?}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; - let data = &src[amt..]; + let data = &src[amt..]; - let mut builder = http::Request::builder() - .method(method) - .version(http::Version::HTTP_11) - .uri(parsed_req.path.unwrap()); - for header in parsed_req.headers { - builder = builder.header(header.name, header.value); - } + let mut builder = http::Request::builder() + .method(method) + .version(http::Version::HTTP_11) + .uri(parsed_req.path.unwrap()); + for header in parsed_req.headers { + builder = builder.header(header.name, header.value); + } - Ok(Some(builder.body(data.to_vec()).unwrap())) + Ok(Some(builder.body(data.to_vec()).unwrap())) + } } fn make_response(status: http::StatusCode, explanation: &str) -> http::Response> { @@ -79,9 +91,9 @@ fn make_response(status: http::StatusCode, explanation: &str) -> http::Response< .unwrap() } -fn generate_response( +async fn generate_response<'a>( req: http::Request>, - callback: impl Fn(&str, crate::Request) -> Option, + callback: impl Fn(&str) -> Option<&'a crate::ErasedHandler> + 'a, ) -> http::Response> { let (parts, body) = req.into_parts(); let request = match parts.method { @@ -90,10 +102,12 @@ fn generate_response( _ => return make_response(StatusCode::METHOD_NOT_ALLOWED, "Not implemented"), }; - let Some(response_res) = callback(parts.uri.path(), request) else { + let Some(handler) = callback(parts.uri.path()) else { return make_response(StatusCode::NOT_FOUND, "No valid route"); }; + let response_res = handler(request).await; + match response_res { Ok(content) => { let (body, ty) = match content { @@ -110,43 +124,23 @@ fn generate_response( } } -pub fn handle( - stream: &TcpStream, - callback: impl Fn(&str, crate::Request) -> Option, +pub async fn handle<'a>( + stream: TcpStream, + callback: &'a (impl Fn(&str) -> Option<&'a crate::ErasedHandler> + 'a), ) -> io::Result<()> { - let mut reader = BufReader::new(stream.try_clone()?); - let mut writer = BufWriter::new(stream.try_clone()?); - - loop { - let req = loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { - stream.shutdown(Shutdown::Both)?; - return Ok(()); + let mut transport = Framed::new(stream, HttpCodec); + if let Some(request) = transport.next().await { + match request { + Ok(request) => { + let response = generate_response(request, callback).await; + transport.send(response).await?; } - - match parse_request(buf) { - Ok(None) => continue, - Ok(Some(req)) => { - let amt = buf.len(); - reader.consume(amt); - break Ok(req); - } - Err(resp) => { - let amt = buf.len(); - reader.consume(amt); - break Err(resp); - } + Err(e) => { + let response = make_response(StatusCode::BAD_REQUEST, &e.to_string()); + transport.send(response).await?; } - }; - - let resp = match req { - Ok(req) => generate_response(req, &callback), - Err(resp) => resp, - }; - - let buf = stringify_response(resp); - writer.write_all(&buf)?; - writer.flush()?; + } } + + Ok(()) }