Skip to content

Commit

Permalink
pass binder test
Browse files Browse the repository at this point in the history
  • Loading branch information
kwannoel committed Jul 18, 2024
1 parent d40caaf commit 3f491de
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,15 @@ pub mod agg_kinds {
#[macro_export]
macro_rules! ordered_set {
() => {
AggKind::PercentileCont | AggKind::PercentileDisc | AggKind::Mode
AggKind::PercentileCont
| AggKind::PercentileDisc
| AggKind::Mode
| AggKind::ApproxPercentile
};
}
pub use ordered_set;

use crate::aggregate::AggKind;
}

impl AggKind {
Expand Down
139 changes: 139 additions & 0 deletions src/expr/impl/src/aggregate/approx_percentile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Range;

use risingwave_common::array::*;
use risingwave_common::row::Row;
use risingwave_common::types::*;
use risingwave_common_estimate_size::EstimateSize;
use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
use risingwave_expr::{build_aggregate, Result};

/// Computes the approx percentile, a value corresponding to the specified fraction within the
/// ordered set of aggregated argument values. This will interpolate between adjacent input items if
/// needed.
///
/// ```slt
/// statement ok
/// create table t(x int, y bigint, z real, w double, v varchar);
///
/// statement ok
/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000');
///
/// query R
/// select percentile_cont(0.45) within group (order by x desc) from t;
/// ----
/// 2.1
///
/// query R
/// select percentile_cont(0.45) within group (order by y desc) from t;
/// ----
/// 21
///
/// query R
/// select percentile_cont(0.45) within group (order by z desc) from t;
/// ----
/// 210
///
/// query R
/// select percentile_cont(0.45) within group (order by w desc) from t;
/// ----
/// 2100
///
/// query R
/// select percentile_cont(NULL) within group (order by w desc) from t;
/// ----
/// NULL
///
/// statement ok
/// drop table t;
/// ```
#[build_aggregate("approx_percentile(float8) -> float8")]
fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
let fraction = agg.direct_args[0]
.literal()
.map(|x| (*x.as_float64()).into());
Ok(Box::new(ApproxPercentile { fraction }))
}

pub struct ApproxPercentile {
fraction: Option<f64>,
}

#[derive(Debug, Default, EstimateSize)]
struct State(Vec<f64>);

impl AggStateDyn for State {}

impl ApproxPercentile {
fn add_datum(&self, state: &mut State, datum_ref: DatumRef<'_>) {
if let Some(datum) = datum_ref.to_owned_datum() {
state.0.push((*datum.as_float64()).into());
}
}
}

#[async_trait::async_trait]
impl AggregateFunction for ApproxPercentile {
fn return_type(&self) -> DataType {
DataType::Float64
}

fn create_state(&self) -> Result<AggregateState> {
Ok(AggregateState::Any(Box::<State>::default()))
}

async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
let state = state.downcast_mut();
for (_, row) in input.rows() {
self.add_datum(state, row.datum_at(0));
}
Ok(())
}

async fn update_range(
&self,
state: &mut AggregateState,
input: &StreamChunk,
range: Range<usize>,
) -> Result<()> {
let state = state.downcast_mut();
for (_, row) in input.rows_in(range) {
self.add_datum(state, row.datum_at(0));
}
Ok(())
}

async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
let state = &state.downcast_ref::<State>().0;
Ok(
if let Some(fraction) = self.fraction
&& !state.is_empty()
{
let rn = fraction * (state.len() - 1) as f64;
let crn = f64::ceil(rn);
let frn = f64::floor(rn);
let result = if crn == frn {
state[crn as usize]
} else {
(crn - rn) * state[frn as usize] + (rn - frn) * state[crn as usize]
};
Some(result.into())
} else {
None
},
)
}
}
1 change: 1 addition & 0 deletions src/expr/impl/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

mod approx_count_distinct;
mod approx_percentile;
mod array_agg;
mod bit_and;
mod bit_or;
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ impl Binder {
}
}
(AggKind::Mode, [], [_arg]) => {}
(AggKind::ApproxPercentile, [_arg, _arg2], [_arg3]) => {}
_ => {
return Err(ErrorCode::InvalidInputSyntax(format!(
"invalid direct args or within group argument for `{}` aggregation",
Expand Down
217 changes: 217 additions & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,4 +759,221 @@ mod tests {

expected.assert_eq(&format!("{:#?}", bound));
}

#[tokio::test]
async fn test_bind_approx_percentile() {
let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
"SELECT approx_percentile(0.5, 0.01) WITHIN GROUP (ORDER BY generate_series) FROM generate_series(1, 100)",
).unwrap().into_iter().next().unwrap();
let parse_expected = expect![[r#"
Query(
Query {
with: None,
body: Select(
Select {
distinct: All,
projection: [
UnnamedExpr(
Function(
Function {
name: ObjectName(
[
Ident {
value: "approx_percentile",
quote_style: None,
},
],
),
args: [
Unnamed(
Expr(
Value(
Number(
"0.5",
),
),
),
),
Unnamed(
Expr(
Value(
Number(
"0.01",
),
),
),
),
],
variadic: false,
over: None,
distinct: false,
order_by: [],
filter: None,
within_group: Some(
OrderByExpr {
expr: Identifier(
Ident {
value: "generate_series",
quote_style: None,
},
),
asc: None,
nulls_first: None,
},
),
},
),
),
],
from: [
TableWithJoins {
relation: TableFunction {
name: ObjectName(
[
Ident {
value: "generate_series",
quote_style: None,
},
],
),
alias: None,
args: [
Unnamed(
Expr(
Value(
Number(
"1",
),
),
),
),
Unnamed(
Expr(
Value(
Number(
"100",
),
),
),
),
],
with_ordinality: false,
},
joins: [],
},
],
lateral_views: [],
selection: None,
group_by: [],
having: None,
},
),
order_by: [],
limit: None,
offset: None,
fetch: None,
},
)"#]];
parse_expected.assert_eq(&format!("{:#?}", stmt));

let mut binder = mock_binder();
let bound = binder.bind(stmt).unwrap();

let expected = expect![[r#"
Query(
BoundQuery {
body: Select(
BoundSelect {
distinct: All,
select_items: [
AggCall(
AggCall {
agg_kind: ApproxPercentile,
return_type: Float64,
args: [
FunctionCall(
FunctionCall {
func_type: Cast,
return_type: Float64,
inputs: [
InputRef(
InputRef {
index: 0,
data_type: Int32,
},
),
],
},
),
],
filter: Condition {
conjunctions: [],
},
},
),
],
aliases: [
Some(
"approx_percentile",
),
],
from: Some(
TableFunction {
expr: TableFunction(
FunctionCall {
function_type: GenerateSeries,
return_type: Int32,
args: [
Literal(
Literal {
data: Some(
Int32(
1,
),
),
data_type: Some(
Int32,
),
},
),
Literal(
Literal {
data: Some(
Int32(
100,
),
),
data_type: Some(
Int32,
),
},
),
],
},
),
with_ordinality: false,
},
),
where_clause: None,
group_by: GroupKey(
[],
),
having: None,
schema: Schema {
fields: [
approx_percentile:Float64,
],
},
},
),
order: [],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: [],
},
)"#]];

expected.assert_eq(&format!("{:#?}", bound));
}
}
Loading

0 comments on commit 3f491de

Please sign in to comment.