Skip to content

Commit

Permalink
feat(frontend): bind default value for approx_percentile relative_e…
Browse files Browse the repository at this point in the history
…rror (#18082)
  • Loading branch information
kwannoel authored Aug 19, 2024
1 parent 33dc6fd commit d16847d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
7 changes: 7 additions & 0 deletions src/frontend/planner_test/tests/testdata/input/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,13 @@
sql: |
CREATE TABLE t (v1 int, v2 int);
SELECT sum(v1) as s1, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v1) as x, count(*), max(v2) as m2, approx_percentile(0.5, 0.01) WITHIN GROUP (order by v2) as y from t;
expected_outputs:
- logical_plan
- stream_plan
- name: test approx percentile with default relative_error
sql: |
CREATE TABLE t (v1 int);
SELECT approx_percentile(0.5) WITHIN GROUP (order by v1) from t;
expected_outputs:
- logical_plan
- stream_plan
16 changes: 16 additions & 0 deletions src/frontend/planner_test/tests/testdata/output/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2103,3 +2103,19 @@
└─StreamHashAgg { group_key: [$expr5], aggs: [sum(t.v1), count, max(t.v2)] }
└─StreamProject { exprs: [t.v1, t.v1::Float64 as $expr3, t.v2, t.v2::Float64 as $expr4, t._row_id, Vnode(t._row_id) as $expr5] }
└─StreamTableScan { table: t, columns: [t.v1, t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: test approx percentile with default relative_error
sql: |
CREATE TABLE t (v1 int);
SELECT approx_percentile(0.5) WITHIN GROUP (order by v1) from t;
logical_plan: |-
LogicalProject { exprs: [approx_percentile($expr1)] }
└─LogicalAgg { aggs: [approx_percentile($expr1)] }
└─LogicalProject { exprs: [t.v1::Float64 as $expr1] }
└─LogicalScan { table: t, columns: [t.v1, t._row_id] }
stream_plan: |-
StreamMaterialize { columns: [approx_percentile], stream_key: [], pk_columns: [], pk_conflict: NoCheck }
└─StreamGlobalApproxPercentile { quantile: 0.5:Float64, relative_error: 0.01:Float64 }
└─StreamExchange { dist: Single }
└─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 }
└─StreamProject { exprs: [t.v1::Float64 as $expr1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
41 changes: 27 additions & 14 deletions src/frontend/src/binder/expr/function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use itertools::Itertools;
use risingwave_common::bail_not_implemented;
use risingwave_common::types::DataType;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind};
use risingwave_sqlparser::ast::{Function, FunctionArgExpr};

Expand Down Expand Up @@ -139,12 +139,9 @@ impl Binder {
let order_by = OrderBy::new(vec![self.bind_order_by_expr(within_group)?]);

// check signature and do implicit cast
match (&kind, direct_args.as_mut_slice(), args.as_mut_slice()) {
(
AggKind::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc),
[fraction],
[arg],
) => {
match (&kind, direct_args.len(), args.as_mut_slice()) {
(AggKind::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => {
let fraction = &mut direct_args[0];
decimal_to_float64(fraction, &kind)?;
if matches!(&kind, AggKind::Builtin(PbAggKind::PercentileCont)) {
arg.cast_implicit_mut(DataType::Float64).map_err(|_| {
Expand All @@ -155,14 +152,30 @@ impl Binder {
})?;
}
}
(AggKind::Builtin(PbAggKind::Mode), [], [_arg]) => {}
(
AggKind::Builtin(PbAggKind::ApproxPercentile),
[percentile, relative_error],
[_percentile_col],
) => {
(AggKind::Builtin(PbAggKind::Mode), 0, [_arg]) => {}
(AggKind::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => {
let percentile = &mut direct_args[0];
decimal_to_float64(percentile, &kind)?;
decimal_to_float64(relative_error, &kind)?;
match direct_args.len() {
2 => {
let relative_error = &mut direct_args[1];
decimal_to_float64(relative_error, &kind)?;
}
1 => {
let relative_error: ExprImpl = Literal::new(
ScalarImpl::Float64(0.01.into()).into(),
DataType::Float64,
)
.into();
direct_args.push(relative_error);
}
_ => {
return Err(ErrorCode::InvalidInputSyntax(
"invalid direct args for approx_percentile aggregation".to_string(),
)
.into())
}
}
}
_ => {
return Err(ErrorCode::InvalidInputSyntax(format!(
Expand Down

0 comments on commit d16847d

Please sign in to comment.