diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 1979e4ea1fb7..75aa7249accc 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -1067,6 +1067,13 @@ sql: | CREATE TABLE t (v1 int, v2 int); SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, count(*), max(v2) as m2, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v2) as y from t; + expected_outputs: + - logical_plan + - stream_plan +- name: test approx percentile with default relative_error + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5) WITHIN GROUP (order by v1) from t; expected_outputs: - logical_plan - stream_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index eca739788bf6..9fd70f1fb28f 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -2103,3 +2103,19 @@ └─StreamHashAgg { group_key: [$expr5], aggs: [sum(t.v1), count, max(t.v2)] } └─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr3, t.v2, t.v2::Float64 as $expr4, t._row_id, Vnode(t._row_id) as $expr5] } └─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: test approx percentile with default relative_error + sql: | + CREATE TABLE t (v1 int); + SELECT approx_percentile(0.5) WITHIN GROUP (order by v1) from t; + logical_plan: |- + LogicalProject { exprs: [approx_percentile($expr1)] } + └─LogicalAgg { aggs: [approx_percentile($expr1)] } + └─LogicalProject { exprs: [t.v1::Float64 as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t._row_id] } + stream_plan: |- + StreamMaterialize { columns: [approx_percentile], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 } + └─StreamExchange { dist: Single } + └─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 } + └─StreamProject { exprs: [t.v1::Float64 as $expr1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } diff --git a/src/frontend/src/binder/expr/function/aggregate.rs b/src/frontend/src/binder/expr/function/aggregate.rs index 1e7b76bf7629..d6410616c1d9 100644 --- a/src/frontend/src/binder/expr/function/aggregate.rs +++ b/src/frontend/src/binder/expr/function/aggregate.rs @@ -14,7 +14,7 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; -use risingwave_common::types::DataType; +use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; use risingwave_sqlparser::ast::{Function, FunctionArgExpr}; @@ -139,12 +139,9 @@ impl Binder { let order_by = OrderBy::new(vec![self.bind_order_by_expr(within_group)?]); // check signature and do implicit cast - match (&kind, direct_args.as_mut_slice(), args.as_mut_slice()) { - ( - AggKind::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), - [fraction], - [arg], - ) => { + match (&kind, direct_args.len(), args.as_mut_slice()) { + (AggKind::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => { + let fraction = &mut direct_args[0]; decimal_to_float64(fraction, &kind)?; if matches!(&kind, AggKind::Builtin(PbAggKind::PercentileCont)) { arg.cast_implicit_mut(DataType::Float64).map_err(|_| { @@ -155,14 +152,30 @@ impl Binder { })?; } } - (AggKind::Builtin(PbAggKind::Mode), [], [_arg]) => {} - ( - AggKind::Builtin(PbAggKind::ApproxPercentile), - [percentile, relative_error], - [_percentile_col], - ) => { + (AggKind::Builtin(PbAggKind::Mode), 0, [_arg]) => {} + (AggKind::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => { + let percentile = &mut direct_args[0]; decimal_to_float64(percentile, &kind)?; - decimal_to_float64(relative_error, &kind)?; + match direct_args.len() { + 2 => { + let relative_error = &mut direct_args[1]; + decimal_to_float64(relative_error, &kind)?; + } + 1 => { + let relative_error: ExprImpl = Literal::new( + ScalarImpl::Float64(0.01.into()).into(), + DataType::Float64, + ) + .into(); + direct_args.push(relative_error); + } + _ => { + return Err(ErrorCode::InvalidInputSyntax( + "invalid direct args for approx_percentile aggregation".to_string(), + ) + .into()) + } + } } _ => { return Err(ErrorCode::InvalidInputSyntax(format!(