From 77ea801609ca90826d8999d0c4707a9043a24613 Mon Sep 17 00:00:00 2001
From: William Wen <william123.wen@gmail.com>
Date: Mon, 18 Nov 2024 18:04:25 +0800
Subject: [PATCH] store pending state

---
 src/stream/src/task/barrier_manager.rs        |   4 +-
 .../src/task/barrier_manager/managed_state.rs | 152 ++++++++++--------
 2 files changed, 91 insertions(+), 65 deletions(-)

diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs
index b618f72f9607f..ec8251f857a12 100644
--- a/src/stream/src/task/barrier_manager.rs
+++ b/src/stream/src/task/barrier_manager.rs
@@ -320,8 +320,8 @@ impl LocalBarrierWorker {
         loop {
             select! {
                 biased;
-                (partial_graph_id, barrier, create_mview_progress, table_ids) = self.state.next_collected_epoch() => {
-                    self.complete_barrier(partial_graph_id, barrier, create_mview_progress, table_ids);
+                (partial_graph_id, barrier) = self.state.next_collected_epoch() => {
+                    self.complete_barrier(partial_graph_id, barrier.epoch.prev);
                 }
                 (partial_graph_id, barrier, result) = rw_futures_util::pending_on_none(self.await_epoch_completed_futures.next()) => {
                     match result {
diff --git a/src/stream/src/task/barrier_manager/managed_state.rs b/src/stream/src/task/barrier_manager/managed_state.rs
index 0710c794b4adc..6be6e2446ec4b 100644
--- a/src/stream/src/task/barrier_manager/managed_state.rs
+++ b/src/stream/src/task/barrier_manager/managed_state.rs
@@ -16,6 +16,7 @@ use std::cell::LazyCell;
 use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
 use std::fmt::{Debug, Display, Formatter};
 use std::future::{pending, poll_fn, Future};
+use std::mem::replace;
 use std::sync::Arc;
 use std::task::Poll;
 
@@ -44,24 +45,32 @@ struct IssuedState {
     pub remaining_actors: BTreeSet<ActorId>,
 
     pub barrier_inflight_latency: HistogramTimer,
-
-    /// Only be `Some(_)` when `kind` is `Checkpoint`
-    pub table_ids: Option<HashSet<TableId>>,
 }
 
 impl Debug for IssuedState {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("IssuedState")
             .field("remaining_actors", &self.remaining_actors)
-            .field("table_ids", &self.table_ids)
             .finish()
     }
 }
 
+/// The state machine of local barrier manager.
+#[derive(Debug)]
+enum ManagedBarrierStateInner {
+    /// Meta service has issued a `send_barrier` request. We're collecting barriers now.
+    Issued(IssuedState),
+
+    /// The barrier has been collected by all remaining actors
+    AllCollected(Vec<PbCreateMviewProgress>),
+}
+
 #[derive(Debug)]
 pub(super) struct BarrierState {
     barrier: Barrier,
-    inner: IssuedState,
+    /// Only be `Some(_)` when `barrier.kind` is `Checkpoint`
+    table_ids: Option<HashSet<TableId>>,
+    inner: ManagedBarrierStateInner,
 }
 
 mod await_epoch_completed_future {
@@ -119,6 +128,7 @@ mod await_epoch_completed_future {
 }
 
 pub(crate) use await_epoch_completed_future::*;
+use risingwave_common::must_match;
 use risingwave_pb::stream_plan::SubscriptionUpstreamInfo;
 use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress;
 use risingwave_pb::stream_service::streaming_control_stream_request::InitialPartialGraph;
@@ -167,9 +177,8 @@ impl Display for &'_ PartialGraphManagedBarrierState {
         let mut prev_epoch = 0u64;
         for (epoch, barrier_state) in &self.epoch_barrier_state_map {
             write!(f, "> Epoch {}: ", epoch)?;
-            {
-                {
-                    let state = &barrier_state.inner;
+            match &barrier_state.inner {
+                ManagedBarrierStateInner::Issued(state) => {
                     write!(
                         f,
                         "Issued [{:?}]. Remaining actors: [",
@@ -178,7 +187,10 @@ impl Display for &'_ PartialGraphManagedBarrierState {
                     let mut is_prev_epoch_issued = false;
                     if prev_epoch != 0 {
                         let bs = &self.epoch_barrier_state_map[&prev_epoch];
-                        let remaining_actors_prev = &bs.inner.remaining_actors;
+                        if let ManagedBarrierStateInner::Issued(IssuedState {
+                            remaining_actors: remaining_actors_prev,
+                            ..
+                        }) = &bs.inner
                         {
                             // Only show the actors that are not in the previous epoch.
                             is_prev_epoch_issued = true;
@@ -202,6 +214,9 @@ impl Display for &'_ PartialGraphManagedBarrierState {
                     }
                     write!(f, "]")?;
                 }
+                ManagedBarrierStateInner::AllCollected(_) => {
+                    write!(f, "AllCollected")?;
+                }
             }
             prev_epoch = *epoch;
             writeln!(f)?;
@@ -639,23 +654,14 @@ impl ManagedBarrierState {
 
     pub(super) fn next_collected_epoch(
         &mut self,
-    ) -> impl Future<
-        Output = (
-            PartialGraphId,
-            Barrier,
-            Vec<PbCreateMviewProgress>,
-            Option<HashSet<TableId>>,
-        ),
-    > {
+    ) -> impl Future<Output = (PartialGraphId, Barrier)> {
         let mut output = None;
         for (partial_graph_id, graph_state) in &mut self.graph_states {
-            if let Some((barrier, create_mview_progress, table_ids)) =
-                graph_state.may_have_collected_all()
-            {
+            if let Some(barrier) = graph_state.may_have_collected_all() {
                 if let Some(actors_to_stop) = barrier.all_stop_actors() {
                     self.current_shared_context.drop_actors(actors_to_stop);
                 }
-                output = Some((*partial_graph_id, barrier, create_mview_progress, table_ids));
+                output = Some((*partial_graph_id, barrier));
                 break;
             }
         }
@@ -688,29 +694,21 @@ impl PartialGraphManagedBarrierState {
     /// This method is called when barrier state is modified in either `Issued` or `Stashed`
     /// to transform the state to `AllCollected` and start state store `sync` when the barrier
     /// has been collected from all actors for an `Issued` barrier.
-    fn may_have_collected_all(
-        &mut self,
-    ) -> Option<(
-        Barrier,
-        Vec<PbCreateMviewProgress>,
-        Option<HashSet<TableId>>,
-    )> {
-        if let Some((_, barrier_state)) = self.epoch_barrier_state_map.first_key_value()
-            && barrier_state.inner.remaining_actors.is_empty()
-        {
-            self.streaming_metrics.barrier_manager_progress.inc();
-
-            let (_, barrier_state) = self.epoch_barrier_state_map.pop_first().expect("non-empty");
+    fn may_have_collected_all(&mut self) -> Option<Barrier> {
+        for barrier_state in self.epoch_barrier_state_map.values_mut() {
+            match &barrier_state.inner {
+                ManagedBarrierStateInner::Issued(IssuedState {
+                    remaining_actors, ..
+                }) if remaining_actors.is_empty() => {}
+                ManagedBarrierStateInner::AllCollected(_) => {
+                    continue;
+                }
+                ManagedBarrierStateInner::Issued(_) => {
+                    break;
+                }
+            }
 
-            let table_ids = {
-                let IssuedState {
-                    barrier_inflight_latency: timer,
-                    table_ids,
-                    ..
-                } = barrier_state.inner;
-                timer.observe_duration();
-                table_ids
-            };
+            self.streaming_metrics.barrier_manager_progress.inc();
 
             let create_mview_progress = self
                 .create_mview_progress
@@ -719,24 +717,44 @@ impl PartialGraphManagedBarrierState {
                 .into_iter()
                 .map(|(actor, state)| state.to_pb(actor))
                 .collect();
-            Some((barrier_state.barrier, create_mview_progress, table_ids))
-        } else {
-            None
+
+            let prev_state = replace(
+                &mut barrier_state.inner,
+                ManagedBarrierStateInner::AllCollected(create_mview_progress),
+            );
+
+            must_match!(prev_state, ManagedBarrierStateInner::Issued(IssuedState {
+                barrier_inflight_latency: timer,
+                ..
+            }) => {
+                timer.observe_duration();
+            });
+
+            return Some(barrier_state.barrier.clone());
         }
+        None
     }
 }
 
 impl LocalBarrierWorker {
-    pub(super) fn complete_barrier(
-        &mut self,
-        partial_graph_id: PartialGraphId,
-        barrier: Barrier,
-        create_mview_progress: Vec<PbCreateMviewProgress>,
-        table_ids: Option<HashSet<TableId>>,
-    ) {
+    pub(super) fn complete_barrier(&mut self, partial_graph_id: PartialGraphId, prev_epoch: u64) {
         {
-            let prev_epoch = barrier.epoch.prev;
-            let complete_barrier_future = match &barrier.kind {
+            let (popped_prev_epoch, barrier_state) = self
+                .state
+                .graph_states
+                .get_mut(&partial_graph_id)
+                .expect("should exist")
+                .epoch_barrier_state_map
+                .pop_first()
+                .expect("should exist");
+
+            assert_eq!(prev_epoch, popped_prev_epoch);
+
+            let create_mview_progress = must_match!(barrier_state.inner, ManagedBarrierStateInner::AllCollected(create_mview_progress) => {
+                create_mview_progress
+            });
+
+            let complete_barrier_future = match &barrier_state.barrier.kind {
                 BarrierKind::Unspecified => unreachable!(),
                 BarrierKind::Initial => {
                     tracing::info!(
@@ -753,7 +771,9 @@ impl LocalBarrierWorker {
                             state_store,
                             &self.actor_manager.streaming_metrics,
                             prev_epoch,
-                            table_ids.expect("should be Some on BarrierKind::Checkpoint"),
+                            barrier_state
+                                .table_ids
+                                .expect("should be Some on BarrierKind::Checkpoint"),
                         ))
                     })
                 }
@@ -763,7 +783,7 @@ impl LocalBarrierWorker {
                 instrument_complete_barrier_future(
                     partial_graph_id,
                     complete_barrier_future,
-                    barrier,
+                    barrier_state.barrier,
                     self.actor_manager.await_tree_reg.as_ref(),
                     create_mview_progress,
                 )
@@ -794,10 +814,10 @@ impl PartialGraphManagedBarrierState {
             Some(&mut BarrierState {
                 ref barrier,
                 inner:
-                    IssuedState {
+                    ManagedBarrierStateInner::Issued(IssuedState {
                         ref mut remaining_actors,
                         ..
-                    },
+                    }),
                 ..
             }) => {
                 let exist = remaining_actors.remove(&actor_id);
@@ -808,6 +828,12 @@ impl PartialGraphManagedBarrierState {
                 );
                 assert_eq!(barrier.epoch.curr, epoch.curr);
             }
+            Some(BarrierState { inner, .. }) => {
+                panic!(
+                    "cannot collect new actor barrier {:?} at current state: {:?}",
+                    epoch, inner
+                )
+            }
         }
     }
 
@@ -882,18 +908,18 @@ impl PartialGraphManagedBarrierState {
             barrier.epoch.prev,
             BarrierState {
                 barrier: barrier.clone(),
-                inner: IssuedState {
+                inner: ManagedBarrierStateInner::Issued(IssuedState {
                     remaining_actors: BTreeSet::from_iter(actor_ids_to_collect),
                     barrier_inflight_latency: timer,
-                    table_ids,
-                },
+                }),
+                table_ids,
             },
         );
     }
 
     #[cfg(test)]
     async fn pop_next_completed_epoch(&mut self) -> u64 {
-        if let Some((barrier, _, _)) = self.may_have_collected_all() {
+        if let Some(barrier) = self.may_have_collected_all() {
             return barrier.epoch.prev;
         }
         pending().await