diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index c99940fbc32fb..0b570d9b2aaa1 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -40,9 +40,7 @@ use risingwave_pb::meta::subscribe_response::{Info, Operation}; use risingwave_pb::meta::PausedReason; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_pb::stream_service::barrier_complete_response::CreateMviewProgress; -use risingwave_pb::stream_service::{ - streaming_control_stream_response, BarrierCompleteResponse, StreamingControlStreamResponse, -}; +use risingwave_pb::stream_service::BarrierCompleteResponse; use thiserror_ext::AsReport; use tokio::sync::oneshot::{Receiver, Sender}; use tokio::sync::Mutex; @@ -670,16 +668,10 @@ impl GlobalBarrierManager { _ => {} } } - resp_result = self.control_stream_manager.next_response() => { + resp_result = self.control_stream_manager.next_complete_barrier_response() => { match resp_result { Ok((worker_id, prev_epoch, resp)) => { - let resp: StreamingControlStreamResponse = resp; - match resp.response { - Some(streaming_control_stream_response::Response::CompleteBarrier(resp)) => { - self.checkpoint_control.barrier_collected(worker_id, prev_epoch, resp); - }, - resp => unreachable!("invalid response: {:?}", resp), - } + self.checkpoint_control.barrier_collected(worker_id, prev_epoch, resp); } Err(e) => { diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index 251ec401fa2fb..f2ec59c3d6f48 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::assert_matches::assert_matches; use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -29,9 +28,6 @@ use risingwave_pb::meta::{PausedReason, Recovery}; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_pb::stream_plan::barrier_mutation::Mutation; use risingwave_pb::stream_plan::AddMutation; -use risingwave_pb::stream_service::{ - streaming_control_stream_response, StreamingControlStreamResponse, -}; use thiserror_ext::AsReport; use tokio::sync::oneshot; use tokio_retry::strategy::{jitter, ExponentialBackoff}; @@ -502,15 +498,10 @@ impl GlobalBarrierManager { let mut node_to_collect = control_stream_manager.inject_barrier(command_ctx.clone())?; while !node_to_collect.is_empty() { - let (worker_id, _, resp) = control_stream_manager.next_response().await?; - assert_matches!( - resp, - StreamingControlStreamResponse { - response: Some( - streaming_control_stream_response::Response::CompleteBarrier(_) - ) - } - ); + let (worker_id, prev_epoch, _) = control_stream_manager + .next_complete_barrier_response() + .await?; + assert_eq!(prev_epoch, command_ctx.prev_epoch.value().0); assert!(node_to_collect.remove(&worker_id)); } diff --git a/src/meta/src/barrier/rpc.rs b/src/meta/src/barrier/rpc.rs index a098627afcd0c..b7ea512ffbbfa 100644 --- a/src/meta/src/barrier/rpc.rs +++ b/src/meta/src/barrier/rpc.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::{HashMap, HashSet, VecDeque}; +use std::error::Error; use std::future::Future; use std::sync::Arc; use std::time::Duration; @@ -28,7 +29,7 @@ use risingwave_common::util::tracing::TracingContext; use risingwave_pb::common::{ActorInfo, WorkerNode}; use risingwave_pb::stream_plan::{Barrier, BarrierMutation, StreamActor}; use risingwave_pb::stream_service::{ - streaming_control_stream_request, streaming_control_stream_response, + streaming_control_stream_request, streaming_control_stream_response, BarrierCompleteResponse, BroadcastActorInfoTableRequest, BuildActorsRequest, DropActorsRequest, InjectBarrierRequest, StreamingControlStreamRequest, StreamingControlStreamResponse, UpdateActorsRequest, }; @@ -47,6 +48,8 @@ use super::GlobalBarrierManagerContext; use crate::manager::{MetaSrvEnv, WorkerId}; use crate::{MetaError, MetaResult}; +const COLLECT_ERROR_TIMEOUT: Duration = Duration::from_secs(3); + struct ControlStreamNode { worker: WorkerNode, sender: UnboundedSender, @@ -162,17 +165,25 @@ impl ControlStreamManager { Ok(()) } - pub(super) async fn next_response( + async fn next_response( + &mut self, + ) -> Option<(WorkerId, MetaResult)> { + let (worker_id, response_stream, result) = self.response_streams.next().await?; + if result.is_ok() { + self.response_streams + .push(into_future(worker_id, response_stream)); + } + Some((worker_id, result)) + } + + pub(super) async fn next_complete_barrier_response( &mut self, - ) -> MetaResult<(WorkerId, u64, StreamingControlStreamResponse)> { + ) -> MetaResult<(WorkerId, u64, BarrierCompleteResponse)> { loop { - let (worker_id, response_stream, result) = - pending_on_none(self.response_streams.next()).await; + let (worker_id, result) = pending_on_none(self.next_response()).await; match result { - Ok(resp) => match &resp.response { - Some(streaming_control_stream_response::Response::CompleteBarrier(_)) => { - self.response_streams - .push(into_future(worker_id, response_stream)); + Ok(resp) => match resp.response { + Some(streaming_control_stream_response::Response::CompleteBarrier(resp)) => { let node = self .nodes .get_mut(&worker_id) @@ -195,16 +206,39 @@ impl ControlStreamManager { // Note: No need to use `?` as the backtrace is from meta and not useful. warn!(node = ?node.worker, err = %err.as_report(), "get error from response stream"); if let Some(command) = node.inflight_barriers.pop_front() { + let errors = self.collect_errors(node.worker.id, err).await; + let err = merge_node_rpc_errors("get error from control stream", errors); self.context.report_collect_failure(&command, &err); break Err(err); } else { // for node with no inflight barrier, simply ignore the error + info!(node = ?node.worker, "no inflight barrier no node. Ignore error"); continue; } } } } } + + async fn collect_errors( + &mut self, + worker_id: WorkerId, + first_err: MetaError, + ) -> Vec<(WorkerId, MetaError)> { + let mut errors = vec![(worker_id, first_err)]; + #[cfg(not(madsim))] + { + let _ = timeout(COLLECT_ERROR_TIMEOUT, async { + while let Some((worker_id, result)) = self.next_response().await { + if let Err(e) = result { + errors.push((worker_id, e)); + } + } + }) + .await; + } + errors + } } impl ControlStreamManager { @@ -356,7 +390,7 @@ impl StreamRpcManager { let client = pool.get(node).await.map_err(|e| (node.id, e))?; f(client, input).await.map_err(|e| (node.id, e)) }); - let result = try_join_all_with_error_timeout(iters, Duration::from_secs(3)).await; + let result = try_join_all_with_error_timeout(iters, COLLECT_ERROR_TIMEOUT).await; result.map_err(|results_err| merge_node_rpc_errors("merged RPC Error", results_err)) } @@ -491,9 +525,9 @@ where Err(results_err) } -fn merge_node_rpc_errors( +fn merge_node_rpc_errors( message: &str, - errors: impl IntoIterator, + errors: impl IntoIterator, ) -> MetaError { use std::fmt::Write;