Skip to content

Commit

Permalink
wip: hyper 1.0 upgrade + custom listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Jan 26, 2024
1 parent e9b568d commit 2dfcc1e
Show file tree
Hide file tree
Showing 87 changed files with 3,428 additions and 2,854 deletions.
1 change: 0 additions & 1 deletion contrib/ws/src/duplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use crate::result::{Result, Error};
///
/// [`StreamExt`]: rocket::futures::StreamExt
/// [`SinkExt`]: rocket::futures::SinkExt
pub struct DuplexStream(tokio_tungstenite::WebSocketStream<IoStream>);

impl DuplexStream {
Expand Down
19 changes: 8 additions & 11 deletions contrib/ws/src/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::io;
use std::pin::Pin;

use rocket::data::{IoHandler, IoStream};
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
Expand Down Expand Up @@ -37,10 +36,6 @@ pub struct WebSocket {
}

impl WebSocket {
fn new(key: String) -> WebSocket {
WebSocket { config: Config::default(), key }
}

/// Change the default connection configuration to `config`.
///
/// # Example
Expand Down Expand Up @@ -202,7 +197,9 @@ impl<'r> FromRequest<'r> for WebSocket {
let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
match key {
Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)),
Some(key) if is_upgrade && is_ws && is_13 => {
Outcome::Success(WebSocket { key, config: Config::default() })
},
Some(_) | None => Outcome::Forward(Status::BadRequest)
}
}
Expand Down Expand Up @@ -232,9 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>

