From 3f491de74d9e1a15d0742753f27c10e341e92704 Mon Sep 17 00:00:00 2001 From: Noel Kwan Date: Thu, 18 Jul 2024 17:03:46 +0800 Subject: [PATCH] pass binder test --- src/expr/core/src/aggregate/def.rs | 7 +- .../impl/src/aggregate/approx_percentile.rs | 139 +++++++++++ src/expr/impl/src/aggregate/mod.rs | 1 + src/frontend/src/binder/expr/function.rs | 1 + src/frontend/src/binder/mod.rs | 217 ++++++++++++++++++ src/frontend/src/expr/agg_call.rs | 2 + 6 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 src/expr/impl/src/aggregate/approx_percentile.rs diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 3d729cef456c6..b807d2b88b5e0 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -464,10 +464,15 @@ pub mod agg_kinds { #[macro_export] macro_rules! ordered_set { () => { - AggKind::PercentileCont | AggKind::PercentileDisc | AggKind::Mode + AggKind::PercentileCont + | AggKind::PercentileDisc + | AggKind::Mode + | AggKind::ApproxPercentile }; } pub use ordered_set; + + use crate::aggregate::AggKind; } impl AggKind { diff --git a/src/expr/impl/src/aggregate/approx_percentile.rs b/src/expr/impl/src/aggregate/approx_percentile.rs new file mode 100644 index 0000000000000..9b5b0c9e4ac06 --- /dev/null +++ b/src/expr/impl/src/aggregate/approx_percentile.rs @@ -0,0 +1,139 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Range; + +use risingwave_common::array::*; +use risingwave_common::row::Row; +use risingwave_common::types::*; +use risingwave_common_estimate_size::EstimateSize; +use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState}; +use risingwave_expr::{build_aggregate, Result}; + +/// Computes the approx percentile, a value corresponding to the specified fraction within the +/// ordered set of aggregated argument values. This will interpolate between adjacent input items if +/// needed. +/// +/// ```slt +/// statement ok +/// create table t(x int, y bigint, z real, w double, v varchar); +/// +/// statement ok +/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000'); +/// +/// query R +/// select percentile_cont(0.45) within group (order by x desc) from t; +/// ---- +/// 2.1 +/// +/// query R +/// select percentile_cont(0.45) within group (order by y desc) from t; +/// ---- +/// 21 +/// +/// query R +/// select percentile_cont(0.45) within group (order by z desc) from t; +/// ---- +/// 210 +/// +/// query R +/// select percentile_cont(0.45) within group (order by w desc) from t; +/// ---- +/// 2100 +/// +/// query R +/// select percentile_cont(NULL) within group (order by w desc) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[build_aggregate("approx_percentile(float8) -> float8")] +fn build(agg: &AggCall) -> Result> { + let fraction = agg.direct_args[0] + .literal() + .map(|x| (*x.as_float64()).into()); + Ok(Box::new(ApproxPercentile { fraction })) +} + +pub struct ApproxPercentile { + fraction: Option, +} + +#[derive(Debug, Default, EstimateSize)] +struct State(Vec); + +impl AggStateDyn for State {} + +impl ApproxPercentile { + fn add_datum(&self, state: &mut State, datum_ref: DatumRef<'_>) { + if let Some(datum) = datum_ref.to_owned_datum() { + state.0.push((*datum.as_float64()).into()); + } + } +} + +#[async_trait::async_trait] +impl AggregateFunction for ApproxPercentile { + fn return_type(&self) -> DataType { + DataType::Float64 + } + + fn create_state(&self) -> Result { + Ok(AggregateState::Any(Box::::default())) + } + + async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> { + let state = state.downcast_mut(); + for (_, row) in input.rows() { + self.add_datum(state, row.datum_at(0)); + } + Ok(()) + } + + async fn update_range( + &self, + state: &mut AggregateState, + input: &StreamChunk, + range: Range, + ) -> Result<()> { + let state = state.downcast_mut(); + for (_, row) in input.rows_in(range) { + self.add_datum(state, row.datum_at(0)); + } + Ok(()) + } + + async fn get_result(&self, state: &AggregateState) -> Result { + let state = &state.downcast_ref::().0; + Ok( + if let Some(fraction) = self.fraction + && !state.is_empty() + { + let rn = fraction * (state.len() - 1) as f64; + let crn = f64::ceil(rn); + let frn = f64::floor(rn); + let result = if crn == frn { + state[crn as usize] + } else { + (crn - rn) * state[frn as usize] + (rn - frn) * state[crn as usize] + }; + Some(result.into()) + } else { + None + }, + ) + } +} diff --git a/src/expr/impl/src/aggregate/mod.rs b/src/expr/impl/src/aggregate/mod.rs index c0b6a5ae64c31..349574018fedf 100644 --- a/src/expr/impl/src/aggregate/mod.rs +++ b/src/expr/impl/src/aggregate/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. mod approx_count_distinct; +mod approx_percentile; mod array_agg; mod bit_and; mod bit_or; diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 175a9c70fde58..8a0c471005b94 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -494,6 +494,7 @@ impl Binder { } } (AggKind::Mode, [], [_arg]) => {} + (AggKind::ApproxPercentile, [_arg, _arg2], [_arg3]) => {} _ => { return Err(ErrorCode::InvalidInputSyntax(format!( "invalid direct args or within group argument for `{}` aggregation", diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index cf12417334612..855e6c0cda00f 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -759,4 +759,221 @@ mod tests { expected.assert_eq(&format!("{:#?}", bound)); } + + #[tokio::test] + async fn test_bind_approx_percentile() { + let stmt = risingwave_sqlparser::parser::Parser::parse_sql( + "SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)", + ).unwrap().into_iter().next().unwrap(); + let parse_expected = expect![[r#" + Query( + Query { + with: None, + body: Select( + Select { + distinct: All, + projection: [ + UnnamedExpr( + Function( + Function { + name: ObjectName( + [ + Ident { + value: "approx_percentile", + quote_style: None, + }, + ], + ), + args: [ + Unnamed( + Expr( + Value( + Number( + "0.5", + ), + ), + ), + ), + Unnamed( + Expr( + Value( + Number( + "0.01", + ), + ), + ), + ), + ], + variadic: false, + over: None, + distinct: false, + order_by: [], + filter: None, + within_group: Some( + OrderByExpr { + expr: Identifier( + Ident { + value: "generate_series", + quote_style: None, + }, + ), + asc: None, + nulls_first: None, + }, + ), + }, + ), + ), + ], + from: [ + TableWithJoins { + relation: TableFunction { + name: ObjectName( + [ + Ident { + value: "generate_series", + quote_style: None, + }, + ], + ), + alias: None, + args: [ + Unnamed( + Expr( + Value( + Number( + "1", + ), + ), + ), + ), + Unnamed( + Expr( + Value( + Number( + "100", + ), + ), + ), + ), + ], + with_ordinality: false, + }, + joins: [], + }, + ], + lateral_views: [], + selection: None, + group_by: [], + having: None, + }, + ), + order_by: [], + limit: None, + offset: None, + fetch: None, + }, + )"#]]; + parse_expected.assert_eq(&format!("{:#?}", stmt)); + + let mut binder = mock_binder(); + let bound = binder.bind(stmt).unwrap(); + + let expected = expect![[r#" + Query( + BoundQuery { + body: Select( + BoundSelect { + distinct: All, + select_items: [ + AggCall( + AggCall { + agg_kind: ApproxPercentile, + return_type: Float64, + args: [ + FunctionCall( + FunctionCall { + func_type: Cast, + return_type: Float64, + inputs: [ + InputRef( + InputRef { + index: 0, + data_type: Int32, + }, + ), + ], + }, + ), + ], + filter: Condition { + conjunctions: [], + }, + }, + ), + ], + aliases: [ + Some( + "approx_percentile", + ), + ], + from: Some( + TableFunction { + expr: TableFunction( + FunctionCall { + function_type: GenerateSeries, + return_type: Int32, + args: [ + Literal( + Literal { + data: Some( + Int32( + 1, + ), + ), + data_type: Some( + Int32, + ), + }, + ), + Literal( + Literal { + data: Some( + Int32( + 100, + ), + ), + data_type: Some( + Int32, + ), + }, + ), + ], + }, + ), + with_ordinality: false, + }, + ), + where_clause: None, + group_by: GroupKey( + [], + ), + having: None, + schema: Schema { + fields: [ + approx_percentile:Float64, + ], + }, + }, + ), + order: [], + limit: None, + offset: None, + with_ties: false, + extra_order_exprs: [], + }, + )"#]]; + + expected.assert_eq(&format!("{:#?}", bound)); + } } diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index d72fba4dbcd2c..7bbd1f1952db2 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -65,7 +65,9 @@ impl AggCall { filter: Condition, direct_args: Vec, ) -> Result { + println!("find return type: args={:?}", args); let return_type = infer_type(agg_kind.into(), &mut args)?; + println!("find return type finished"); Ok(AggCall { agg_kind, return_type,