diff --git a/e2e_test/streaming/bug_fixes/issue_12140.slt b/e2e_test/streaming/bug_fixes/issue_12140.slt new file mode 100644 index 0000000000000..2240762868832 --- /dev/null +++ b/e2e_test/streaming/bug_fixes/issue_12140.slt @@ -0,0 +1,75 @@ +# https://github.com/risingwavelabs/risingwave/issues/12140 + +statement ok +CREATE TABLE t (c3 INT, c9 CHARACTER VARYING); + +statement ok +INSERT INTO t VALUES (1, 'interesting'), (2, 'interesting'), (3, 'interesting'), (4, 'IwfwuseZmg'), (5, 'ZuT4aIQVhA'); + +statement ok +CREATE MATERIALIZED VIEW mv AS +SELECT + first_value(DISTINCT t.c9 ORDER BY t.c9 ASC) +FROM + t; + +statement ok +DELETE FROM t WHERE c3 = 1; + +statement ok +DELETE FROM t WHERE c3 = 2; + +statement ok +DELETE FROM t WHERE c3 = 3; + +statement ok +drop materialized view mv; + +statement ok +drop table t; + +statement ok +CREATE TABLE t (c3 INT, c9 CHARACTER VARYING); + +statement ok +INSERT INTO t VALUES (1, 'interesting'), (2, 'interesting'), (3, 'interesting'), (1, 'boring'), (2, 'boring'), (3, 'boring'), (1, 'exciting'), (2, 'exciting'), (3, 'exciting'), (4, 'IwfwuseZmg'), (5, 'ZuT4aIQVhA'); + +statement ok +CREATE MATERIALIZED VIEW mv AS +SELECT + first_value(DISTINCT t.c9 ORDER BY t.c9 ASC), last_value(distinct c3 order by c3 asc) +FROM + t; + +statement ok +DELETE FROM t WHERE c3 = 1 and c9 = 'interesting'; + +statement ok +DELETE FROM t WHERE c3 = 2 and c9 = 'interesting'; + +statement ok +DELETE FROM t WHERE c3 = 3 and c9 = 'interesting'; + +statement ok +DELETE FROM t WHERE c3 = 1 and c9 = 'boring'; + +statement ok +DELETE FROM t WHERE c3 = 1 and c9 = 'exciting'; + +statement ok +DELETE FROM t WHERE c3 = 2 and c9 = 'boring'; + +statement ok +DELETE FROM t WHERE c3 = 2 and c9 = 'exciting'; + +statement ok +DELETE FROM t WHERE c3 = 3 and c9 = 'boring'; + +statement ok +DELETE FROM t WHERE c3 = 3 and c9 = 'exciting'; + +statement ok +drop materialized view mv; + +statement ok +drop table t; \ No newline at end of file diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 149fba8439db1..2b7e224447dba 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -249,6 +249,16 @@ message AggCallState { reserved "table_state"; } +enum AggNodeVersion { + AGG_NODE_VERSION_UNSPECIFIED = 0; + + // https://github.com/risingwavelabs/risingwave/issues/12140#issuecomment-1776289808 + AGG_NODE_VERSION_ISSUE_12140 = 1; + + // Used for test only. + AGG_NODE_VERSION_MAX = 2147483647; +} + message SimpleAggNode { repeated expr.AggCall agg_calls = 1; // Only used for stateless simple agg. @@ -260,6 +270,7 @@ message SimpleAggNode { bool is_append_only = 5; map distinct_dedup_tables = 6; uint32 row_count_index = 7; + AggNodeVersion version = 8; } message HashAggNode { @@ -273,6 +284,7 @@ message HashAggNode { map distinct_dedup_tables = 6; uint32 row_count_index = 7; bool emit_on_window_close = 8; + AggNodeVersion version = 9; } message TopNNode { diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 99aa94ff773b9..2b69b5a53145d 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -912,6 +912,12 @@ expected_outputs: - batch_plan - stream_plan + - stream_dist_plan # check the state table schema +- sql: | + create table t (x int, y int); + select first_value(distinct x order by x asc) from t; + expected_outputs: + - stream_dist_plan # check the state table schema - sql: | create table t (x int, y int); select last_value(x order by y desc nulls last) from t; diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index aefb4df98ef4e..5d6ceab766ba2 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1613,6 +1613,108 @@ └─StreamSimpleAgg { aggs: [first_value(t.x order_by(t.y ASC)), count] } └─StreamExchange { dist: Single } └─StreamTableScan { table: t, columns: [t.x, t.y, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + stream_dist_plan: |+ + Fragment 0 + StreamMaterialize { columns: [first_value], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + ├── materialized table: 4294967294 + └── StreamProject { exprs: [first_value(t.x order_by(t.y ASC))] } + └── StreamSimpleAgg { aggs: [first_value(t.x order_by(t.y ASC)), count] } + ├── intermediate state table: 1 + ├── state tables: [ 0 ] + ├── distinct tables: [] + └── StreamExchange Single from 1 + + Fragment 1 + Chain { table: t, columns: [t.x, t.y, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + ├── state table: 2 + ├── Upstream + └── BatchPlanNode + + Table 0 + ├── columns: [ t_y, t__row_id, t_x ] + ├── primary key: [ $0 ASC, $1 ASC ] + ├── value indices: [ 0, 1, 2 ] + ├── distribution key: [] + └── read pk prefix len hint: 0 + + Table 1 + ├── columns: [ first_value(t_x order_by(t_y ASC)), count ] + ├── primary key: [] + ├── value indices: [ 0, 1 ] + ├── distribution key: [] + └── read pk prefix len hint: 0 + + Table 2 + ├── columns: [ vnode, _row_id, t_backfill_finished, t_row_count ] + ├── primary key: [ $0 ASC ] + ├── value indices: [ 1, 2, 3 ] + ├── distribution key: [ 0 ] + ├── read pk prefix len hint: 1 + └── vnode column idx: 0 + + Table 4294967294 + ├── columns: [ first_value ] + ├── primary key: [] + ├── value indices: [ 0 ] + ├── distribution key: [] + └── read pk prefix len hint: 0 + +- sql: | + create table t (x int, y int); + select first_value(distinct x order by x asc) from t; + stream_dist_plan: |+ + Fragment 0 + StreamMaterialize { columns: [first_value], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + ├── materialized table: 4294967294 + └── StreamProject { exprs: [first_value(distinct t.x order_by(t.x ASC))] } + └── StreamSimpleAgg { aggs: [first_value(distinct t.x order_by(t.x ASC)), count] } + ├── intermediate state table: 1 + ├── state tables: [ 0 ] + ├── distinct tables: [ (distinct key: t.x, table id: 2) ] + └── StreamExchange Single from 1 + + Fragment 1 + Chain { table: t, columns: [t.x, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + ├── state table: 3 + ├── Upstream + └── BatchPlanNode + + Table 0 + ├── columns: [ t_x ] + ├── primary key: [ $0 ASC ] + ├── value indices: [ 0 ] + ├── distribution key: [] + └── read pk prefix len hint: 0 + + Table 1 + ├── columns: [ first_value(distinct t_x order_by(t_x ASC)), count ] + ├── primary key: [] + ├── value indices: [ 0, 1 ] + ├── distribution key: [] + └── read pk prefix len hint: 0 + + Table 2 + ├── columns: [ t_x, count_for_agg_call_0 ] + ├── primary key: [ $0 ASC ] + ├── value indices: [ 1 ] + ├── distribution key: [] + └── read pk prefix len hint: 1 + + Table 3 + ├── columns: [ vnode, _row_id, t_backfill_finished, t_row_count ] + ├── primary key: [ $0 ASC ] + ├── value indices: [ 1, 2, 3 ] + ├── distribution key: [ 0 ] + ├── read pk prefix len hint: 1 + └── vnode column idx: 0 + + Table 4294967294 + ├── columns: [ first_value ] + ├── primary key: [] + ├── value indices: [ 0 ] + ├── distribution key: [] + └── read pk prefix len hint: 0 + - sql: | create table t (x int, y int); select last_value(x order by y desc nulls last) from t; diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 4c35f2540c99f..7cf19fbb613aa 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::fmt; +use std::{fmt, vec}; use fixedbitset::FixedBitSet; use itertools::{Either, Itertools}; @@ -348,6 +348,7 @@ impl Agg { let in_dist_key = self.input.distribution().dist_column_indices().to_vec(); let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>, + extra_keys: Vec, include_keys: Vec| -> MaterializedInputState { let (mut table_builder, mut included_upstream_indices, mut column_mapping) = @@ -375,7 +376,7 @@ impl Agg { for (order_type, idx) in sort_keys { add_column(idx, Some(order_type), true, &mut table_builder); } - for &idx in &in_pks { + for idx in extra_keys { add_column(idx, Some(OrderType::ascending()), true, &mut table_builder); } for idx in include_keys { @@ -458,6 +459,17 @@ impl Agg { _ => unreachable!(), } }; + + // columns to ensure each row unique + let extra_keys = if agg_call.distinct { + // if distinct, use distinct keys as extra keys + let distinct_key = agg_call.inputs[0].index; + vec![distinct_key] + } else { + // if not distinct, use primary keys as extra keys + in_pks.clone() + }; + // other columns that should be contained in state table let include_keys = match agg_call.agg_kind { AggKind::FirstValue @@ -470,7 +482,8 @@ impl Agg { } _ => vec![], }; - let state = gen_materialized_input_state(sort_keys, include_keys); + + let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys); AggCallState::MaterializedInput(Box::new(state)) } agg_kinds::rewritten!() => { diff --git a/src/frontend/src/optimizer/plan_node/stream.rs b/src/frontend/src/optimizer/plan_node/stream.rs index 36b2c77d22aed..d9b11c87678f0 100644 --- a/src/frontend/src/optimizer/plan_node/stream.rs +++ b/src/frontend/src/optimizer/plan_node/stream.rs @@ -599,6 +599,7 @@ pub fn to_stream_prost_body( .into_iter() .map(|(key_idx, table)| (key_idx as u32, table.to_internal_table_prost())) .collect(), + version: PbAggNodeVersion::Max as _, }) } Node::GroupTopN(me) => { @@ -658,6 +659,7 @@ pub fn to_stream_prost_body( .map(|(key_idx, table)| (key_idx as u32, table.to_internal_table_prost())) .collect(), emit_on_window_close: me.emit_on_window_close(), + version: PbAggNodeVersion::Max as _, }) } Node::HashJoin(_) => { @@ -701,6 +703,7 @@ pub fn to_stream_prost_body( intermediate_state_table: None, is_append_only: me.input.0.append_only, distinct_dedup_tables: Default::default(), + version: PbAggNodeVersion::Max as _, }) } Node::Materialize(me) => { diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs index 5f0cfc16a4171..45e22e670dfee 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs @@ -215,6 +215,7 @@ impl StreamNode for StreamHashAgg { .collect(), row_count_index: self.row_count_idx as u32, emit_on_window_close: self.base.emit_on_window_close, + version: PbAggNodeVersion::Issue12140 as _, }) } } diff --git a/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs b/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs index f0c0bab5fae77..44794dbeb2119 100644 --- a/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs @@ -128,6 +128,7 @@ impl StreamNode for StreamSimpleAgg { }) .collect(), row_count_index: self.row_count_idx as u32, + version: PbAggNodeVersion::Issue12140 as _, }) } } diff --git a/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs b/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs index 639b6c5782bb5..f471cb4e1df3e 100644 --- a/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs @@ -103,6 +103,7 @@ impl StreamNode for StreamStatelessSimpleAgg { intermediate_state_table: None, is_append_only: self.input().append_only(), distinct_dedup_tables: Default::default(), + version: AggNodeVersion::Issue12140 as _, }) } } diff --git a/src/stream/src/executor/agg_common.rs b/src/stream/src/executor/agg_common.rs index fbaa80c3fbeb7..d1ea23068d430 100644 --- a/src/stream/src/executor/agg_common.rs +++ b/src/stream/src/executor/agg_common.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::sync::Arc; use risingwave_expr::aggregate::AggCall; +use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; use super::aggregation::AggStateStorage; @@ -27,6 +28,8 @@ use crate::task::AtomicU64Ref; /// Arguments needed to construct an `XxxAggExecutor`. pub struct AggExecutorArgs { + pub version: PbAggNodeVersion, + // basic pub input: Box, pub actor_ctx: ActorContextRef, diff --git a/src/stream/src/executor/aggregation/agg_group.rs b/src/stream/src/executor/aggregation/agg_group.rs index d854969120919..d0e97cd4783e9 100644 --- a/src/stream/src/executor/aggregation/agg_group.rs +++ b/src/stream/src/executor/aggregation/agg_group.rs @@ -25,6 +25,7 @@ use risingwave_common::must_match; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction}; +use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; use super::agg_state::{AggState, AggStateStorage}; @@ -192,6 +193,7 @@ impl AggGroup { /// For [`crate::executor::SimpleAggExecutor`], the `group_key` should be `None`. #[allow(clippy::too_many_arguments)] pub async fn create( + version: PbAggNodeVersion, group_key: Option, agg_calls: &[AggCall], agg_funcs: &[BoxedAggregateFunction], @@ -212,6 +214,7 @@ impl AggGroup { let mut states = Vec::with_capacity(agg_calls.len()); for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() { let state = AggState::create( + version, agg_call, agg_func, &storages[idx], @@ -242,6 +245,7 @@ impl AggGroup { /// Create a group from encoded states for EOWC. The previous output is set to `None`. #[allow(clippy::too_many_arguments)] pub fn create_eowc( + version: PbAggNodeVersion, group_key: Option, agg_calls: &[AggCall], agg_funcs: &[BoxedAggregateFunction], @@ -255,6 +259,7 @@ impl AggGroup { let mut states = Vec::with_capacity(agg_calls.len()); for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() { let state = AggState::create( + version, agg_call, agg_func, &storages[idx], diff --git a/src/stream/src/executor/aggregation/agg_state.rs b/src/stream/src/executor/aggregation/agg_state.rs index 0c1932c58831c..a0413ed4491d2 100644 --- a/src/stream/src/executor/aggregation/agg_state.rs +++ b/src/stream/src/executor/aggregation/agg_state.rs @@ -19,6 +19,7 @@ use risingwave_common::estimate_size::EstimateSize; use risingwave_common::must_match; use risingwave_common::types::Datum; use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction}; +use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; use super::minput::MaterializedInputState; @@ -65,6 +66,7 @@ impl AggState { /// Create an [`AggState`] from a given [`AggCall`]. #[allow(clippy::too_many_arguments)] pub fn create( + version: PbAggNodeVersion, agg_call: &AggCall, agg_func: &BoxedAggregateFunction, storage: &AggStateStorage, @@ -83,6 +85,7 @@ impl AggState { } AggStateStorage::MaterializedInput { mapping, .. } => { Self::MaterializedInput(Box::new(MaterializedInputState::new( + version, agg_call, pk_indices, mapping, diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index 01a330bc8ec9f..aa473a7ccc97e 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -23,6 +23,7 @@ use risingwave_common::types::Datum; use risingwave_common::util::row_serde::OrderedRowSerde; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::aggregate::{AggCall, AggKind, BoxedAggregateFunction}; +use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -66,6 +67,7 @@ pub struct MaterializedInputState { impl MaterializedInputState { /// Create an instance from [`AggCall`]. pub fn new( + version: PbAggNodeVersion, agg_call: &AggCall, pk_indices: &PkIndices, col_mapping: &StateTableColumnMapping, @@ -100,9 +102,26 @@ impl MaterializedInputState { .unzip() }; - let pk_len = pk_indices.len(); - order_col_indices.extend(pk_indices.iter()); - order_types.extend(itertools::repeat_n(OrderType::ascending(), pk_len)); + if agg_call.distinct { + if version < PbAggNodeVersion::Issue12140 { + panic!( + "RisingWave versions before issue #12140 is resolved has critical bug, you must re-create current MV to ensure correctness." + ); + } + + // If distinct, we need to materialize input with the distinct keys + // As we only support single-column distinct for now, we use the + // `agg_call.args.val_indices()[0]` as the distinct key. + if !order_col_indices.contains(&agg_call.args.val_indices()[0]) { + order_col_indices.push(agg_call.args.val_indices()[0]); + order_types.push(OrderType::ascending()); + } + } else { + // If not distinct, we need to materialize input with the primary keys + let pk_len = pk_indices.len(); + order_col_indices.extend(pk_indices.iter()); + order_types.extend(itertools::repeat_n(OrderType::ascending(), pk_len)); + } // map argument columns to state table column indices let state_table_arg_col_indices = arg_col_indices @@ -251,6 +270,7 @@ mod tests { use risingwave_common::util::epoch::EpochPair; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::aggregate::{build_append_only, AggCall}; + use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::memory::MemoryStateStore; use risingwave_storage::StateStore; @@ -323,6 +343,7 @@ mod tests { .await; let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -375,6 +396,7 @@ mod tests { { // test recovery (cold start) let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -416,6 +438,7 @@ mod tests { .await; let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -468,6 +491,7 @@ mod tests { { // test recovery (cold start) let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -525,6 +549,7 @@ mod tests { table_2.init_epoch(epoch); let mut state_1 = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call_1, &input_pk_indices, &mapping_1, @@ -534,6 +559,7 @@ mod tests { .unwrap(); let mut state_2 = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call_2, &input_pk_indices, &mapping_2, @@ -617,6 +643,7 @@ mod tests { .await; let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -668,6 +695,7 @@ mod tests { { // test recovery (cold start) let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -711,6 +739,7 @@ mod tests { table.init_epoch(epoch); let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -810,6 +839,7 @@ mod tests { .await; let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -917,6 +947,7 @@ mod tests { .await; let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, @@ -996,6 +1027,7 @@ mod tests { .await; let mut state = MaterializedInputState::new( + PbAggNodeVersion::Max, &agg_call, &input_pk_indices, &mapping, diff --git a/src/stream/src/executor/hash_agg.rs b/src/stream/src/executor/hash_agg.rs index dd9b23f9147d6..6eafdcfd18361 100644 --- a/src/stream/src/executor/hash_agg.rs +++ b/src/stream/src/executor/hash_agg.rs @@ -29,6 +29,7 @@ use risingwave_common::types::ScalarImpl; use risingwave_common::util::epoch::EpochPair; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::aggregate::{build_retractable, AggCall, BoxedAggregateFunction}; +use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; use super::agg_common::{AggExecutorArgs, HashAggExecutorExtraArgs}; @@ -81,6 +82,9 @@ pub struct HashAggExecutor { struct ExecutorInner { _phantom: PhantomData, + /// Version of aggregation executors. + version: PbAggNodeVersion, + actor_ctx: ActorContextRef, info: ExecutorInfo, @@ -233,6 +237,7 @@ impl HashAggExecutor { input: args.input, inner: ExecutorInner { _phantom: PhantomData, + version: args.version, actor_ctx: args.actor_ctx, info: ExecutorInfo { schema, @@ -318,6 +323,7 @@ impl HashAggExecutor { // Create `AggGroup` for the current group if not exists. This will // restore agg states from the intermediate state table. let agg_group = AggGroup::create( + this.version, Some(GroupKey::new( key.deserialize(group_key_types)?, Some(this.group_key_table_pk_projection.clone()), @@ -466,6 +472,7 @@ impl HashAggExecutor { let states = row.into_iter().skip(this.group_key_indices.len()).collect(); let mut agg_group = AggGroup::create_eowc( + this.version, Some(GroupKey::new( group_key, Some(this.group_key_table_pk_projection.clone()), diff --git a/src/stream/src/executor/simple_agg.rs b/src/stream/src/executor/simple_agg.rs index 6e88241f48433..92730218ca148 100644 --- a/src/stream/src/executor/simple_agg.rs +++ b/src/stream/src/executor/simple_agg.rs @@ -18,6 +18,7 @@ use risingwave_common::array::StreamChunk; use risingwave_common::catalog::Schema; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::aggregate::{build_retractable, AggCall, BoxedAggregateFunction}; +use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; use super::agg_common::{AggExecutorArgs, SimpleAggExecutorExtraArgs}; @@ -52,6 +53,9 @@ pub struct SimpleAggExecutor { } struct ExecutorInner { + /// Version of aggregation executors. + version: PbAggNodeVersion, + actor_ctx: ActorContextRef, info: ExecutorInfo, @@ -135,6 +139,7 @@ impl SimpleAggExecutor { Ok(Self { input: args.input, inner: ExecutorInner { + version: args.version, actor_ctx: args.actor_ctx, info: ExecutorInfo { schema, @@ -257,9 +262,23 @@ impl SimpleAggExecutor { table.init_epoch(barrier.epoch); }); + let mut distinct_dedup = DistinctDeduplicater::new( + &this.agg_calls, + &this.watermark_epoch, + &this.distinct_dedup_tables, + this.actor_ctx.id, + this.metrics.clone(), + ); + distinct_dedup.dedup_caches_mut().for_each(|cache| { + cache.update_epoch(barrier.epoch.curr); + }); + + yield Message::Barrier(barrier); + let mut vars = ExecutionVars { // This will fetch previous agg states from the intermediate state table. agg_group: AggGroup::create( + this.version, None, &this.agg_calls, &this.agg_funcs, @@ -271,22 +290,10 @@ impl SimpleAggExecutor { &this.input_schema, ) .await?, - distinct_dedup: DistinctDeduplicater::new( - &this.agg_calls, - &this.watermark_epoch, - &this.distinct_dedup_tables, - this.actor_ctx.id, - this.metrics.clone(), - ), + distinct_dedup, state_changed: false, }; - vars.distinct_dedup.dedup_caches_mut().for_each(|cache| { - cache.update_epoch(barrier.epoch.curr); - }); - - yield Message::Barrier(barrier); - #[for_await] for msg in input { let msg = msg?; diff --git a/src/stream/src/executor/test_utils.rs b/src/stream/src/executor/test_utils.rs index bb4864ac04ef8..5bcb0fe547976 100644 --- a/src/stream/src/executor/test_utils.rs +++ b/src/stream/src/executor/test_utils.rs @@ -272,6 +272,7 @@ pub mod agg_executor { use risingwave_common::types::DataType; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::aggregate::{AggCall, AggKind}; + use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; use crate::common::table::state_table::StateTable; @@ -436,6 +437,8 @@ pub mod agg_executor { .await; HashAggExecutor::::new(AggExecutorArgs { + version: PbAggNodeVersion::Max, + input, actor_ctx: ActorContext::create(123), pk_indices, @@ -499,6 +502,8 @@ pub mod agg_executor { .await; SimpleAggExecutor::new(AggExecutorArgs { + version: PbAggNodeVersion::Max, + input, actor_ctx, pk_indices, diff --git a/src/stream/src/from_proto/hash_agg.rs b/src/stream/src/from_proto/hash_agg.rs index a369f8124ebfb..faf8a1f7fdad1 100644 --- a/src/stream/src/from_proto/hash_agg.rs +++ b/src/stream/src/from_proto/hash_agg.rs @@ -97,6 +97,8 @@ impl ExecutorBuilder for HashAggExecutorBuilder { HashAggExecutorDispatcherArgs { args: AggExecutorArgs { + version: node.version(), + input, actor_ctx: params.actor_context, pk_indices: params.pk_indices, diff --git a/src/stream/src/from_proto/simple_agg.rs b/src/stream/src/from_proto/simple_agg.rs index 5423e4fd2043f..fdd6d877b99ed 100644 --- a/src/stream/src/from_proto/simple_agg.rs +++ b/src/stream/src/from_proto/simple_agg.rs @@ -58,6 +58,8 @@ impl ExecutorBuilder for SimpleAggExecutorBuilder { .await; Ok(SimpleAggExecutor::new(AggExecutorArgs { + version: node.version(), + input, actor_ctx: params.actor_context, pk_indices: params.pk_indices,