#[rocket::async_trait]
impl IoHandler for Channel<'_> {
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
let channel = Pin::into_inner(self);
let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
let stream = DuplexStream::new(io, self.ws.config).await;
let result = (self.handler)(stream).await;
handle_result(result).map(|_| ())
}
}
Expand All @@ -243,9 +240,9 @@ impl IoHandler for Channel<'_> {
impl<'r, S> IoHandler for MessageStream<'r, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'r
{
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
let stream = (Pin::into_inner(self).handler)(source);
let stream = (self.handler)(source);
rocket::tokio::pin!(stream);
while let Some(msg) = stream.next().await {
let result = match msg {
Expand Down
61 changes: 61 additions & 0 deletions core/codegen/src/attribute/async_bound/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use proc_macro2::{TokenStream, Span};
use devise::{Spanned, Result, ext::SpanDiagnosticExt};
use syn::{Token, parse_quote, parse_quote_spanned};
use syn::{TraitItemFn, TypeParamBound, ReturnType, Attribute};
use syn::punctuated::Punctuated;
use syn::parse::Parser;

fn _async_bound(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
) -> Result<TokenStream> {
let bounds = <Punctuated<TypeParamBound, Token![+]>>::parse_terminated.parse(args)?;
if bounds.is_empty() {
return Ok(input.into());
}

let mut func: TraitItemFn = syn::parse(input)?;
let original: TraitItemFn = func.clone();
if !func.sig.asyncness.is_some() {
let diag = Span::call_site()
.error("attribute can only be applied to async fns")
.span_help(func.sig.span(), "this fn declaration must be `async`");

return Err(diag);
}

let doc: Attribute = parse_quote! {
#[doc = concat!(
"# Future Bounds",
"\n",
"**The `Future` generated by this `async fn` must be `", stringify!(#bounds), "`**."
)]
};

func.sig.asyncness = None;
func.sig.output = match func.sig.output {
ReturnType::Type(arrow, ty) => parse_quote_spanned!(ty.span() =>
#arrow impl ::core::future::Future<Output = #ty> + #bounds
),
default@ReturnType::Default => default
};

Ok(quote! {
#[cfg(all(not(doc), rust_analyzer))]
#original

#[cfg(all(doc, not(rust_analyzer)))]
#doc
#original

#[cfg(not(any(doc, rust_analyzer)))]
#func
})
}

pub fn async_bound(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
) -> TokenStream {
_async_bound(args, input).unwrap_or_else(|d| d.emit_as_item_tokens())
}
1 change: 1 addition & 0 deletions core/codegen/src/attribute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod entry;
pub mod catch;
pub mod route;
pub mod param;
pub mod async_bound;
7 changes: 7 additions & 0 deletions core/codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1497,3 +1497,10 @@ pub fn internal_guide_tests(input: TokenStream) -> TokenStream {
pub fn export(input: TokenStream) -> TokenStream {
emit!(bang::export_internal(input))
}

/// Private Rocket attribute: `async_bound(Bounds + On + Returned + Future)`.
#[doc(hidden)]
#[proc_macro_attribute]
pub fn async_bound(args: TokenStream, input: TokenStream) -> TokenStream {
emit!(attribute::async_bound::async_bound(args, input))
}
24 changes: 0 additions & 24 deletions core/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,22 @@ rust-version = "1.64"

[features]
default = []
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
mtls = ["tls", "x509-parser"]
http2 = ["hyper/http2"]
private-cookies = ["cookie/private", "cookie/key-expansion"]
serde = ["uncased/with-serde-alloc", "serde_"]
uuid = ["uuid_"]

[dependencies]
smallvec = { version = "1.11", features = ["const_generics", "const_new"] }
percent-encoding = "2"
http = "0.2"
time = { version = "0.3", features = ["formatting", "macros"] }
indexmap = "2"
rustls = { version = "0.22", optional = true }
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
log = "0.4"
ref-cast = "1.0"
uncased = "0.9.6"
either = "1"
pear = "0.2.8"
pin-project-lite = "0.2"
memchr = "2"
stable-pattern = "0.1"
cookie = { version = "0.18", features = ["percent-encode"] }
state = "0.6"
futures = { version = "0.3", default-features = false }

[dependencies.x509-parser]
version = "0.13"
optional = true

[dependencies.hyper]
version = "0.14.9"
default-features = false
features = ["http1", "runtime", "server", "stream"]

[dependencies.serde_]
package = "serde"
Expand All @@ -67,6 +46,3 @@ package = "uuid"
version = "1"
optional = true
default-features = false

[dev-dependencies]
rocket = { path = "../lib", features = ["mtls"] }
3 changes: 1 addition & 2 deletions core/http/src/header/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,7 @@ impl<'h> HeaderMap<'h> {
/// WARNING: This is unstable! Do not use this method outside of Rocket!
#[doc(hidden)]
#[inline]
pub fn into_iter_raw(self)
-> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
pub fn into_iter_raw(self) -> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
self.headers.into_iter()
}
}
Expand Down
35 changes: 0 additions & 35 deletions core/http/src/hyper.rs

This file was deleted.

13 changes: 1 addition & 12 deletions core/http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
//! Types that map to concepts in HTTP.
//!
//! This module exports types that map to HTTP concepts or to the underlying
//! HTTP library when needed. Because the underlying HTTP library is likely to
//! change (see [#17]), types in [`hyper`] should be considered unstable.
//!
//! [#17]: https://github.com/rwf2/Rocket/issues/17
//! HTTP library when needed.
#[macro_use]
extern crate pear;

pub mod hyper;
pub mod uri;
pub mod ext;

Expand All @@ -22,7 +18,6 @@ mod method;
mod status;
mod raw_str;
mod parse;
mod listener;

/// Case-preserving, ASCII case-insensitive string types.
///
Expand All @@ -39,14 +34,8 @@ pub mod uncased {
pub mod private {
pub use crate::parse::Indexed;
pub use smallvec::{SmallVec, Array};
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, Certificates};
pub use cookie;
}

#[doc(hidden)]
#[cfg(feature = "tls")]
pub mod tls;

pub use crate::method::Method;
pub use crate::status::{Status, StatusClass};
pub use crate::raw_str::{RawStr, RawStrBuf};
Expand Down
Loading

0 comments on commit 2dfcc1e

Please sign in to comment.