Skip to content

Commit

Permalink
introduce ignore_nulls: bool field in ast::Function
Browse files Browse the repository at this point in the history
  • Loading branch information
stdrc committed Aug 13, 2024
1 parent 196a998 commit f1710d9
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 168 deletions.
13 changes: 8 additions & 5 deletions src/frontend/src/binder/expr/function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Binder {
) -> Result<ExprImpl> {
self.ensure_aggregate_allowed()?;

let distinct = f.distinct;
let distinct = f.arg_list.distinct;
let filter_expr = f.filter.clone();

let (direct_args, args, order_by) = if matches!(kind, agg_kinds::ordered_set!()) {
Expand Down Expand Up @@ -105,14 +105,14 @@ impl Binder {

assert!(matches!(kind, agg_kinds::ordered_set!()));

if !f.order_by.is_empty() {
if !f.arg_list.order_by.is_empty() {
return Err(ErrorCode::InvalidInputSyntax(format!(
"ORDER BY is not allowed for ordered-set aggregation `{}`",
kind
))
.into());
}
if f.distinct {
if f.arg_list.distinct {
return Err(ErrorCode::InvalidInputSyntax(format!(
"DISTINCT is not allowed for ordered-set aggregation `{}`",
kind
Expand All @@ -128,6 +128,7 @@ impl Binder {
})?;

let mut direct_args: Vec<_> = f
.arg_list
.args
.into_iter()
.map(|arg| self.bind_function_arg(arg))
Expand Down Expand Up @@ -207,19 +208,21 @@ impl Binder {
}

let args: Vec<_> = f
.arg_list
.args
.iter()
.map(|arg| self.bind_function_arg(arg.clone()))
.flatten_ok()
.try_collect()?;
let order_by = OrderBy::new(
f.order_by
f.arg_list
.order_by
.into_iter()
.map(|e| self.bind_order_by_expr(e))
.try_collect()?,
);

if f.distinct {
if f.arg_list.distinct {
if matches!(
kind,
AggKind::Builtin(PbAggKind::ApproxCountDistinct)
Expand Down
34 changes: 23 additions & 11 deletions src/frontend/src/binder/expr/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16;

impl Binder {
pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
if f.arg_list.ignore_nulls {
bail_not_implemented!("IGNORE NULLS is not supported yet");
}

let function_name = match f.name.0.as_slice() {
[name] => name.real_value(),
[schema, name] => {
Expand Down Expand Up @@ -108,6 +112,7 @@ impl Binder {
}

let mut inputs: Vec<_> = f
.arg_list
.args
.iter()
.map(|arg| self.bind_function_arg(arg.clone()))
Expand Down Expand Up @@ -135,7 +140,11 @@ impl Binder {
}
UserDefinedFunction::new(func.clone(), scalar_inputs).into()
} else {
self.bind_builtin_scalar_function(&function_name, scalar_inputs, f.variadic)?
self.bind_builtin_scalar_function(
&function_name,
scalar_inputs,
f.arg_list.variadic,
)?
};
return self.bind_aggregate_function(f, AggKind::WrapScalar(scalar.to_expr_proto()));
}
Expand Down Expand Up @@ -180,7 +189,9 @@ impl Binder {

// The actual inline logic for sql udf
// Note that we will always create new udf context for each sql udf
let Ok(context) = UdfContext::create_udf_context(&f.args, &Arc::clone(func)) else {
let Ok(context) =
UdfContext::create_udf_context(&f.arg_list.args, &Arc::clone(func))
else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
Expand Down Expand Up @@ -265,7 +276,7 @@ impl Binder {
return self.bind_aggregate_function(f, AggKind::Builtin(kind));
}

if f.distinct || !f.order_by.is_empty() || f.filter.is_some() {
if f.arg_list.distinct || !f.arg_list.order_by.is_empty() || f.filter.is_some() {
return Err(ErrorCode::InvalidInputSyntax(format!(
"DISTINCT, ORDER BY or FILTER is only allowed in aggregation functions, but `{}` is not an aggregation function", function_name
)
Expand Down Expand Up @@ -303,17 +314,18 @@ impl Binder {
return Ok(TableFunction::new(function_type, inputs)?.into());
}

self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.variadic)
self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.arg_list.variadic)
}

fn bind_array_transform(&mut self, f: Function) -> Result<ExprImpl> {
let [array, lambda] = <[FunctionArg; 2]>::try_from(f.args).map_err(|args| -> RwError {
ErrorCode::BindError(format!(
"`array_transform` expect two inputs `array` and `lambda`, but {} were given",
args.len()
))
.into()
})?;
let [array, lambda] =
<[FunctionArg; 2]>::try_from(f.arg_list.args).map_err(|args| -> RwError {
ErrorCode::BindError(format!(
"`array_transform` expect two inputs `array` and `lambda`, but {} were given",
args.len()
))
.into()
})?;

let bound_array = self.bind_function_arg(array)?;
let [bound_array] = <[ExprImpl; 1]>::try_from(bound_array).map_err(|bound_array| -> RwError {
Expand Down
37 changes: 20 additions & 17 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,30 +790,33 @@ mod tests {
},
],
),
args: [
Unnamed(
Expr(
Value(
Number(
"0.5",
arg_list: FunctionArgList {
distinct: false,
args: [
Unnamed(
Expr(
Value(
Number(
"0.5",
),
),
),
),
),
Unnamed(
Expr(
Value(
Number(
"0.01",
Unnamed(
Expr(
Value(
Number(
"0.01",
),
),
),
),
),
],
variadic: false,
],
variadic: false,
order_by: [],
ignore_nulls: false,
},
over: None,
distinct: false,
order_by: [],
filter: None,
within_group: Some(
OrderByExpr {
Expand Down
7 changes: 2 additions & 5 deletions src/frontend/src/binder/relation/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use itertools::Itertools;
use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::{Field, Schema, RW_INTERNAL_TABLE_FUNCTION_NAME};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{Function, FunctionArg, ObjectName, TableAlias};
use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgList, ObjectName, TableAlias};

use super::watermark::is_watermark_func;
use super::{Binder, Relation, Result, WindowTableFunctionKind};
Expand Down Expand Up @@ -85,11 +85,8 @@ impl Binder {
let func = self.bind_function(Function {
scalar_as_agg: false,
name,
args,
variadic: false,
arg_list: FunctionArgList::args_only(args),
over: None,
distinct: false,
order_by: vec![],
filter: None,
within_group: None,
});
Expand Down
16 changes: 10 additions & 6 deletions src/meta/src/controller/rename.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use risingwave_pb::expr::expr_node::RexNode;
use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
use risingwave_sqlparser::ast::{
Array, CreateSink, CreateSinkStatement, CreateSourceStatement, CreateSubscriptionStatement,
Distinct, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Query, SelectItem,
SetExpr, Statement, TableAlias, TableFactor, TableWithJoins,
Distinct, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, ObjectName,
Query, SelectItem, SetExpr, Statement, TableAlias, TableFactor, TableWithJoins,
};
use risingwave_sqlparser::parser::Parser;

Expand Down Expand Up @@ -264,14 +264,18 @@ impl QueryRewriter<'_> {
}
}

/// Visit function and update all references.
fn visit_function(&self, function: &mut Function) {
for arg in &mut function.args {
fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
for arg in &mut arg_list.args {
self.visit_function_arg(arg);
}
for expr in &mut function.order_by {
for expr in &mut arg_list.order_by {
self.visit_expr(&mut expr.expr)
}
}

/// Visit function and update all references.
fn visit_function(&self, function: &mut Function) {
self.visit_function_arg_list(&mut function.arg_list);
if let Some(over) = &mut function.over {
for expr in &mut over.partition_by {
self.visit_expr(expr);
Expand Down
107 changes: 77 additions & 30 deletions src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2477,21 +2477,88 @@ impl fmt::Display for FunctionArg {
}
}

/// A list of function arguments, including additional modifiers like `DISTINCT` or `ORDER BY`.
/// This basically holds all the information between the `(` and `)` in a function call.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FunctionArgList {
/// Aggregate function calls may have a `DISTINCT`, e.g. `count(DISTINCT x)`.
pub distinct: bool,
pub args: Vec<FunctionArg>,
/// Whether the last argument is variadic, e.g. `foo(a, b, VARIADIC c)`.
pub variadic: bool,
/// Aggregate function calls may have an `ORDER BY`, e.g. `array_agg(x ORDER BY y)`.
pub order_by: Vec<OrderByExpr>,
/// Window function calls may have an `IGNORE NULLS`, e.g. `first_value(x IGNORE NULLS)`.
pub ignore_nulls: bool,
}

impl fmt::Display for FunctionArgList {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(")?;
if self.distinct {
write!(f, "DISTINCT ")?;
}
if self.variadic {
for arg in &self.args[0..self.args.len() - 1] {
write!(f, "{}, ", arg)?;
}
write!(f, "VARIADIC {}", self.args.last().unwrap())?;
} else {
write!(f, "{}", display_comma_separated(&self.args))?;
}
if !self.order_by.is_empty() {
write!(f, " ORDER BY {}", display_comma_separated(&self.order_by))?;
}
if self.ignore_nulls {
write!(f, " IGNORE NULLS")?;
}
write!(f, ")")?;
Ok(())
}
}

impl FunctionArgList {
pub fn empty() -> Self {
Self {
distinct: false,
args: vec![],
variadic: false,
order_by: vec![],
ignore_nulls: false,
}
}

pub fn args_only(args: Vec<FunctionArg>) -> Self {
Self {
distinct: false,
args,
variadic: false,
order_by: vec![],
ignore_nulls: false,
}
}

pub fn for_agg(distinct: bool, args: Vec<FunctionArg>, order_by: Vec<OrderByExpr>) -> Self {
Self {
distinct,
args,
variadic: false,
order_by: order_by,
ignore_nulls: false,
}
}
}

/// A function call
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Function {
/// Whether the function is prefixed with `aggregate:`
pub scalar_as_agg: bool,
pub name: ObjectName,
pub args: Vec<FunctionArg>,
/// whether the last argument is variadic, e.g. `foo(a, b, variadic c)`
pub variadic: bool,
pub arg_list: FunctionArgList,
pub over: Option<WindowSpec>,
// aggregate functions may specify eg `COUNT(DISTINCT x)`
pub distinct: bool,
// aggregate functions may contain order_by_clause
pub order_by: Vec<OrderByExpr>,
pub filter: Option<Box<Expr>>,
pub within_group: Option<Box<OrderByExpr>>,
}
Expand All @@ -2501,11 +2568,8 @@ impl Function {
Self {
scalar_as_agg: false,
name,
args: vec![],
variadic: false,
arg_list: FunctionArgList::empty(),
over: None,
distinct: false,
order_by: vec![],
filter: None,
within_group: None,
}
Expand All @@ -2515,26 +2579,9 @@ impl Function {
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.scalar_as_agg {
write!(f, "aggregate:")?;
write!(f, "AGGREGATE:")?;
}
write!(
f,
"{}({}",
self.name,
if self.distinct { "DISTINCT " } else { "" },
)?;
if self.variadic {
for arg in &self.args[0..self.args.len() - 1] {
write!(f, "{}, ", arg)?;
}
write!(f, "VARIADIC {}", self.args.last().unwrap())?;
} else {
write!(f, "{}", display_comma_separated(&self.args))?;
}
if !self.order_by.is_empty() {
write!(f, " ORDER BY {}", display_comma_separated(&self.order_by))?;
}
write!(f, ")")?;
write!(f, "{}{}", self.name, self.arg_list)?;
if let Some(o) = &self.over {
write!(f, " OVER ({})", o)?;
}
Expand Down
Loading

0 comments on commit f1710d9

Please sign in to comment.