Skip to content

Commit

Permalink
chore(server): Use same non tls logic at server io stream
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Dec 4, 2024
1 parent d2f0d97 commit df13009
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions tonic/src/transport/server/io_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ use super::service::ServerIo;
#[cfg(feature = "_tls-any")]
use super::service::TlsAcceptor;

#[cfg(feature = "_tls-any")]
struct State<IO>(TlsAcceptor, JoinSet<Result<ServerIo<IO>, crate::BoxError>>);

#[pin_project]
pub(crate) struct ServerIoStream<S, IO, IE>
where
Expand All @@ -27,9 +30,7 @@ where
#[pin]
inner: S,
#[cfg(feature = "_tls-any")]
tls: Option<TlsAcceptor>,
#[cfg(feature = "_tls-any")]
tasks: JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
state: Option<State<IO>>,
}

impl<S, IO, IE> ServerIoStream<S, IO, IE>
Expand All @@ -40,23 +41,17 @@ where
Self {
inner: incoming,
#[cfg(feature = "_tls-any")]
tls,
#[cfg(feature = "_tls-any")]
tasks: JoinSet::new(),
state: tls.map(|tls| State(tls, JoinSet::new())),
}
}
}

impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
where
S: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
{
type Item = Result<ServerIo<IO>, crate::BoxError>;

#[cfg(not(feature = "_tls-any"))]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next_without_tls(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<ServerIo<IO>, crate::BoxError>>>
where
IE: Into<crate::BoxError>,
{
match ready!(self.as_mut().project().inner.as_mut().poll_next(cx)) {
Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
Some(Err(e)) => match handle_tcp_accept_error(e) {
Expand All @@ -69,29 +64,40 @@ where
None => Poll::Ready(None),
}
}
}

impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
where
S: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
{
type Item = Result<ServerIo<IO>, crate::BoxError>;

#[cfg(not(feature = "_tls-any"))]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_next_without_tls(cx)
}

#[cfg(feature = "_tls-any")]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut projected = self.as_mut().project();

let tls = projected.tls;
let tasks = projected.tasks;
let Some(State(tls, tasks)) = projected.state else {
return self.poll_next_without_tls(cx);
};

let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));

match select_output {
SelectOutput::Incoming(stream) => {
if let Some(tls) = tls {
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(Some(Ok(ServerIo::new_io(stream))))
}
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
cx.waker().wake_by_ref();
Poll::Pending
}

SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),
Expand Down

0 comments on commit df13009

Please sign in to comment.