Skip to content

Commit

Permalink
wip: cleanup + add http3 to testing matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Mar 15, 2024
1 parent 6e66806 commit 300b0d5
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 96 deletions.
101 changes: 32 additions & 69 deletions core/lib/src/listener/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,56 @@ use std::io;
use std::fmt;
use std::net::SocketAddr;

use bytes::Bytes;
use futures::Stream;
use s2n_quic as quic;
use s2n_quic_h3 as quic_h3;
use quic_h3::h3 as h3;
use s2n_quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES};
use s2n_quic::provider::tls::rustls::Server as H3TlsServer;

use bytes::Bytes;
use futures::Stream;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

use crate::listener::{Bindable, Listener};
use crate::tls::TlsConfig;
use crate::listener::{Listener, Connection, Endpoint};

use super::{Connection, Endpoint};

pub struct QuicBindable {
pub address: SocketAddr,
pub tls: TlsConfig,
}
type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;

pub struct QuicListener {
endpoint: SocketAddr,
listener: Mutex<quic::Server>,
local_addr: SocketAddr,
}

impl Bindable for QuicBindable {
type Listener = QuicListener;
pub struct H3Stream(H3Conn);

pub struct H3Connection {
pub handle: quic::connection::Handle,
pub parts: http::request::Parts,
pub tx: QuicTx,
pub rx: QuicRx,
}

pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);

pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);

type Error = io::Error;
impl QuicListener {
pub async fn bind(address: SocketAddr, tls: TlsConfig) -> Result<Self, io::Error> {
use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer};

async fn bind(self) -> Result<Self::Listener, Self::Error> {
// FIXME: Remove this as soon as `s2n_quic` is on rustls 0.22.
let cert_chain = crate::tls::util::load_cert_chain(&mut self.tls.certs_reader().unwrap())
// FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22.
let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap())
.unwrap()
.into_iter()
.map(|v| v.to_vec())
.map(rustls::Certificate)
.collect::<Vec<_>>();

let key = crate::tls::util::load_key(&mut self.tls.key_reader().unwrap())
let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap())
.unwrap()
.secret_der()
.to_vec();

let mut tls = rustls::server::ServerConfig::builder()
let mut h3tls = rustls::server::ServerConfig::builder()
.with_cipher_suites(DEFAULT_CIPHERSUITES)
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
Expand All @@ -56,40 +60,23 @@ impl Bindable for QuicBindable {
.with_single_cert(cert_chain, rustls::PrivateKey(key))
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;

tls.alpn_protocols = vec![b"h3".to_vec()];
tls.ignore_client_order = self.tls.prefer_server_cipher_order;
tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024);
tls.ticketer = rustls::Ticketer::new()
h3tls.alpn_protocols = vec![b"h3".to_vec()];
h3tls.ignore_client_order = tls.prefer_server_cipher_order;
h3tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024);
h3tls.ticketer = rustls::Ticketer::new()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?;

let listener = quic::Server::builder()
.with_tls(H3TlsServer::new(tls))
.with_tls(H3TlsServer::new(h3tls))
.unwrap_or_else(|e| match e { })
.with_io(self.address)?
.with_io(address)?
.start()
.map_err(io::Error::other)?;

let local_addr = listener.local_addr()?;

Ok(QuicListener { listener: Mutex::new(listener), local_addr })
Ok(QuicListener { endpoint: listener.local_addr()?, listener: Mutex::new(listener) })
}
}

type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;

pub struct H3Stream(H3Conn);

pub struct H3Connection {
pub handle: quic::connection::Handle,
pub parts: http::request::Parts,
pub tx: QuicTx,
pub rx: QuicRx,
}

pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);

pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);

impl Listener for QuicListener {
type Accept = quic::Connection;

Expand All @@ -109,7 +96,7 @@ impl Listener for QuicListener {
}

fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.local_addr.into())
Ok(self.endpoint.into())
}
}

Expand Down Expand Up @@ -152,8 +139,6 @@ impl QuicTx {
}

pub fn cancel(&mut self) {
use s2n_quic_h3::h3;

self.0.stop_stream(h3::error::Code::H3_NO_ERROR);
}
}
Expand Down Expand Up @@ -200,28 +185,6 @@ mod async_traits {
Poll::Ready(Ok(()))
}
}

// impl AsyncWrite for QuicTx {
// fn poll_write(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &[u8],
// ) -> Poll<io::Result<usize>> {
// let len = buf.len();
// let result = ready!(self.0.poll_send_data(cx, Bytes::copy_from_slice(buf)));
// result.map_err(io::Error::other)?;
// Poll::Ready(Ok(len))
// }
//
// fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// Poll::Ready(Ok(()))
// }
//
// fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// self.0.stop_stream(h3::error::Code::H3_NO_ERROR);
// Poll::Ready(Ok(()))
// }
// }
}

