From 8e9d3226d61cfc0cc75cee8bfc2a1b452f18a066 Mon Sep 17 00:00:00 2001 From: chubei <914745487@qq.com> Date: Wed, 27 Sep 2023 16:13:43 +0800 Subject: [PATCH] fix: `AddrStream::pool_shutdown` was not called pooled until completion --- dozer-api/src/grpc/mod.rs | 80 ++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 48 deletions(-) diff --git a/dozer-api/src/grpc/mod.rs b/dozer-api/src/grpc/mod.rs index c1898fadd2..b56696921c 100644 --- a/dozer-api/src/grpc/mod.rs +++ b/dozer-api/src/grpc/mod.rs @@ -33,38 +33,38 @@ use tower::{Layer, Service}; use crate::shutdown::ShutdownReceiver; #[derive(Debug)] -enum ShutdownAddrStream { - Alive { inner: AddrStream, shutdown: F }, - Shutdown { inner: AddrStream }, - Temp, +struct ShutdownAddrStream { + inner: AddrStream, + state: ShutdownState, +} + +#[derive(Debug)] +enum ShutdownState { + SignalPending(F), + ShutdownPending, + Done, } impl + Unpin> ShutdownAddrStream { fn check_shutdown(&mut self, cx: &mut Context<'_>) -> Result<(), io::Error> { - loop { - match self { - Self::Alive { shutdown, .. } => { - if let Poll::Ready(()) = Pin::new(shutdown).poll(cx) { - let mut temp = Self::Temp; - std::mem::swap(self, &mut temp); - let Self::Alive { inner, .. } = temp else { - unreachable!() - }; - *self = Self::Shutdown { inner }; - continue; - } else { - return Ok(()); - } + match &mut self.state { + ShutdownState::SignalPending(signal) => { + if let Poll::Ready(()) = Pin::new(signal).poll(cx) { + self.state = ShutdownState::ShutdownPending; + self.check_shutdown(cx) + } else { + Ok(()) } - Self::Shutdown { inner } => { - if let Poll::Ready(Err(e)) = Pin::new(inner).poll_shutdown(cx) { - return Err(e); - } else { - return Ok(()); - } - } - Self::Temp => unreachable!(), } + ShutdownState::ShutdownPending => match Pin::new(&mut self.inner).poll_shutdown(cx) { + Poll::Ready(Ok(())) => { + self.state = ShutdownState::Done; + Ok(()) + } + Poll::Ready(Err(e)) => Err(e), + Poll::Pending => Ok(()), + }, + ShutdownState::Done => Ok(()), } } @@ -78,11 +78,7 @@ impl + Unpin> ShutdownAddrStream { return Poll::Ready(Err(e)); } - match this { - Self::Alive { inner, .. } => func(Pin::new(inner), cx), - Self::Shutdown { inner } => func(Pin::new(inner), cx), - Self::Temp => unreachable!(), - } + func(Pin::new(&mut this.inner), cx) } } @@ -97,11 +93,7 @@ impl + Unpin> AsyncRead for ShutdownAddrStream { return Poll::Ready(Err(e)); } - match this { - Self::Alive { inner, .. } => Pin::new(inner).poll_read(cx, buf), - Self::Shutdown { inner } => Pin::new(inner).poll_read(cx, buf), - Self::Temp => unreachable!(), - } + Pin::new(&mut this.inner).poll_read(cx, buf) } } @@ -116,11 +108,7 @@ impl + Unpin> AsyncWrite for ShutdownAddrStream { return Poll::Ready(Err(e)); } - match this { - Self::Alive { inner, .. } => Pin::new(inner).poll_write(cx, buf), - Self::Shutdown { inner } => Pin::new(inner).poll_write(cx, buf), - Self::Temp => unreachable!(), - } + Pin::new(&mut this.inner).poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -136,11 +124,7 @@ impl Connected for ShutdownAddrStream { type ConnectInfo = TcpConnectInfo; fn connect_info(&self) -> Self::ConnectInfo { - match self { - Self::Alive { inner, .. } => inner.connect_info(), - Self::Shutdown { inner } => inner.connect_info(), - Self::Temp => unreachable!(), - } + self.inner.connect_info() } } @@ -160,9 +144,9 @@ where let incoming = incoming.map(|stream| { stream.map(|stream| { let shutdown = shutdown.create_shutdown_future(); - ShutdownAddrStream::Alive { + ShutdownAddrStream { inner: stream, - shutdown: Box::pin(shutdown), + state: ShutdownState::SignalPending(Box::pin(shutdown)), } }) });