diff --git a/tonic/src/transport/server/io_stream.rs b/tonic/src/transport/server/io_stream.rs index e873e4def..e48f44c16 100644 --- a/tonic/src/transport/server/io_stream.rs +++ b/tonic/src/transport/server/io_stream.rs @@ -27,9 +27,7 @@ where #[pin] inner: S, #[cfg(feature = "_tls-any")] - tls: Option, - #[cfg(feature = "_tls-any")] - tasks: JoinSet, crate::BoxError>>, + state: Option<(TlsAcceptor, JoinSet, crate::BoxError>>)>, } impl ServerIoStream @@ -40,23 +38,17 @@ where Self { inner: incoming, #[cfg(feature = "_tls-any")] - tls, - #[cfg(feature = "_tls-any")] - tasks: JoinSet::new(), + state: tls.map(|tls| (tls, JoinSet::new())), } } -} -impl Stream for ServerIoStream -where - S: Stream>, - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, - IE: Into, -{ - type Item = Result, crate::BoxError>; - - #[cfg(not(feature = "_tls-any"))] - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next_without_tls( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, crate::BoxError>>> + where + IE: Into, + { match ready!(self.as_mut().project().inner.as_mut().poll_next(cx)) { Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))), Some(Err(e)) => match handle_tcp_accept_error(e) { @@ -69,29 +61,40 @@ where None => Poll::Ready(None), } } +} + +impl Stream for ServerIoStream +where + S: Stream>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, +{ + type Item = Result, crate::BoxError>; + + #[cfg(not(feature = "_tls-any"))] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_without_tls(cx) + } #[cfg(feature = "_tls-any")] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut projected = self.as_mut().project(); - let tls = projected.tls; - let tasks = projected.tasks; + let Some((tls, tasks)) = projected.state else { + return self.poll_next_without_tls(cx); + }; let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx)); match select_output { SelectOutput::Incoming(stream) => { - if let Some(tls) = tls { - let tls = tls.clone(); - tasks.spawn(async move { - let io = tls.accept(stream).await?; - Ok(ServerIo::new_tls_io(io)) - }); - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(Some(Ok(ServerIo::new_io(stream)))) - } + let tls = tls.clone(); + tasks.spawn(async move { + let io = tls.accept(stream).await?; + Ok(ServerIo::new_tls_io(io)) + }); + cx.waker().wake_by_ref(); + Poll::Pending } SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),