diff --git a/proto/expr.proto b/proto/expr.proto index 9babde3fecaf..da466d690f82 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -436,6 +436,8 @@ message AggCall { LAST_VALUE = 25; GROUPING = 26; INTERNAL_LAST_SEEN_VALUE = 27; + APPROX_PERCENTILE = 28; + // user defined aggregate function USER_DEFINED = 100; // wraps a scalar function that takes a list as input as an aggregate function. diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index cd735a1df49f..d35077698da6 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -313,7 +313,10 @@ pub mod agg_kinds { | PbAggKind::StddevSamp | PbAggKind::VarPop | PbAggKind::VarSamp - | PbAggKind::Grouping, + | PbAggKind::Grouping + // ApproxPercentile always uses custom agg executors, + // rather than an aggregation operator + | PbAggKind::ApproxPercentile ) }; } @@ -443,7 +446,10 @@ pub mod agg_kinds { macro_rules! ordered_set { () => { AggKind::Builtin( - PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode, + PbAggKind::PercentileCont + | PbAggKind::PercentileDisc + | PbAggKind::Mode + | PbAggKind::ApproxPercentile, ) }; } diff --git a/src/expr/impl/src/aggregate/approx_percentile.rs b/src/expr/impl/src/aggregate/approx_percentile.rs new file mode 100644 index 000000000000..7f9ae3fb4a53 --- /dev/null +++ b/src/expr/impl/src/aggregate/approx_percentile.rs @@ -0,0 +1,67 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Range; + +use risingwave_common::array::*; +use risingwave_common::types::*; +use risingwave_common_estimate_size::EstimateSize; +use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState}; +use risingwave_expr::{build_aggregate, Result}; + +#[build_aggregate("approx_percentile(float8) -> float8")] +fn build(agg: &AggCall) -> Result> { + let fraction = agg.direct_args[0] + .literal() + .map(|x| (*x.as_float64()).into()); + Ok(Box::new(ApproxPercentile { fraction })) +} + +#[allow(dead_code)] +pub struct ApproxPercentile { + fraction: Option, +} + +#[derive(Debug, Default, EstimateSize)] +struct State(Vec); + +impl AggStateDyn for State {} + +#[async_trait::async_trait] +impl AggregateFunction for ApproxPercentile { + fn return_type(&self) -> DataType { + DataType::Float64 + } + + fn create_state(&self) -> Result { + todo!() + } + + async fn update(&self, _state: &mut AggregateState, _input: &StreamChunk) -> Result<()> { + todo!() + } + + async fn update_range( + &self, + _state: &mut AggregateState, + _input: &StreamChunk, + _range: Range, + ) -> Result<()> { + todo!() + } + + async fn get_result(&self, _state: &AggregateState) -> Result { + todo!() + } +} diff --git a/src/expr/impl/src/aggregate/mod.rs b/src/expr/impl/src/aggregate/mod.rs index c0b6a5ae64c3..349574018fed 100644 --- a/src/expr/impl/src/aggregate/mod.rs +++ b/src/expr/impl/src/aggregate/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. mod approx_count_distinct; +mod approx_percentile; mod array_agg; mod bit_and; mod bit_or; diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 70f83549ff18..25f62054f1e6 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -1000,3 +1000,59 @@ expected_outputs: - batch_plan - stream_plan +- name: test duplicate agg + sql: | + CREATE TABLE t (v1 int); + SELECT sum(v1) as x, count(v1) as y, sum(v1) as z, count(v1) as w from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile alone + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile with other simple aggs + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1) from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile with other simple aggs (sum, count) + sql: | + CREATE TABLE t (v1 int); + SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1) as s2, count(v1) from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile with duplicate approx_percentile + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as y from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile with different approx_percentile + sql: | + CREATE TABLE t (v1 int, v2 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v2) as y from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile with different approx_percentile interleaved with stateless simple aggs + sql: | + CREATE TABLE t (v1 int, v2 int); + SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, count(*), sum(v2) as s2, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v2) as y from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test simple approx_percentile with descending order + sql: | + CREATE TABLE t (v1 int, v2 int); + SELECT sum(v1) as s1, approx_percentile(0.2, 0.01) WITHIN GROUP (order by v1 desc) from t; + expected_outputs: + - logical_plan + - stream_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index 4c75b8318774..81e4185b128c 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1863,3 +1863,178 @@ └─StreamHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c), sum(t.d), count(t.d), max(t.e), count] } └─StreamExchange { dist: HashShard(t.a, t.b) } └─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t.d, t.e, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test duplicate agg + sql: | + CREATE TABLE t (v1 int); + SELECT sum(v1) as x, count(v1) as y, sum(v1) as z, count(v1) as w from t; + logical_plan: |- + LogicalProject { exprs: [sum(t.v1), count(t.v1), sum(t.v1), count(t.v1)] } + └─LogicalAgg { aggs: [sum(t.v1), count(t.v1)] } + └─LogicalProject { exprs: [t.v1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [x, y, z, w], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamProject { exprs: [sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1)), sum0(count(t.v1))] } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), sum0(count(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1), count(t.v1)] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile alone + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) from t; + logical_plan: |- + LogicalProject { exprs: [approx_percentile($expr1)] } + └─LogicalAgg { aggs: [approx_percentile($expr1)] } + └─LogicalProject { exprs: [t.v1::Float64 as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [approx_percentile], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + └─StreamProject { exprs: [t.v1::Float64 as $expr1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile with other simple aggs + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1) from t; + logical_plan: |- + LogicalProject { exprs: [approx_percentile($expr1), sum(t.v1)] } + └─LogicalAgg { aggs: [approx_percentile($expr1), sum(t.v1)] } + └─LogicalProject { exprs: [t.v1::Float64 as $expr1, t.v1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [approx_percentile, sum], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamKeyedMerge { output: [approx_percentile:Float64, sum(sum(t.v1)):Int64] } + ├─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1)] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile with other simple aggs (sum, count) + sql: | + CREATE TABLE t (v1 int); + SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1), sum(v1) as s2, count(v1) from t; + logical_plan: |- + LogicalProject { exprs: [sum(t.v1), approx_percentile($expr1), sum(t.v1), count(t.v1)] } + └─LogicalAgg { aggs: [sum(t.v1), approx_percentile($expr1), count(t.v1)] } + └─LogicalProject { exprs: [t.v1, t.v1::Float64 as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [s1, approx_percentile, s2, count], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamProject { exprs: [sum(sum(t.v1)), approx_percentile, sum(sum(t.v1)), sum0(count(t.v1))] } + └─StreamKeyedMerge { output: [sum(sum(t.v1)):Int64, approx_percentile:Float64, sum0(count(t.v1)):Int64] } + ├─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), sum0(count(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1), count(t.v1)] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile with duplicate approx_percentile + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as y from t; + logical_plan: |- + LogicalProject { exprs: [approx_percentile($expr1), approx_percentile($expr1)] } + └─LogicalAgg { aggs: [approx_percentile($expr1)] } + └─LogicalProject { exprs: [t.v1::Float64 as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [x, y], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamProject { exprs: [approx_percentile, approx_percentile] } + └─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + └─StreamProject { exprs: [t.v1::Float64 as $expr1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile with different approx_percentile + sql: | + CREATE TABLE t (v1 int, v2 int); + SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v2) as y from t; + logical_plan: |- + LogicalProject { exprs: [approx_percentile($expr1), approx_percentile($expr2)] } + └─LogicalAgg { aggs: [approx_percentile($expr1), approx_percentile($expr2)] } + └─LogicalProject { exprs: [t.v1::Float64 as $expr1, t.v2::Float64 as $expr2] } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [x, y], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamKeyedMerge { output: [approx_percentile:Float64, approx_percentile:Float64] } + ├─StreamKeyedMerge { output: [approx_percentile:Float64, approx_percentile:Float64] } + │ ├─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ │ └─StreamShare { id: 2 } + │ │ └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v2::Float64 as $expr2, t._row_id] } + │ │ └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + │ └─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamLocalApproxPercentile { percentile_col: $expr2, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v2::Float64 as $expr2, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1::Float64 as $expr1, t.v2::Float64 as $expr2, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile with different approx_percentile interleaved with stateless simple aggs + sql: | + CREATE TABLE t (v1 int, v2 int); + SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, count(*), sum(v2) as s2, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v2) as y from t; + logical_plan: |- + LogicalProject { exprs: [sum(t.v1), approx_percentile($expr1), count, sum(t.v2), approx_percentile($expr2)] } + └─LogicalAgg { aggs: [sum(t.v1), approx_percentile($expr1), count, sum(t.v2), approx_percentile($expr2)] } + └─LogicalProject { exprs: [t.v1, t.v1::Float64 as $expr1, t.v2, t.v2::Float64 as $expr2] } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [s1, x, count, s2, y], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamKeyedMerge { output: [sum(sum(t.v1)):Int64, approx_percentile:Float64, sum0(count):Int64, sum(sum(t.v2)):Int64, approx_percentile:Float64] } + ├─StreamKeyedMerge { output: [approx_percentile:Float64, approx_percentile:Float64] } + │ ├─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ │ └─StreamShare { id: 2 } + │ │ └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t.v2, t.v2::Float64 as $expr2, t._row_id] } + │ │ └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + │ └─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamLocalApproxPercentile { percentile_col: $expr2, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t.v2, t.v2::Float64 as $expr2, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), sum0(count), sum(sum(t.v2)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1), count, sum(t.v2)] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t.v2, t.v2::Float64 as $expr2, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test simple approx_percentile with descending order + sql: | + CREATE TABLE t (v1 int, v2 int); + SELECT sum(v1) as s1, approx_percentile(0.2, 0.01) WITHIN GROUP (order by v1 desc) from t; + logical_plan: |- + LogicalProject { exprs: [sum(t.v1), approx_percentile($expr1)] } + └─LogicalAgg { aggs: [sum(t.v1), approx_percentile($expr1)] } + └─LogicalProject { exprs: [t.v1, t.v1::Float64 as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [s1, approx_percentile], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamKeyedMerge { output: [sum(sum(t.v1)):Int64, approx_percentile:Float64] } + ├─StreamGlobalApproxPercentile { quantile: 0.8:Float64, relative_error: 0.01:Float64 } + │ └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.8:Float64, relative_error: 0.01:Float64 } + │ └─StreamShare { id: 2 } + │ └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamSimpleAgg { aggs: [sum(sum(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum(t.v1)] } + └─StreamShare { id: 2 } + └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 1d08fc3e19cb..b9610819756b 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -424,6 +424,36 @@ impl Binder { )?))) } + fn decimal_to_float64(decimal_expr: &mut ExprImpl, kind: &AggKind) -> Result<()> { + if decimal_expr.cast_implicit_mut(DataType::Float64).is_err() { + return Err(ErrorCode::InvalidInputSyntax(format!( + "direct arg in `{}` must be castable to float64", + kind + )) + .into()); + } + + let Some(Ok(fraction_datum)) = decimal_expr.try_fold_const() else { + bail_not_implemented!( + issue = 14079, + "variable as direct argument of ordered-set aggregate", + ); + }; + + if let Some(ref fraction_value) = fraction_datum + && !(0.0..=1.0).contains(&fraction_value.as_float64().0) + { + return Err(ErrorCode::InvalidInputSyntax(format!( + "direct arg in `{}` must between 0.0 and 1.0", + kind + )) + .into()); + } + // note that the fraction can be NULL + *decimal_expr = Literal::new(fraction_datum, DataType::Float64).into(); + Ok(()) + } + fn bind_ordered_set_agg( &mut self, f: Function, @@ -474,33 +504,7 @@ impl Binder { [fraction], [arg], ) => { - if fraction.cast_implicit_mut(DataType::Float64).is_err() { - return Err(ErrorCode::InvalidInputSyntax(format!( - "direct arg in `{}` must be castable to float64", - kind - )) - .into()); - } - - let Some(Ok(fraction_datum)) = fraction.try_fold_const() else { - bail_not_implemented!( - issue = 14079, - "variable as direct argument of ordered-set aggregate", - ); - }; - - if let Some(ref fraction_value) = fraction_datum - && !(0.0..=1.0).contains(&fraction_value.as_float64().0) - { - return Err(ErrorCode::InvalidInputSyntax(format!( - "direct arg in `{}` must between 0.0 and 1.0", - kind - )) - .into()); - } - // note that the fraction can be NULL - *fraction = Literal::new(fraction_datum, DataType::Float64).into(); - + Self::decimal_to_float64(fraction, &kind)?; if matches!(&kind, AggKind::Builtin(PbAggKind::PercentileCont)) { arg.cast_implicit_mut(DataType::Float64).map_err(|_| { ErrorCode::InvalidInputSyntax(format!( @@ -511,6 +515,14 @@ impl Binder { } } (AggKind::Builtin(PbAggKind::Mode), [], [_arg]) => {} + ( + AggKind::Builtin(PbAggKind::ApproxPercentile), + [percentile, relative_error], + [_percentile_col], + ) => { + Self::decimal_to_float64(percentile, &kind)?; + Self::decimal_to_float64(relative_error, &kind)?; + } _ => { return Err(ErrorCode::InvalidInputSyntax(format!( "invalid direct args or within group argument for `{}` aggregation", @@ -568,7 +580,11 @@ impl Binder { ); if f.distinct { - if matches!(kind, AggKind::Builtin(PbAggKind::ApproxCountDistinct)) { + if matches!( + kind, + AggKind::Builtin(PbAggKind::ApproxCountDistinct) + | AggKind::Builtin(PbAggKind::ApproxPercentile) + ) { return Err(ErrorCode::InvalidInputSyntax(format!( "DISTINCT is not allowed for approximate aggregation `{}`", kind diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 1ca640f54b56..af1be41a711e 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -763,4 +763,267 @@ mod tests { expected.assert_eq(&format!("{:#?}", bound)); } + + #[tokio::test] + async fn test_bind_approx_percentile() { + let stmt = risingwave_sqlparser::parser::Parser::parse_sql( + "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)", + ).unwrap().into_iter().next().unwrap(); + let parse_expected = expect![[r#" + Query( + Query { + with: None, + body: Select( + Select { + distinct: All, + projection: [ + UnnamedExpr( + Function( + Function { + aggregate: false, + name: ObjectName( + [ + Ident { + value: "approx_percentile", + quote_style: None, + }, + ], + ), + args: [ + Unnamed( + Expr( + Value( + Number( + "0.5", + ), + ), + ), + ), + Unnamed( + Expr( + Value( + Number( + "0.01", + ), + ), + ), + ), + ], + variadic: false, + over: None, + distinct: false, + order_by: [], + filter: None, + within_group: Some( + OrderByExpr { + expr: Identifier( + Ident { + value: "generate_series", + quote_style: None, + }, + ), + asc: None, + nulls_first: None, + }, + ), + }, + ), + ), + ], + from: [ + TableWithJoins { + relation: TableFunction { + name: ObjectName( + [ + Ident { + value: "generate_series", + quote_style: None, + }, + ], + ), + alias: None, + args: [ + Unnamed( + Expr( + Value( + Number( + "1", + ), + ), + ), + ), + Unnamed( + Expr( + Value( + Number( + "100", + ), + ), + ), + ), + ], + with_ordinality: false, + }, + joins: [], + }, + ], + lateral_views: [], + selection: None, + group_by: [], + having: None, + }, + ), + order_by: [], + limit: None, + offset: None, + fetch: None, + }, + )"#]]; + parse_expected.assert_eq(&format!("{:#?}", stmt)); + + let mut binder = mock_binder(); + let bound = binder.bind(stmt).unwrap(); + + let expected = expect![[r#" + Query( + BoundQuery { + body: Select( + BoundSelect { + distinct: All, + select_items: [ + AggCall( + AggCall { + agg_kind: Builtin( + ApproxPercentile, + ), + return_type: Float64, + args: [ + FunctionCall( + FunctionCall { + func_type: Cast, + return_type: Float64, + inputs: [ + InputRef( + InputRef { + index: 0, + data_type: Int32, + }, + ), + ], + }, + ), + ], + filter: Condition { + conjunctions: [], + }, + distinct: false, + order_by: OrderBy { + sort_exprs: [ + OrderByExpr { + expr: InputRef( + InputRef { + index: 0, + data_type: Int32, + }, + ), + order_type: OrderType { + direction: Ascending, + nulls_are: Largest, + }, + }, + ], + }, + direct_args: [ + Literal { + data: Some( + Float64( + OrderedFloat( + 0.5, + ), + ), + ), + data_type: Some( + Float64, + ), + }, + Literal { + data: Some( + Float64( + OrderedFloat( + 0.01, + ), + ), + ), + data_type: Some( + Float64, + ), + }, + ], + }, + ), + ], + aliases: [ + Some( + "approx_percentile", + ), + ], + from: Some( + TableFunction { + expr: TableFunction( + FunctionCall { + function_type: GenerateSeries, + return_type: Int32, + args: [ + Literal( + Literal { + data: Some( + Int32( + 1, + ), + ), + data_type: Some( + Int32, + ), + }, + ), + Literal( + Literal { + data: Some( + Int32( + 100, + ), + ), + data_type: Some( + Int32, + ), + }, + ), + ], + }, + ), + with_ordinality: false, + }, + ), + where_clause: None, + group_by: GroupKey( + [], + ), + having: None, + schema: Schema { + fields: [ + approx_percentile:Float64, + ], + }, + }, + ), + order: [], + limit: None, + offset: None, + with_ties: false, + extra_order_exprs: [], + }, + )"#]]; + + expected.assert_eq(&format!("{:#?}", bound)); + } } diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index 9be381f4512a..452d37652d34 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -38,6 +38,9 @@ impl std::fmt::Debug for AggCall { .field("return_type", &self.return_type) .field("args", &self.args) .field("filter", &self.filter) + .field("distinct", &self.distinct) + .field("order_by", &self.order_by) + .field("direct_args", &self.direct_args) .finish() } else { let mut builder = f.debug_tuple(&format!("{}", self.agg_kind)); diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index a085e46e12ef..c0f41b8d82c7 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -105,8 +105,11 @@ impl Agg { && !self.agg_calls.is_empty() && self.agg_calls.iter().all(|call| { let agg_kind_ok = !matches!(call.agg_kind, agg_kinds::simply_cannot_two_phase!()); - let order_ok = matches!(call.agg_kind, agg_kinds::result_unaffected_by_order_by!()) - || call.order_by.is_empty(); + let order_ok = matches!( + call.agg_kind, + agg_kinds::result_unaffected_by_order_by!() + | AggKind::Builtin(PbAggKind::ApproxPercentile) + ) || call.order_by.is_empty(); let distinct_ok = matches!(call.agg_kind, agg_kinds::result_unaffected_by_distinct!()) || !call.distinct; @@ -133,6 +136,7 @@ impl Agg { self.agg_calls.iter().all(|c| { matches!(c.agg_kind, agg_kinds::single_value_state!()) || (matches!(c.agg_kind, agg_kinds::single_value_state_iff_in_append_only!() if stream_input_append_only)) + || (matches!(c.agg_kind, AggKind::Builtin(PbAggKind::ApproxPercentile))) }) } diff --git a/src/frontend/src/optimizer/plan_node/generic/mod.rs b/src/frontend/src/optimizer/plan_node/generic/mod.rs index d83ab50d8923..3e01dee8aa0b 100644 --- a/src/frontend/src/optimizer/plan_node/generic/mod.rs +++ b/src/frontend/src/optimizer/plan_node/generic/mod.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! This module contains the generic plan nodes that are shared by all the plan nodes. +//! They are meant to reuse the common fields between logical, batch and stream nodes. + use std::borrow::Cow; use std::hash::Hash; diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index ed2b4d308815..54c59883b10f 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -15,7 +15,7 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::types::{DataType, Datum, ScalarImpl}; -use risingwave_common::util::sort_util::ColumnOrder; +use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::{bail_not_implemented, not_implemented}; use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; @@ -23,8 +23,8 @@ use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder}; use super::utils::impl_distill_by_unit; use super::{ BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, - PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamProject, StreamSimpleAgg, - StreamStatelessSimpleAgg, ToBatch, ToStream, + PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamProject, StreamShare, + StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream, }; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{ @@ -33,6 +33,9 @@ use crate::expr::{ }; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::generic::GenericPlanNode; +use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile; +use crate::optimizer::plan_node::stream_keyed_merge::StreamKeyedMerge; +use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile; use crate::optimizer::plan_node::{ gen_filter_and_pushdown, BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext, RewriteStreamContext, ToStreamContext, @@ -42,6 +45,17 @@ use crate::utils::{ ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute, }; +pub struct AggInfo { + pub calls: Vec, + pub col_mapping: ColIndexMapping, +} + +/// `SeparatedAggInfo` is used to separate normal and approx percentile aggs. +pub struct SeparatedAggInfo { + normal: AggInfo, + approx: AggInfo, +} + /// `LogicalAgg` groups input data by their group key and computes aggregation functions. /// /// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate @@ -60,22 +74,64 @@ impl LogicalAgg { fn gen_stateless_two_phase_streaming_agg_plan(&self, stream_input: PlanRef) -> Result { debug_assert!(self.group_key().is_empty()); let mut core = self.core.clone(); - core.input = stream_input; + + // ====== Handle approx percentile aggs + let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg(); + + let AggInfo { + calls: non_approx_percentile_agg_calls, + col_mapping: non_approx_percentile_col_mapping, + } = normal; + let AggInfo { + calls: approx_percentile_agg_calls, + col_mapping: approx_percentile_col_mapping, + } = approx; + + let needs_keyed_merge = (!non_approx_percentile_agg_calls.is_empty() + && !approx_percentile_agg_calls.is_empty()) + || approx_percentile_agg_calls.len() >= 2; + core.input = if needs_keyed_merge { + // If there's keyed merge, we need to share the input. + StreamShare::new_from_input(stream_input.clone()).into() + } else { + stream_input + }; + core.agg_calls = non_approx_percentile_agg_calls; + + let approx_percentile = + self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls); + + // ====== Handle normal aggs + let total_agg_calls = core + .agg_calls + .iter() + .enumerate() + .map(|(partial_output_idx, agg_call)| { + agg_call.partial_to_total_agg_call(partial_output_idx) + }) + .collect_vec(); let local_agg = StreamStatelessSimpleAgg::new(core); let exchange = RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?; - let global_agg = new_stream_simple_agg(Agg::new( - self.agg_calls() - .iter() - .enumerate() - .map(|(partial_output_idx, agg_call)| { - agg_call.partial_to_total_agg_call(partial_output_idx) - }) - .collect(), - IndexSet::empty(), - exchange, - )); - Ok(global_agg.into()) + let global_agg = + new_stream_simple_agg(Agg::new(total_agg_calls, IndexSet::empty(), exchange)); + + // ====== Merge approx percentile and normal aggs + if let Some(approx_percentile) = approx_percentile { + if needs_keyed_merge { + let keyed_merge = StreamKeyedMerge::new( + approx_percentile, + global_agg.into(), + approx_percentile_col_mapping, + non_approx_percentile_col_mapping, + )?; + Ok(keyed_merge.into()) + } else { + Ok(approx_percentile) + } + } else { + Ok(global_agg.into()) + } } /// Generate plan for stateless/stateful 2-phase streaming agg. @@ -242,6 +298,94 @@ impl LogicalAgg { } } + fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo { + let estimated_len = self.agg_calls().len() - 1; + let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len); + let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len); + let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len); + let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len); + for (output_idx, agg_call) in self.agg_calls().iter().enumerate() { + if agg_call.agg_kind == AggKind::Builtin(PbAggKind::ApproxPercentile) { + approx_percentile_agg_calls.push(agg_call.clone()); + approx_percentile_col_mapping.push(Some(output_idx)); + } else { + non_approx_percentile_agg_calls.push(agg_call.clone()); + non_approx_percentile_col_mapping.push(Some(output_idx)); + } + } + let normal = AggInfo { + calls: non_approx_percentile_agg_calls, + col_mapping: ColIndexMapping::new( + non_approx_percentile_col_mapping, + self.agg_calls().len(), + ), + }; + let approx = AggInfo { + calls: approx_percentile_agg_calls, + col_mapping: ColIndexMapping::new( + approx_percentile_col_mapping, + self.agg_calls().len(), + ), + }; + SeparatedAggInfo { normal, approx } + } + + fn build_approx_percentile_agg( + &self, + input: PlanRef, + approx_percentile_agg_call: &PlanAggCall, + ) -> PlanRef { + let local_approx_percentile = + StreamLocalApproxPercentile::new(input, approx_percentile_agg_call); + let global_approx_percentile = StreamGlobalApproxPercentile::new( + local_approx_percentile.into(), + approx_percentile_agg_call, + ); + global_approx_percentile.into() + } + + /// If only 1 approx percentile, just return it. + /// Otherwise build a tree of approx percentile with `KeyedMerge`. + /// e.g. + /// ApproxPercentile(col1, 0.5) as x, + /// ApproxPercentile(col2, 0.5) as y, + /// ApproxPercentile(col3, 0.5) as z + /// will be built as + /// `KeyedMerge` + /// / \ + /// `KeyedMerge` z + /// / \ + /// x y + + fn build_approx_percentile_aggs( + &self, + input: PlanRef, + approx_percentile_agg_call: &[PlanAggCall], + ) -> Option { + if approx_percentile_agg_call.is_empty() { + return None; + } + let approx_percentile_plans = approx_percentile_agg_call + .iter() + .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call)) + .collect_vec(); + assert!(!approx_percentile_plans.is_empty()); + let mut iter = approx_percentile_plans.into_iter(); + let mut acc = iter.next().unwrap(); + for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) { + let new_size = current_size + 1; + let keyed_merge = StreamKeyedMerge::new( + acc, + plan, + ColIndexMapping::identity_or_none(current_size, new_size), + ColIndexMapping::new(vec![Some(current_size)], new_size), + ) + .expect("failed to build keyed merge"); + acc = keyed_merge.into(); + } + Some(acc) + } + pub fn core(&self) -> &Agg { &self.core } @@ -512,6 +656,35 @@ impl LogicalAggBuilder { _ => unreachable!(), } } + AggKind::Builtin(PbAggKind::ApproxPercentile) => { + if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() { + // Rewrite DESC into 1.0-percentile for approx_percentile. + let prev_percentile = agg_call.direct_args[0].clone(); + let new_percentile = 1.0 + - prev_percentile + .get_data() + .as_ref() + .unwrap() + .as_float64() + .into_inner(); + let new_percentile = Some(ScalarImpl::Float64(new_percentile.into())); + let new_percentile = Literal::new(new_percentile, DataType::Float64); + let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()]; + + let new_agg_call = AggCall { + order_by: OrderBy::any(), + direct_args: new_direct_args, + ..agg_call + }; + Ok(push_agg_call(new_agg_call)?.into()) + } else { + let new_agg_call = AggCall { + order_by: OrderBy::any(), + ..agg_call + }; + Ok(push_agg_call(new_agg_call)?.into()) + } + } _ => Ok(push_agg_call(agg_call)?.into()), } } @@ -1130,8 +1303,26 @@ impl ToStream for LogicalAgg { }, final_agg.agg_calls().len(), ) + } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() { + if eowc { + return Err(ErrorCode::InvalidInputSyntax( + "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`" + .to_string(), + ) + .into()); + } + (plan.clone(), 1) + } else if let Some(stream_keyed_merge) = plan.as_stream_keyed_merge() { + if eowc { + return Err(ErrorCode::InvalidInputSyntax( + "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`" + .to_string(), + ) + .into()); + } + (plan.clone(), stream_keyed_merge.base.schema().len()) } else { - panic!("the root PlanNode must be either StreamHashAgg or StreamSimpleAgg"); + panic!("the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamKeyedMerge"); }; if self.agg_calls().len() == n_final_agg_calls { diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index ee2b16265e7a..2cf7e67dd2b6 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -893,10 +893,13 @@ mod stream_exchange; mod stream_expand; mod stream_filter; mod stream_fs_fetch; +mod stream_global_approx_percentile; mod stream_group_topn; mod stream_hash_agg; mod stream_hash_join; mod stream_hop_window; +mod stream_keyed_merge; +mod stream_local_approx_percentile; mod stream_materialize; mod stream_now; mod stream_over_window; @@ -1002,10 +1005,13 @@ pub use stream_exchange::StreamExchange; pub use stream_expand::StreamExpand; pub use stream_filter::StreamFilter; pub use stream_fs_fetch::StreamFsFetch; +pub use stream_global_approx_percentile::StreamGlobalApproxPercentile; pub use stream_group_topn::StreamGroupTopN; pub use stream_hash_agg::StreamHashAgg; pub use stream_hash_join::StreamHashJoin; pub use stream_hop_window::StreamHopWindow; +pub use stream_keyed_merge::StreamKeyedMerge; +pub use stream_local_approx_percentile::StreamLocalApproxPercentile; pub use stream_materialize::StreamMaterialize; pub use stream_now::StreamNow; pub use stream_over_window::StreamOverWindow; @@ -1150,6 +1156,9 @@ macro_rules! for_all_plan_nodes { , { Stream, OverWindow } , { Stream, FsFetch } , { Stream, ChangeLog } + , { Stream, GlobalApproxPercentile } + , { Stream, LocalApproxPercentile } + , { Stream, KeyedMerge } } }; } @@ -1276,6 +1285,9 @@ macro_rules! for_stream_plan_nodes { , { Stream, OverWindow } , { Stream, FsFetch } , { Stream, ChangeLog } + , { Stream, GlobalApproxPercentile } + , { Stream, LocalApproxPercentile } + , { Stream, KeyedMerge } } }; } diff --git a/src/frontend/src/optimizer/plan_node/stream_global_approx_percentile.rs b/src/frontend/src/optimizer/plan_node/stream_global_approx_percentile.rs new file mode 100644 index 000000000000..22fe3b33ab69 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_global_approx_percentile.rs @@ -0,0 +1,115 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use fixedbitset::FixedBitSet; +use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::catalog::{Field, Schema}; +use risingwave_common::types::DataType; +use risingwave_pb::stream_plan::stream_node::PbNodeBody; + +use crate::expr::{ExprRewriter, ExprVisitor, Literal}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::generic::GenericPlanRef; +use crate::optimizer::plan_node::stream::StreamPlanRef; +use crate::optimizer::plan_node::utils::{childless_record, Distill}; +use crate::optimizer::plan_node::{ + ExprRewritable, PlanAggCall, PlanBase, PlanTreeNodeUnary, Stream, StreamNode, +}; +use crate::optimizer::property::Distribution; +use crate::stream_fragmenter::BuildFragmentGraphState; +use crate::PlanRef; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StreamGlobalApproxPercentile { + pub base: PlanBase, + input: PlanRef, + /// Quantile + quantile: Literal, + /// Used to compute the exponent bucket base. + relative_error: Literal, +} + +impl StreamGlobalApproxPercentile { + pub fn new(input: PlanRef, approx_percentile_agg_call: &PlanAggCall) -> Self { + let schema = Schema::new(vec![Field::with_name( + DataType::Float64, + "approx_percentile", + )]); + let watermark_columns = FixedBitSet::with_capacity(1); + let base = PlanBase::new_stream( + input.ctx(), + schema, + Some(vec![]), + input.functional_dependency().clone(), + Distribution::Single, + input.append_only(), + input.emit_on_window_close(), + watermark_columns, + input.columns_monotonicity().clone(), + ); + Self { + base, + input, + quantile: approx_percentile_agg_call.direct_args[0].clone(), + relative_error: approx_percentile_agg_call.direct_args[1].clone(), + } + } +} + +impl Distill for StreamGlobalApproxPercentile { + fn distill<'a>(&self) -> XmlNode<'a> { + let out = vec![ + ("quantile", Pretty::debug(&self.quantile)), + ("relative_error", Pretty::debug(&self.relative_error)), + ]; + childless_record("StreamGlobalApproxPercentile", out) + } +} + +impl PlanTreeNodeUnary for StreamGlobalApproxPercentile { + fn input(&self) -> PlanRef { + self.input.clone() + } + + fn clone_with_input(&self, input: PlanRef) -> Self { + Self { + base: self.base.clone(), + input, + quantile: self.quantile.clone(), + relative_error: self.relative_error.clone(), + } + } +} + +impl_plan_tree_node_for_unary! {StreamGlobalApproxPercentile} + +impl StreamNode for StreamGlobalApproxPercentile { + fn to_stream_prost_body(&self, _state: &mut BuildFragmentGraphState) -> PbNodeBody { + todo!() + } +} + +impl ExprRewritable for StreamGlobalApproxPercentile { + fn has_rewritable_expr(&self) -> bool { + false + } + + fn rewrite_exprs(&self, _rewriter: &mut dyn ExprRewriter) -> PlanRef { + unimplemented!() + } +} + +impl ExprVisitable for StreamGlobalApproxPercentile { + fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {} +} diff --git a/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs b/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs new file mode 100644 index 000000000000..e84a2c9dd9b0 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_keyed_merge.rs @@ -0,0 +1,156 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::anyhow; +use fixedbitset::FixedBitSet; +use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::bail; +use risingwave_common::catalog::Schema; +use risingwave_common::util::column_index_mapping::ColIndexMapping; +use risingwave_pb::stream_plan::stream_node::PbNodeBody; + +use crate::error::Result; +use crate::expr::{ExprRewriter, ExprVisitor}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef}; +use crate::optimizer::plan_node::stream::StreamPlanRef; +use crate::optimizer::plan_node::utils::{childless_record, Distill}; +use crate::optimizer::plan_node::{ + ExprRewritable, PlanBase, PlanTreeNodeBinary, Stream, StreamNode, +}; +use crate::stream_fragmenter::BuildFragmentGraphState; +use crate::PlanRef; + +/// `StreamKeyedMerge` is used for merging two streams with the same stream key and distribution. +/// It will buffer the outputs from its input streams until we receive a barrier. +/// On receiving a barrier, it will `Project` their outputs according +/// to the provided `lhs_mapping` and `rhs_mapping`. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StreamKeyedMerge { + pub base: PlanBase, + pub lhs_input: PlanRef, + pub rhs_input: PlanRef, + /// Maps input from the lhs to the output. + pub lhs_mapping: ColIndexMapping, + /// Maps input from the rhs to the output. + pub rhs_mapping: ColIndexMapping, +} + +impl StreamKeyedMerge { + pub fn new( + lhs_input: PlanRef, + rhs_input: PlanRef, + lhs_mapping: ColIndexMapping, + rhs_mapping: ColIndexMapping, + ) -> Result { + assert_eq!(lhs_mapping.target_size(), rhs_mapping.target_size()); + assert_eq!(lhs_input.distribution(), rhs_input.distribution()); + assert_eq!(lhs_input.stream_key(), rhs_input.stream_key()); + let mut schema_fields = Vec::with_capacity(lhs_mapping.target_size()); + let o2i_lhs = lhs_mapping + .inverse() + .ok_or_else(|| anyhow!("lhs_mapping should be invertible"))?; + let o2i_rhs = rhs_mapping + .inverse() + .ok_or_else(|| anyhow!("rhs_mapping should be invertible"))?; + for output_idx in 0..lhs_mapping.target_size() { + if let Some(lhs_idx) = o2i_lhs.try_map(output_idx) { + schema_fields.push(lhs_input.schema().fields()[lhs_idx].clone()); + } else if let Some(rhs_idx) = o2i_rhs.try_map(output_idx) { + schema_fields.push(rhs_input.schema().fields()[rhs_idx].clone()); + } else { + bail!( + "output index {} not found in either lhs or rhs mapping", + output_idx + ); + } + } + let schema = Schema::new(schema_fields); + let watermark_columns = FixedBitSet::with_capacity(schema.fields.len()); + + let base = PlanBase::new_stream( + lhs_input.ctx(), + schema, + lhs_input.stream_key().map(|k| k.to_vec()), + lhs_input.functional_dependency().clone(), + lhs_input.distribution().clone(), + lhs_input.append_only(), + lhs_input.emit_on_window_close(), + watermark_columns, + lhs_input.columns_monotonicity().clone(), + ); + Ok(Self { + base, + lhs_input, + rhs_input, + lhs_mapping, + rhs_mapping, + }) + } +} + +impl Distill for StreamKeyedMerge { + fn distill<'a>(&self) -> XmlNode<'a> { + let mut out = Vec::with_capacity(1); + + if self.base.ctx().is_explain_verbose() { + let f = |t| Pretty::debug(&t); + let e = Pretty::Array(self.base.schema().fields().iter().map(f).collect()); + out = vec![("output", e)]; + } + childless_record("StreamKeyedMerge", out) + } +} + +impl PlanTreeNodeBinary for StreamKeyedMerge { + fn left(&self) -> PlanRef { + self.lhs_input.clone() + } + + fn right(&self) -> PlanRef { + self.rhs_input.clone() + } + + fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self { + Self { + base: self.base.clone(), + lhs_input: left, + rhs_input: right, + lhs_mapping: self.lhs_mapping.clone(), + rhs_mapping: self.rhs_mapping.clone(), + } + } +} + +impl_plan_tree_node_for_binary! { StreamKeyedMerge } + +impl StreamNode for StreamKeyedMerge { + fn to_stream_prost_body(&self, _state: &mut BuildFragmentGraphState) -> PbNodeBody { + todo!() + } +} + +impl ExprRewritable for StreamKeyedMerge { + fn has_rewritable_expr(&self) -> bool { + false + } + + fn rewrite_exprs(&self, _rewriter: &mut dyn ExprRewriter) -> PlanRef { + unimplemented!() + } +} + +impl ExprVisitable for StreamKeyedMerge { + fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {} +} diff --git a/src/frontend/src/optimizer/plan_node/stream_local_approx_percentile.rs b/src/frontend/src/optimizer/plan_node/stream_local_approx_percentile.rs new file mode 100644 index 000000000000..a4fb2a602917 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_local_approx_percentile.rs @@ -0,0 +1,127 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use fixedbitset::FixedBitSet; +use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::catalog::{Field, Schema}; +use risingwave_common::types::DataType; +use risingwave_pb::stream_plan::stream_node::PbNodeBody; + +use crate::expr::{ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef}; +use crate::optimizer::plan_node::stream::StreamPlanRef; +use crate::optimizer::plan_node::utils::{childless_record, watermark_pretty, Distill}; +use crate::optimizer::plan_node::{ + ExprRewritable, PlanAggCall, PlanBase, PlanTreeNodeUnary, Stream, StreamNode, +}; +use crate::stream_fragmenter::BuildFragmentGraphState; +use crate::PlanRef; + +// Does not contain `core` because no other plan nodes share +// common fields and schema, even GlobalApproxPercentile. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StreamLocalApproxPercentile { + pub base: PlanBase, + input: PlanRef, + quantile: Literal, + relative_error: Literal, + percentile_col: InputRef, +} + +impl StreamLocalApproxPercentile { + pub fn new(input: PlanRef, approx_percentile_agg_call: &PlanAggCall) -> Self { + let schema = Schema::new(vec![ + Field::with_name(DataType::Int64, "bucket_id"), + Field::with_name(DataType::Int64, "count"), + ]); + // FIXME(kwannoel): How does watermark work with FixedBitSet + let watermark_columns = FixedBitSet::with_capacity(2); + let base = PlanBase::new_stream( + input.ctx(), + schema, + input.stream_key().map(|k| k.to_vec()), + input.functional_dependency().clone(), + input.distribution().clone(), + input.append_only(), + input.emit_on_window_close(), + watermark_columns, + input.columns_monotonicity().clone(), + ); + Self { + base, + input, + quantile: approx_percentile_agg_call.direct_args[0].clone(), + relative_error: approx_percentile_agg_call.direct_args[1].clone(), + percentile_col: approx_percentile_agg_call.inputs[0].clone(), + } + } +} + +impl Distill for StreamLocalApproxPercentile { + fn distill<'a>(&self) -> XmlNode<'a> { + let mut out = Vec::with_capacity(5); + out.push(( + "percentile_col", + Pretty::display(&InputRefDisplay { + input_ref: &self.percentile_col, + input_schema: self.input.schema(), + }), + )); + out.push(("quantile", Pretty::debug(&self.quantile))); + out.push(("relative_error", Pretty::debug(&self.relative_error))); + if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) { + out.push(("output_watermarks", ow)); + } + childless_record("StreamLocalApproxPercentile", out) + } +} + +impl PlanTreeNodeUnary for StreamLocalApproxPercentile { + fn input(&self) -> PlanRef { + self.input.clone() + } + + fn clone_with_input(&self, input: PlanRef) -> Self { + Self { + base: self.base.clone(), + input, + quantile: self.quantile.clone(), + relative_error: self.relative_error.clone(), + percentile_col: self.percentile_col.clone(), + } + } +} + +impl_plan_tree_node_for_unary! {StreamLocalApproxPercentile} + +impl StreamNode for StreamLocalApproxPercentile { + fn to_stream_prost_body(&self, _state: &mut BuildFragmentGraphState) -> PbNodeBody { + todo!() + } +} + +impl ExprRewritable for StreamLocalApproxPercentile { + fn has_rewritable_expr(&self) -> bool { + false + } + + fn rewrite_exprs(&self, _rewriter: &mut dyn ExprRewriter) -> PlanRef { + unimplemented!() + } +} + +impl ExprVisitable for StreamLocalApproxPercentile { + fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {} +} diff --git a/src/frontend/src/optimizer/plan_node/stream_share.rs b/src/frontend/src/optimizer/plan_node/stream_share.rs index b082d82b022d..7e6f87fa5c27 100644 --- a/src/frontend/src/optimizer/plan_node/stream_share.rs +++ b/src/frontend/src/optimizer/plan_node/stream_share.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cell::RefCell; + use pretty_xmlish::XmlNode; use risingwave_pb::stream_plan::stream_node::PbNodeBody; use risingwave_pb::stream_plan::PbStreamNode; @@ -50,6 +52,13 @@ impl StreamShare { StreamShare { base, core } } + + pub fn new_from_input(input: PlanRef) -> Self { + let core = generic::Share { + input: RefCell::new(input), + }; + Self::new(core) + } } impl Distill for StreamShare { diff --git a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs index eab534b1b9a3..33bb59e59bf1 100644 --- a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs @@ -172,6 +172,7 @@ impl Rule for ApplyAggTransposeRule { | PbAggKind::LastValue | PbAggKind::InternalLastSeenValue // All statistical aggregates only consider non-null inputs. + | PbAggKind::ApproxPercentile | PbAggKind::VarPop | PbAggKind::VarSamp | PbAggKind::StddevPop diff --git a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs index 353b4ff7a7f6..52aad1336a7b 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs @@ -57,8 +57,11 @@ impl Rule for DistinctAggRule { c.agg_kind ); let agg_kind_ok = !matches!(c.agg_kind, agg_kinds::simply_cannot_two_phase!()); - let order_ok = matches!(c.agg_kind, agg_kinds::result_unaffected_by_order_by!()) - || c.order_by.is_empty(); + let order_ok = matches!( + c.agg_kind, + agg_kinds::result_unaffected_by_order_by!() + | AggKind::Builtin(PbAggKind::ApproxPercentile) + ) || c.order_by.is_empty(); agg_kind_ok && order_ok }) { tracing::warn!("DistinctAggRule: unsupported agg kind, fallback to backend impl");