From 4c71ed6ee3747ff97ca7ff9cd8abefac7636fe0b Mon Sep 17 00:00:00 2001 From: lobsterwise Date: Fri, 16 Aug 2024 22:47:26 -0400 Subject: [PATCH] Fix support for HTTP extension methods. --- core/lib/src/request/request.rs | 7 +++--- testbench/src/client.rs | 6 +++-- testbench/src/servers/http_extensions.rs | 29 ++++++++++++++++++++++++ testbench/src/servers/mod.rs | 1 + 4 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 testbench/src/servers/http_extensions.rs diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 93912383de..e99e721d3b 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -2,6 +2,7 @@ use std::{io, fmt}; use std::ops::RangeFrom; use std::sync::{Arc, atomic::Ordering}; use std::borrow::Cow; +use std::str::FromStr; use std::future::Future; use std::net::IpAddr; @@ -1086,7 +1087,7 @@ impl<'r> Request<'r> { // Keep track of parsing errors; emit a `BadRequest` if any exist. let mut errors = vec![]; - // Ensure that the method is known. TODO: Allow made-up methods? + // Ensure that the method is known. let method = match hyper.method { hyper::Method::GET => Method::Get, hyper::Method::PUT => Method::Put, @@ -1097,10 +1098,10 @@ impl<'r> Request<'r> { hyper::Method::TRACE => Method::Trace, hyper::Method::CONNECT => Method::Connect, hyper::Method::PATCH => Method::Patch, - _ => { + ref ext => Method::from_str(ext.as_str()).unwrap_or_else(|_| { errors.push(RequestError::BadMethod(hyper.method.clone())); Method::Get - } + }), }; // TODO: Keep around not just the path/query, but the rest, if there? diff --git a/testbench/src/client.rs b/testbench/src/client.rs index 953b2f907a..607ba9fa11 100644 --- a/testbench/src/client.rs +++ b/testbench/src/client.rs @@ -26,7 +26,9 @@ impl Client { .connect_timeout(Duration::from_secs(5)) } - pub fn request(&self, server: &Server, method: Method, url: &str) -> Result { + pub fn request(&self, server: &Server, method: M, url: &str) -> Result + where M: AsRef + { let uri = match Uri::parse_any(url).map_err(|e| e.into_owned())? { Uri::Origin(uri) => { let proto = if server.tls { "https" } else { "http" }; @@ -45,7 +47,7 @@ impl Client { uri => return Err(Error::InvalidUri(uri.into_owned())), }; - let method = reqwest::Method::from_str(method.as_str()).unwrap(); + let method = reqwest::Method::from_str(method.as_ref()).unwrap(); Ok(self.client.request(method, uri.to_string())) } diff --git a/testbench/src/servers/http_extensions.rs b/testbench/src/servers/http_extensions.rs new file mode 100644 index 0000000000..42b990c76e --- /dev/null +++ b/testbench/src/servers/http_extensions.rs @@ -0,0 +1,29 @@ +//! Test that HTTP method extensions unlike POST or GET work. + +use crate::prelude::*; + +use rocket::http::Method; + +#[route(PROPFIND, uri = "/")] +fn route() -> &'static str { + "Hello, World!" +} + +pub fn test_http_extensions() -> Result<()> { + let server = spawn! { + Rocket::default().mount("/", routes![route]) + }?; + + let client = Client::default(); + let response = client.request(&server, Method::PropFind, "/")?.send()?; + assert_eq!(response.status(), 200); + assert_eq!(response.text()?, "Hello, World!"); + + // Make sure that verbs outside of extensions are marked as errors + let res = client.request(&server, "BAKEMEACOOKIE", "/")?.send()?; + assert_eq!(res.status(), 400); + + Ok(()) +} + +register!(test_http_extensions); diff --git a/testbench/src/servers/mod.rs b/testbench/src/servers/mod.rs index 30fd6ff97c..d05aa00c69 100644 --- a/testbench/src/servers/mod.rs +++ b/testbench/src/servers/mod.rs @@ -1,5 +1,6 @@ pub mod ignite_failure; pub mod bind; +pub mod http_extensions; pub mod infinite_stream; pub mod tls_resolver; pub mod mtls;