diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs index 8c8ac3d09a..bd6013716a 100644 --- a/core/lib/src/listener/quic.rs +++ b/core/lib/src/listener/quic.rs @@ -2,52 +2,56 @@ use std::io; use std::fmt; use std::net::SocketAddr; -use bytes::Bytes; -use futures::Stream; use s2n_quic as quic; use s2n_quic_h3 as quic_h3; use quic_h3::h3 as h3; -use s2n_quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES}; -use s2n_quic::provider::tls::rustls::Server as H3TlsServer; +use bytes::Bytes; +use futures::Stream; use tokio::sync::Mutex; use tokio_stream::StreamExt; -use crate::listener::{Bindable, Listener}; use crate::tls::TlsConfig; +use crate::listener::{Listener, Connection, Endpoint}; -use super::{Connection, Endpoint}; - -pub struct QuicBindable { - pub address: SocketAddr, - pub tls: TlsConfig, -} +type H3Conn = h3::server::Connection; pub struct QuicListener { + endpoint: SocketAddr, listener: Mutex, - local_addr: SocketAddr, } -impl Bindable for QuicBindable { - type Listener = QuicListener; +pub struct H3Stream(H3Conn); + +pub struct H3Connection { + pub handle: quic::connection::Handle, + pub parts: http::request::Parts, + pub tx: QuicTx, + pub rx: QuicRx, +} + +pub struct QuicRx(h3::server::RequestStream); + +pub struct QuicTx(h3::server::RequestStream, Bytes>); - type Error = io::Error; +impl QuicListener { + pub async fn bind(address: SocketAddr, tls: TlsConfig) -> Result { + use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer}; - async fn bind(self) -> Result { - // FIXME: Remove this as soon as `s2n_quic` is on rustls 0.22. - let cert_chain = crate::tls::util::load_cert_chain(&mut self.tls.certs_reader().unwrap()) + // FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22. + let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap()) .unwrap() .into_iter() .map(|v| v.to_vec()) .map(rustls::Certificate) .collect::>(); - let key = crate::tls::util::load_key(&mut self.tls.key_reader().unwrap()) + let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap()) .unwrap() .secret_der() .to_vec(); - let mut tls = rustls::server::ServerConfig::builder() + let mut h3tls = rustls::server::ServerConfig::builder() .with_cipher_suites(DEFAULT_CIPHERSUITES) .with_safe_default_kx_groups() .with_safe_default_protocol_versions() @@ -56,40 +60,23 @@ impl Bindable for QuicBindable { .with_single_cert(cert_chain, rustls::PrivateKey(key)) .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; - tls.alpn_protocols = vec![b"h3".to_vec()]; - tls.ignore_client_order = self.tls.prefer_server_cipher_order; - tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024); - tls.ticketer = rustls::Ticketer::new() + h3tls.alpn_protocols = vec![b"h3".to_vec()]; + h3tls.ignore_client_order = tls.prefer_server_cipher_order; + h3tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024); + h3tls.ticketer = rustls::Ticketer::new() .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?; let listener = quic::Server::builder() - .with_tls(H3TlsServer::new(tls)) + .with_tls(H3TlsServer::new(h3tls)) .unwrap_or_else(|e| match e { }) - .with_io(self.address)? + .with_io(address)? .start() .map_err(io::Error::other)?; - let local_addr = listener.local_addr()?; - - Ok(QuicListener { listener: Mutex::new(listener), local_addr }) + Ok(QuicListener { endpoint: listener.local_addr()?, listener: Mutex::new(listener) }) } } -type H3Conn = h3::server::Connection; - -pub struct H3Stream(H3Conn); - -pub struct H3Connection { - pub handle: quic::connection::Handle, - pub parts: http::request::Parts, - pub tx: QuicTx, - pub rx: QuicRx, -} - -pub struct QuicRx(h3::server::RequestStream); - -pub struct QuicTx(h3::server::RequestStream, Bytes>); - impl Listener for QuicListener { type Accept = quic::Connection; @@ -109,7 +96,7 @@ impl Listener for QuicListener { } fn endpoint(&self) -> io::Result { - Ok(self.local_addr.into()) + Ok(self.endpoint.into()) } } @@ -152,8 +139,6 @@ impl QuicTx { } pub fn cancel(&mut self) { - use s2n_quic_h3::h3; - self.0.stop_stream(h3::error::Code::H3_NO_ERROR); } } @@ -200,28 +185,6 @@ mod async_traits { Poll::Ready(Ok(())) } } - - // impl AsyncWrite for QuicTx { - // fn poll_write( - // mut self: Pin<&mut Self>, - // cx: &mut Context<'_>, - // buf: &[u8], - // ) -> Poll> { - // let len = buf.len(); - // let result = ready!(self.0.poll_send_data(cx, Bytes::copy_from_slice(buf))); - // result.map_err(io::Error::other)?; - // Poll::Ready(Ok(len)) - // } - // - // fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - // Poll::Ready(Ok(())) - // } - // - // fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - // self.0.stop_stream(h3::error::Code::H3_NO_ERROR); - // Poll::Ready(Ok(())) - // } - // } } impl fmt::Debug for H3Stream { diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 6b4ec29f23..8d3ae3bba4 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -702,27 +702,7 @@ impl Rocket { return Err(ErrorKind::Liftoff(rocket, Box::new(e)).into()); } - #[cfg(not(feature = "http3"))] rocket.clone().serve(listener).await?; - - #[cfg(feature = "http3")] { - use crate::listener::quic::QuicBindable; - - let endpoint = rocket.endpoint(); - if let (Some(address), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) { - let quic_bindable = QuicBindable { address, tls: tls.clone() }; - let http3 = tokio::task::spawn(rocket.clone().serve3(quic_bindable.bind().await?)); - let http12 = tokio::task::spawn(rocket.clone().serve(listener)); - let (r1, r2) = tokio::join!(http12, http3); - r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; - r2.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; - } else { - warn!("HTTP/3 feature is enabled, but listener is not TCP/TLS."); - warn_!("HTTP/3 server cannot be started."); - rocket.clone().serve(listener).await?; - } - } - Ok(rocket.try_wait_shutdown().await.map_err(ErrorKind::Shutdown)?) } } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 7be2634a7f..fe5167ae78 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -28,7 +28,9 @@ impl Rocket { upgrade: Option, connection: ConnectionMeta, ) -> Result>, http::Error> { - let _http3_addr = self.endpoint().tls_config().and_then(|_| self.endpoint().tcp()); + #[cfg(feature = "http3")] + let _http3_addr = self.http3_config().map(|(addr, _)| addr); + let request = ErasedRequest::new(self, parts, |rocket, parts| { Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e) }); @@ -93,6 +95,33 @@ async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) impl Rocket { pub(crate) async fn serve(self: Arc, listener: L) -> Result<()> + where L: Listener + 'static, L::Connection: AsyncRead + AsyncWrite + { + #[cfg(not(feature = "http3"))] + self.clone().serve12(listener).await?; + + #[cfg(feature = "http3")] { + use crate::error::ErrorKind; + use crate::listener::quic::QuicListener; + + if let Some((address, tls)) = self.http3_config() { + let h3listener = QuicListener::bind(address, tls.clone()).await?; + let http3 = tokio::task::spawn(self.clone().serve3(h3listener)); + let http12 = tokio::task::spawn(self.clone().serve12(listener)); + let (r1, r2) = tokio::join!(http12, http3); + r1.map_err(|e| ErrorKind::Liftoff(Err(self.clone()), Box::new(e)))??; + r2.map_err(|e| ErrorKind::Liftoff(Err(self.clone()), Box::new(e)))??; + } else { + warn!("HTTP/3 cannot start without valid TCP/TLS configuration."); + info_!("Falling back to HTTP/1 and HTTP/2 server."); + self.clone().serve(listener).await?; + } + } + + Ok(()) + } + + pub(crate) async fn serve12(self: Arc, listener: L) -> Result<()> where L: Listener + 'static, L::Connection: AsyncRead + AsyncWrite { @@ -141,14 +170,14 @@ impl Rocket { Ok(()) } -} -#[cfg(feature = "http3")] -use crate::listener::quic::QuicListener; + #[cfg(feature = "http3")] + fn http3_config(&self) -> Option<(std::net::SocketAddr, &crate::tls::TlsConfig)> { + Some((self.endpoint().tcp()?, self.endpoint().tls_config()?)) + } -#[cfg(feature = "http3")] -impl Rocket { - pub(crate) async fn serve3(self: Arc, listener: QuicListener) -> Result<()> { + #[cfg(feature = "http3")] + async fn serve3(self: Arc, listener: crate::listener::quic::QuicListener) -> Result<()> { let rocket = self.clone(); let listener = Arc::new(listener.bounced()); while let Some(accept) = listener.accept().try_until(rocket.shutdown()).await? {