diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 99aa94ff773b..baa070aa9cd9 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -912,6 +912,12 @@ expected_outputs: - batch_plan - stream_plan +- sql: | + create table t (x int, y int); + select first_value(distinct x order by x asc) from t; + expected_outputs: + - batch_plan + - stream_plan - sql: | create table t (x int, y int); select last_value(x order by y desc nulls last) from t; diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index aefb4df98ef4..69f495acad71 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1613,6 +1613,19 @@ └─StreamSimpleAgg { aggs: [first_value(t.x order_by(t.y ASC)), count] } └─StreamExchange { dist: Single } └─StreamTableScan { table: t, columns: [t.x, t.y, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + create table t (x int, y int); + select first_value(distinct x order by x asc) from t; + batch_plan: |- + BatchSimpleAgg { aggs: [first_value(distinct t.x order_by(t.x ASC))] } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: t, columns: [t.x], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [first_value], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamProject { exprs: [first_value(distinct t.x order_by(t.x ASC))] } + └─StreamSimpleAgg { aggs: [first_value(distinct t.x order_by(t.x ASC)), count] } + └─StreamExchange { dist: Single } + └─StreamTableScan { table: t, columns: [t.x, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } - sql: | create table t (x int, y int); select last_value(x order by y desc nulls last) from t; diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index f83087d8ddbb..3de945bcf684 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::fmt; +use std::{fmt, vec}; use fixedbitset::FixedBitSet; use itertools::{Either, Itertools}; @@ -348,8 +348,8 @@ impl Agg { let in_dist_key = self.input.distribution().dist_column_indices().to_vec(); let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>, - include_keys: Vec, - distinct_key: Option| + extra_keys: Vec, + include_keys: Vec| -> MaterializedInputState { let (mut table_builder, mut included_upstream_indices, mut column_mapping) = self.create_table_builder(me.ctx(), window_col_idx); @@ -376,17 +376,8 @@ impl Agg { for (order_type, idx) in sort_keys { add_column(idx, Some(order_type), true, &mut table_builder); } - if let Some(distinct_key) = distinct_key { - add_column( - distinct_key, - Some(OrderType::ascending()), - true, - &mut table_builder, - ); - } else { - for &idx in &in_pks { - add_column(idx, Some(OrderType::ascending()), true, &mut table_builder); - } + for idx in extra_keys { + add_column(idx, Some(OrderType::ascending()), true, &mut table_builder); } for idx in include_keys { add_column(idx, None, true, &mut table_builder); @@ -468,6 +459,17 @@ impl Agg { _ => unreachable!(), } }; + + // columns to ensure each row unique + let extra_keys = if agg_call.distinct { + // if distinct, use distinct keys as extra keys + let distinct_key = agg_call.inputs[0].index; + vec![distinct_key] + } else { + // if not distinct, use primary keys as extra keys + in_pks.clone() + }; + // other columns that should be contained in state table let include_keys = match agg_call.agg_kind { AggKind::FirstValue @@ -480,12 +482,8 @@ impl Agg { } _ => vec![], }; - let state = if agg_call.distinct { - let distinct_key = agg_call.inputs[0].index; - gen_materialized_input_state(sort_keys, include_keys, Some(distinct_key)) - } else { - gen_materialized_input_state(sort_keys, include_keys, None) - }; + + let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys); AggCallState::MaterializedInput(Box::new(state)) } agg_kinds::rewritten!() => { diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index 79ce53140359..5060ed8a1da0 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -103,11 +103,15 @@ impl MaterializedInputState { }; if agg_call.distinct { + // If distinct, we need to materialize input with the distinct keys + // As we only support single-column distinct for now, we use the + // `agg_call.args.val_indices()[0]` as the distinct key. if !order_col_indices.contains(&agg_call.args.val_indices()[0]) { order_col_indices.push(agg_call.args.val_indices()[0]); order_types.push(OrderType::ascending()); } } else { + // If not distinct, we need to materialize input with the primary keys let pk_len = pk_indices.len(); order_col_indices.extend(pk_indices.iter()); order_types.extend(itertools::repeat_n(OrderType::ascending(), pk_len));