From 85b8e3e8363db2fbf6c55a3a3e7617725a3a50e2 Mon Sep 17 00:00:00 2001 From: Noel Kwan Date: Mon, 22 Jul 2024 16:05:17 +0800 Subject: [PATCH] fix agg key merge schema --- .../tests/testdata/input/agg.yaml | 9 ++- .../tests/testdata/output/agg.yaml | 77 +++++++++++-------- .../src/optimizer/plan_node/logical_agg.rs | 7 +- .../optimizer/plan_node/stream_keyed_merge.rs | 31 ++++++-- 4 files changed, 82 insertions(+), 42 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 9109b31e81a74..fe5e39b50730c 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -1000,6 +1000,13 @@ expected_outputs: - batch_plan - stream_plan +- name: test duplicate agg + sql: | + CREATE TABLE t (v1 int); + SELECT sum(v1) as x, count(v1) as y, sum(v1) as z, count(v1) as w from t; + expected_outputs: + - logical_plan + - stream_plan - name: test simple approx_percentile alone sql: | CREATE TABLE t (v1 int); @@ -1017,7 +1024,7 @@ - name: test simple approx_percentile with other simple aggs (sum, count) sql: | CREATE TABLE t (v1 int); - SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1), count(v1) from t; + SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1) as s2, count(v1) from t; expected_outputs: - logical_plan - stream_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index 1cce4e53f67ec..7aceaa57c45b4 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1863,6 +1863,22 @@ └─StreamHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c), sum(t.d), count(t.d), max(t.e), count] } └─StreamExchange { dist: HashShard(t.a, t.b) } └─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t.d, t.e, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test duplicate agg + sql: | + CREATE TABLE t (v1 int); + SELECT sum(v1) as x, count(v1) as y, sum(v1) as z, count(v1) as w from t; + logical_plan: |- + LogicalProject { exprs: [sum(t.v1), count(t.v1), sum(t.v1), count(t.v1)] } + └─LogicalAgg { aggs: [sum(t.v1), count(t.v1)] } + └─LogicalProject { exprs: [t.v1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [x, y, z, w], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamProject { exprs: [sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1)), sum0(count(t.v1))] } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), sum0(count(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1), count(t.v1)] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - name: test simple approx_percentile alone sql: | CREATE TABLE t (v1 int); @@ -1888,39 +1904,40 @@ └─LogicalProject { exprs: [t.v1::Float64 as $expr1, t.v1] } └─LogicalScan { table: t, columns: [t.v1, t._row_id] } stream_plan: |- - StreamMaterialize { columns: [approx_percentile, sum], stream_key: [], pk_columns: [], pk_conflict: NoCheck } - └─StreamKeyedMerge { output: [approx_percentile($expr10011):Float64, sum(t.v1):Int64] } - ├─StreamSimpleAgg { aggs: [sum(sum(t.v1)), count] } - │ └─StreamExchange { dist: Single } - │ └─StreamStatelessSimpleAgg { aggs: [sum(t.v1)] } - │ └─StreamShare { id: 2 } - │ └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } - │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - └─StreamGlobalApproxPercentile - └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Decimal, relative_error: 0.01:Decimal } - └─StreamShare { id: 2 } - └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } - └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + StreamMaterialize { columns: [approx_percentile, sum], stream_key: [approx_percentile], pk_columns: [approx_percentile], pk_conflict: NoCheck } + └─StreamKeyedMerge { output: [approx_percentile:Float64, sum(sum(t.v1)):Int64] } + ├─StreamGlobalApproxPercentile + │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Decimal, relative_error: 0.01:Decimal } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1)] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - name: test simple approx_percentile with other simple aggs (sum, count) sql: | CREATE TABLE t (v1 int); - SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1), count(v1) from t; + SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1) as s2, count(v1) from t; logical_plan: |- - LogicalProject { exprs: [approx_percentile($expr1), sum(t.v1), count(t.v1)] } - └─LogicalAgg { aggs: [approx_percentile($expr1), sum(t.v1), count(t.v1)] } - └─LogicalProject { exprs: [t.v1::Float64 as $expr1, t.v1] } + LogicalProject { exprs: [sum(t.v1), approx_percentile($expr1), sum(t.v1), count(t.v1)] } + └─LogicalAgg { aggs: [sum(t.v1), approx_percentile($expr1), count(t.v1)] } + └─LogicalProject { exprs: [t.v1, t.v1::Float64 as $expr1] } └─LogicalScan { table: t, columns: [t.v1, t._row_id] } stream_plan: |- - StreamMaterialize { columns: [approx_percentile, sum, count], stream_key: [], pk_columns: [], pk_conflict: NoCheck } - └─StreamKeyedMerge { output: [approx_percentile($expr10011):Float64, sum(t.v1):Int64, count(t.v1):Int64] } - ├─StreamSimpleAgg { aggs: [sum(sum(t.v1)), sum0(count(t.v1)), count] } - │ └─StreamExchange { dist: Single } - │ └─StreamStatelessSimpleAgg { aggs: [sum(t.v1), count(t.v1)] } - │ └─StreamShare { id: 2 } - │ └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } - │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - └─StreamGlobalApproxPercentile - └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Decimal, relative_error: 0.01:Decimal } - └─StreamShare { id: 2 } - └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } - └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + StreamMaterialize { columns: [s1, approx_percentile, s2, count], stream_key: [s2], pk_columns: [s2], pk_conflict: NoCheck } + └─StreamProject { exprs: [sum(sum(t.v1)), approx_percentile, sum(sum(t.v1)), sum0(count(t.v1))] } + └─StreamKeyedMerge { output: [sum(sum(t.v1)):Int64, approx_percentile:Float64, sum0(count(t.v1)):Int64] } + ├─StreamGlobalApproxPercentile + │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Decimal, relative_error: 0.01:Decimal } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), sum0(count(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1), count(t.v1)] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index b575d58fb9636..4499474358684 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -63,8 +63,6 @@ impl LogicalAgg { fn gen_stateless_two_phase_streaming_agg_plan(&self, stream_input: PlanRef) -> Result { debug_assert!(self.group_key().is_empty()); let mut core = self.core.clone(); - let schema = self.base.schema().clone(); - println!("agg schema: {:?}", schema); // First, handle approx percentile. let has_approx_percentile = self @@ -114,12 +112,11 @@ impl LogicalAgg { )); if let Some((approx_percentile_agg, lhs_mapping, rhs_mapping)) = approx_percentile_info { let keyed_merge = StreamKeyedMerge::new( - global_agg.into(), approx_percentile_agg, + global_agg.into(), lhs_mapping, rhs_mapping, - schema, - ); + )?; Ok(keyed_merge.into()) } else { Ok(global_agg.into()) diff --git a/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs b/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs index 86034a3b3889f..cac5ca63ed764 100644 --- a/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs +++ b/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::anyhow; +use fixedbitset::FixedBitSet; use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::bail; use risingwave_common::catalog::Schema; use risingwave_common::util::column_index_mapping::ColIndexMapping; use risingwave_pb::stream_plan::stream_node::PbNodeBody; @@ -29,6 +32,7 @@ use crate::optimizer::plan_node::{ }; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::PlanRef; +use crate::error::Result; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct StreamKeyedMerge { @@ -47,9 +51,24 @@ impl StreamKeyedMerge { rhs_input: PlanRef, lhs_mapping: ColIndexMapping, rhs_mapping: ColIndexMapping, - schema: Schema, - ) -> Self { - println!("keyed merge schema: {:?}", schema); + ) -> Result { + assert_eq!(lhs_mapping.target_size(), rhs_mapping.target_size()); + let mut schema_fields = Vec::with_capacity(lhs_mapping.target_size()); + let mut o2i_lhs = lhs_mapping.inverse().ok_or_else(|| anyhow!("lhs_mapping should be invertible"))?; + let mut o2i_rhs = rhs_mapping.inverse().ok_or_else(|| anyhow!("rhs_mapping should be invertible"))?; + for output_idx in 0..lhs_mapping.target_size() { + if let Some(lhs_idx) = o2i_lhs.try_map(output_idx) { + schema_fields.push(lhs_input.schema().fields()[lhs_idx].clone()); + } else if let Some(rhs_idx) = o2i_rhs.try_map(output_idx) { + println!("rhs schema: {:?}", rhs_input.schema().fields()); + schema_fields.push(rhs_input.schema().fields()[rhs_idx].clone()); + } else { + bail!("output index {} not found in either lhs or rhs mapping", output_idx); + } + } + let schema = Schema::new(schema_fields); + let watermark_columns = FixedBitSet::with_capacity(schema.fields.len()); + // FIXME: schema is wrong. let base = PlanBase::new_stream( lhs_input.ctx(), @@ -59,16 +78,16 @@ impl StreamKeyedMerge { lhs_input.distribution().clone(), lhs_input.append_only(), lhs_input.emit_on_window_close(), - lhs_input.watermark_columns().clone(), + watermark_columns, lhs_input.columns_monotonicity().clone(), ); - Self { + Ok(Self { base, lhs_input, rhs_input, lhs_mapping, rhs_mapping, - } + }) } }