impl fmt::Debug for H3Stream {
Expand Down
20 changes: 0 additions & 20 deletions core/lib/src/rocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,27 +702,7 @@ impl Rocket<Ignite> {
return Err(ErrorKind::Liftoff(rocket, Box::new(e)).into());
}

#[cfg(not(feature = "http3"))]
rocket.clone().serve(listener).await?;

#[cfg(feature = "http3")] {
use crate::listener::quic::QuicBindable;

let endpoint = rocket.endpoint();
if let (Some(address), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) {
let quic_bindable = QuicBindable { address, tls: tls.clone() };
let http3 = tokio::task::spawn(rocket.clone().serve3(quic_bindable.bind().await?));
let http12 = tokio::task::spawn(rocket.clone().serve(listener));
let (r1, r2) = tokio::join!(http12, http3);
r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??;
r2.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??;
} else {
warn!("HTTP/3 feature is enabled, but listener is not TCP/TLS.");
warn_!("HTTP/3 server cannot be started.");
rocket.clone().serve(listener).await?;
}
}

Ok(rocket.try_wait_shutdown().await.map_err(ErrorKind::Shutdown)?)
}
}
Expand Down
43 changes: 36 additions & 7 deletions core/lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ impl Rocket<Orbit> {
upgrade: Option<hyper::upgrade::OnUpgrade>,
connection: ConnectionMeta,
) -> Result<hyper::Response<ReaderStream<ErasedResponse>>, http::Error> {
let _http3_addr = self.endpoint().tls_config().and_then(|_| self.endpoint().tcp());
#[cfg(feature = "http3")]
let _http3_addr = self.http3_config().map(|(addr, _)| addr);

let request = ErasedRequest::new(self, parts, |rocket, parts| {
Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e)
});
Expand Down Expand Up @@ -93,6 +95,33 @@ async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)

impl Rocket<Orbit> {
pub(crate) async fn serve<L>(self: Arc<Self>, listener: L) -> Result<()>
where L: Listener + 'static, L::Connection: AsyncRead + AsyncWrite
{
#[cfg(not(feature = "http3"))]
self.clone().serve12(listener).await?;

#[cfg(feature = "http3")] {
use crate::error::ErrorKind;
use crate::listener::quic::QuicListener;

if let Some((address, tls)) = self.http3_config() {
let h3listener = QuicListener::bind(address, tls.clone()).await?;
let http3 = tokio::task::spawn(self.clone().serve3(h3listener));
let http12 = tokio::task::spawn(self.clone().serve12(listener));
let (r1, r2) = tokio::join!(http12, http3);
r1.map_err(|e| ErrorKind::Liftoff(Err(self.clone()), Box::new(e)))??;
r2.map_err(|e| ErrorKind::Liftoff(Err(self.clone()), Box::new(e)))??;
} else {
warn!("HTTP/3 cannot start without valid TCP/TLS configuration.");
info_!("Falling back to HTTP/1 and HTTP/2 server.");
self.clone().serve12(listener).await?;
}
}

Ok(())
}

pub(crate) async fn serve12<L>(self: Arc<Self>, listener: L) -> Result<()>
where L: Listener + 'static,
L::Connection: AsyncRead + AsyncWrite
{
Expand Down Expand Up @@ -141,14 +170,14 @@ impl Rocket<Orbit> {

Ok(())
}
}

#[cfg(feature = "http3")]
use crate::listener::quic::QuicListener;
#[cfg(feature = "http3")]
fn http3_config(&self) -> Option<(std::net::SocketAddr, &crate::tls::TlsConfig)> {
Some((self.endpoint().tcp()?, self.endpoint().tls_config()?))
}

#[cfg(feature = "http3")]
impl Rocket<Orbit> {
pub(crate) async fn serve3(self: Arc<Self>, listener: QuicListener) -> Result<()> {
#[cfg(feature = "http3")]
async fn serve3(self: Arc<Self>, listener: crate::listener::quic::QuicListener) -> Result<()> {
let rocket = self.clone();
let listener = Arc::new(listener.bounced());
while let Some(accept) = listener.accept().try_until(rocket.shutdown()).await? {
Expand Down
1 change: 1 addition & 0 deletions scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ function test_core() {
FEATURES=(
tokio-macros
http2
http3
secrets
tls
mtls
Expand Down

0 comments on commit 300b0d5

Please sign in to comment.