From c47b9ed8f31c45b056eda13f1598c94488362613 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 28 Nov 2023 16:49:31 +0800 Subject: [PATCH] feat(expr): support implicit cast of arguments for table and aggregate functions (#13545) Signed-off-by: Runji Wang --- e2e_test/batch/types/jsonb.slt.part | 54 ++++---- src/common/src/error.rs | 2 + src/expr/core/src/aggregate/mod.rs | 22 ++-- src/expr/core/src/sig/mod.rs | 115 ++++++++++++------ .../src/window_function/state/aggregate.rs | 9 +- src/expr/impl/src/aggregate/general.rs | 84 ++++++++++++- src/expr/impl/src/scalar/date_trunc.rs | 15 +-- src/expr/impl/src/scalar/to_char.rs | 17 +-- src/expr/impl/src/scalar/to_timestamp.rs | 14 +-- src/expr/macro/src/gen.rs | 35 ++++-- src/expr/macro/src/lib.rs | 4 +- src/expr/macro/src/parse.rs | 4 +- .../tests/testdata/output/expr.yaml | 15 +-- .../tests/testdata/output/struct_query.yaml | 5 +- src/frontend/src/binder/select.rs | 14 +-- src/frontend/src/expr/agg_call.rs | 67 ++++------ src/frontend/src/expr/function_call.rs | 2 +- src/frontend/src/expr/table_function.rs | 10 +- src/frontend/src/expr/type_inference/func.rs | 52 ++++---- src/frontend/src/expr/window_function.rs | 6 +- .../src/optimizer/plan_node/generic/agg.rs | 39 ++++-- .../src/optimizer/plan_node/logical_agg.rs | 27 ++-- 22 files changed, 358 insertions(+), 254 deletions(-) diff --git a/e2e_test/batch/types/jsonb.slt.part b/e2e_test/batch/types/jsonb.slt.part index 949b18af315c0..8e8ec7aa0e60f 100644 --- a/e2e_test/batch/types/jsonb.slt.part +++ b/e2e_test/batch/types/jsonb.slt.part @@ -159,113 +159,113 @@ SELECT '{ # jsonb_array_elements query T -select * from jsonb_array_elements('[1,true, [2,false]]'::jsonb); +select * from jsonb_array_elements('[1,true, [2,false]]'); ---- 1 true [2, false] statement error cannot extract elements -select * from jsonb_array_elements('null'::jsonb) +select * from jsonb_array_elements('null'); statement error cannot extract elements -select * from jsonb_array_elements('1'::jsonb) +select * from jsonb_array_elements('1'); statement error cannot extract elements -select * from jsonb_array_elements('"string"'::jsonb) +select * from jsonb_array_elements('"string"'); statement error cannot extract elements -select * from jsonb_array_elements('{}'::jsonb) +select * from jsonb_array_elements('{}'); # jsonb_array_elements_text query T -select * from jsonb_array_elements_text('["foo", "bar"]'::jsonb); +select * from jsonb_array_elements_text('["foo", "bar"]'); ---- foo bar statement error cannot extract elements -select * from jsonb_array_elements_text('null'::jsonb) +select * from jsonb_array_elements_text('null'); statement error cannot extract elements -select * from jsonb_array_elements_text('1'::jsonb) +select * from jsonb_array_elements_text('1'); statement error cannot extract elements -select * from jsonb_array_elements_text('"string"'::jsonb) +select * from jsonb_array_elements_text('"string"'); statement error cannot extract elements -select * from jsonb_array_elements_text('{}'::jsonb) +select * from jsonb_array_elements_text('{}'); # jsonb_object_keys query T -select * from jsonb_object_keys('{"f1":"abc","f2":{"f3":"a", "f4":"b"}}'::jsonb); +select * from jsonb_object_keys('{"f1":"abc","f2":{"f3":"a", "f4":"b"}}'); ---- f1 f2 statement error cannot call jsonb_object_keys -select * from jsonb_object_keys('null'::jsonb) +select * from jsonb_object_keys('null'); statement error cannot call jsonb_object_keys -select * from jsonb_object_keys('1'::jsonb) +select * from jsonb_object_keys('1'); statement error cannot call jsonb_object_keys -select * from jsonb_object_keys('"string"'::jsonb) +select * from jsonb_object_keys('"string"'); statement error cannot call jsonb_object_keys -select * from jsonb_object_keys('[]'::jsonb) +select * from jsonb_object_keys('[]'); # jsonb_each query TT -select * from jsonb_each('{"a":"foo", "b":"bar"}'::jsonb); +select * from jsonb_each('{"a":"foo", "b":"bar"}'); ---- a "foo" b "bar" query T -select jsonb_each('{"a":"foo", "b":"bar"}'::jsonb); +select jsonb_each('{"a":"foo", "b":"bar"}'); ---- (a,"""foo""") (b,"""bar""") statement error cannot deconstruct -select * from jsonb_each('null'::jsonb) +select * from jsonb_each('null'); statement error cannot deconstruct -select * from jsonb_each('1'::jsonb) +select * from jsonb_each('1'); statement error cannot deconstruct -select * from jsonb_each('"string"'::jsonb) +select * from jsonb_each('"string"'); statement error cannot deconstruct -select * from jsonb_each('[]'::jsonb) +select * from jsonb_each('[]'); # jsonb_each_text query TT -select * from jsonb_each_text('{"a":"foo", "b":"bar"}'::jsonb); +select * from jsonb_each_text('{"a":"foo", "b":"bar"}'); ---- a foo b bar query T -select jsonb_each_text('{"a":"foo", "b":"bar"}'::jsonb); +select jsonb_each_text('{"a":"foo", "b":"bar"}'); ---- (a,foo) (b,bar) statement error cannot deconstruct -select * from jsonb_each_text('null'::jsonb) +select * from jsonb_each_text('null'); statement error cannot deconstruct -select * from jsonb_each_text('1'::jsonb) +select * from jsonb_each_text('1'); statement error cannot deconstruct -select * from jsonb_each_text('"string"'::jsonb) +select * from jsonb_each_text('"string"'); statement error cannot deconstruct -select * from jsonb_each_text('[]'::jsonb) +select * from jsonb_each_text('[]'); query TTTTT SELECT js, diff --git a/src/common/src/error.rs b/src/common/src/error.rs index d97fbdc0c73e9..a7b27e9dd9f89 100644 --- a/src/common/src/error.rs +++ b/src/common/src/error.rs @@ -102,6 +102,8 @@ pub enum ErrorCode { // Tips: Use this only if it's intended to reject the query #[error("Not supported: {0}\nHINT: {1}")] NotSupported(String, String), + #[error("function {0} does not exist")] + NoFunction(String), #[error(transparent)] IoError(#[from] IoError), #[error("Storage error: {0}")] diff --git a/src/expr/core/src/aggregate/mod.rs b/src/expr/core/src/aggregate/mod.rs index 74e3afdb0904c..4eeb3b4256f12 100644 --- a/src/expr/core/src/aggregate/mod.rs +++ b/src/expr/core/src/aggregate/mod.rs @@ -21,6 +21,7 @@ use risingwave_common::array::StreamChunk; use risingwave_common::estimate_size::EstimateSize; use risingwave_common::types::{DataType, Datum}; +use crate::sig::FuncBuilder; use crate::{ExprError, Result}; // aggregate definition @@ -131,18 +132,16 @@ pub fn build_retractable(agg: &AggCall) -> Result { build(agg, false) } -/// Build an `Aggregator` from `AggCall`. +/// Build an aggregate function. +/// +/// If `prefer_append_only` is true, and both append-only and retractable implementations exist, +/// the append-only version will be used. /// /// NOTE: This function ignores argument indices, `column_orders`, `filter` and `distinct` in /// `AggCall`. Such operations should be done in batch or streaming executors. -pub fn build(agg: &AggCall, append_only: bool) -> Result { - let desc = crate::sig::FUNCTION_REGISTRY - .get_aggregate( - agg.kind, - agg.args.arg_types(), - &agg.return_type, - append_only, - ) +pub fn build(agg: &AggCall, prefer_append_only: bool) -> Result { + let sig = crate::sig::FUNCTION_REGISTRY + .get(agg.kind, agg.args.arg_types(), &agg.return_type) .ok_or_else(|| { ExprError::UnsupportedFunction(format!( "{}({}) -> {}", @@ -152,5 +151,8 @@ pub fn build(agg: &AggCall, append_only: bool) -> Result )) })?; - desc.build_aggregate(agg) + if let FuncBuilder::Aggregate{ append_only: Some(f), .. } = sig.build && prefer_append_only { + return f(agg); + } + sig.build_aggregate(agg) } diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index c2e71b585d49c..738b4f6b9eaf9 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -49,7 +49,42 @@ pub struct FunctionRegistry(HashMap>); impl FunctionRegistry { /// Inserts a function signature. pub fn insert(&mut self, sig: FuncSign) { - self.0.entry(sig.name).or_default().push(sig) + let list = self.0.entry(sig.name).or_default(); + if sig.is_aggregate() { + // merge retractable and append-only aggregate + if let Some(existing) = list + .iter_mut() + .find(|d| d.inputs_type == sig.inputs_type && d.ret_type == sig.ret_type) + { + let ( + FuncBuilder::Aggregate { + retractable, + append_only, + retractable_state_type, + append_only_state_type, + }, + FuncBuilder::Aggregate { + retractable: r1, + append_only: a1, + retractable_state_type: rs1, + append_only_state_type: as1, + }, + ) = (&mut existing.build, sig.build) + else { + panic!("expected aggregate function") + }; + if let Some(f) = r1 { + *retractable = Some(f); + *retractable_state_type = rs1; + } + if let Some(f) = a1 { + *append_only = Some(f); + *append_only_state_type = as1; + } + return; + } + } + list.push(sig); } /// Returns a function signature with the same type, argument types and return type. @@ -76,27 +111,8 @@ impl FunctionRegistry { } } - /// Returns a function signature with the given type, argument types, return type. - /// - /// The `prefer_append_only` flag only works when both append-only and retractable version exist. - /// Otherwise, return the signature of the only version. - pub fn get_aggregate( - &self, - ty: AggregateFunctionType, - args: &[DataType], - ret: &DataType, - prefer_append_only: bool, - ) -> Option<&FuncSign> { - let v = self.0.get(&ty.into())?; - let mut iter = v.iter().filter(|d| d.match_args_ret(args, ret)); - if iter.clone().count() == 2 { - iter.find(|d| d.append_only == prefer_append_only) - } else { - iter.next() - } - } - /// Returns the return type for the given function and arguments. + /// Deprecated functions are excluded. pub fn get_return_type( &self, name: impl Into, @@ -109,7 +125,7 @@ impl FunctionRegistry { .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?; let sig = v .iter() - .find(|d| d.match_args(args)) + .find(|d| d.match_args(args) && !d.deprecated) .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?; (sig.type_infer)(args) } @@ -154,13 +170,6 @@ pub struct FuncSign { /// Whether the function is deprecated and should not be used in the frontend. /// For backward compatibility, it is still available in the backend. pub deprecated: bool, - - /// The state type of the aggregate function. - /// `None` means equal to the return type. - pub state_type: Option, - - /// Whether the aggregate function is append-only. - pub append_only: bool, } impl fmt::Debug for FuncSign { @@ -182,9 +191,6 @@ impl fmt::Debug for FuncSign { if self.name.is_table() { "setof " } else { "" }, self.ret_type, )?; - if self.append_only { - write!(f, " [append-only]")?; - } if self.deprecated { write!(f, " [deprecated]")?; } @@ -235,6 +241,28 @@ impl FuncSign { matches!(self.name, FuncName::Aggregate(_)) } + /// Returns true if the aggregate function is append-only. + pub const fn is_append_only(&self) -> bool { + matches!( + self.build, + FuncBuilder::Aggregate { + retractable: None, + .. + } + ) + } + + /// Returns true if the aggregate function has a retractable version. + pub const fn is_retractable(&self) -> bool { + matches!( + self.build, + FuncBuilder::Aggregate { + retractable: Some(_), + .. + } + ) + } + /// Builds the scalar function. pub fn build_scalar( &self, @@ -260,10 +288,15 @@ impl FuncSign { } } - /// Builds the aggregate function. + /// Builds the aggregate function. If both retractable and append-only versions exist, the + /// retractable version will be built. pub fn build_aggregate(&self, agg: &AggCall) -> Result { match self.build { - FuncBuilder::Aggregate(f) => f(agg), + FuncBuilder::Aggregate { + retractable, + append_only, + .. + } => retractable.or(append_only).unwrap()(agg), _ => panic!("Expected an aggregate function"), } } @@ -385,7 +418,7 @@ impl SigDataType { } } -#[derive(Clone, Copy)] +#[derive(Clone)] pub enum FuncBuilder { Scalar(fn(return_type: DataType, children: Vec) -> Result), Table( @@ -395,7 +428,17 @@ pub enum FuncBuilder { children: Vec, ) -> Result, ), - Aggregate(fn(agg: &AggCall) -> Result), + // An aggregate function may contain both or either one of retractable and append-only versions. + Aggregate { + retractable: Option Result>, + append_only: Option Result>, + /// The state type of the retractable aggregate function. + /// `None` means equal to the return type. + retractable_state_type: Option, + /// The state type of the append-only aggregate function. + /// `None` means equal to the return type. + append_only_state_type: Option, + }, } /// Register a function into global registry. diff --git a/src/expr/core/src/window_function/state/aggregate.rs b/src/expr/core/src/window_function/state/aggregate.rs index 38958b50b8c38..19bec0a4bb572 100644 --- a/src/expr/core/src/window_function/state/aggregate.rs +++ b/src/expr/core/src/window_function/state/aggregate.rs @@ -63,16 +63,11 @@ impl AggregateState { direct_args: vec![], }; let agg_func_sig = FUNCTION_REGISTRY - .get_aggregate( - agg_kind, - &arg_data_types, - &call.return_type, - false, // means prefer retractable version - ) + .get(agg_kind, &arg_data_types, &call.return_type) .expect("the agg func must exist"); let agg_func = agg_func_sig.build_aggregate(&agg_call)?; let (agg_impl, enable_delta) = - if !agg_func_sig.append_only && call.frame.exclusion.is_no_others() { + if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { let init_state = agg_func.create_state(); (AggImpl::Incremental(init_state), true) } else { diff --git a/src/expr/impl/src/aggregate/general.rs b/src/expr/impl/src/aggregate/general.rs index f47c94d45f24d..993b567590263 100644 --- a/src/expr/impl/src/aggregate/general.rs +++ b/src/expr/impl/src/aggregate/general.rs @@ -19,14 +19,14 @@ use risingwave_expr::{aggregate, ExprError, Result}; #[aggregate("sum(int2) -> int8")] #[aggregate("sum(int4) -> int8")] -#[aggregate("sum(int8) -> int8")] #[aggregate("sum(int8) -> decimal")] #[aggregate("sum(float4) -> float4")] #[aggregate("sum(float8) -> float8")] #[aggregate("sum(decimal) -> decimal")] #[aggregate("sum(interval) -> interval")] #[aggregate("sum(int256) -> int256")] -#[aggregate("sum0(int8) -> int8", init_state = "0i64")] +#[aggregate("sum(int8) -> int8", internal)] // used internally for 2-phase sum(int2) and sum(int4) +#[aggregate("sum0(int8) -> int8", internal, init_state = "0i64")] // used internally for 2-phase count fn sum(state: S, input: T, retract: bool) -> Result where S: Default + From + CheckedAdd + CheckedSub, @@ -42,12 +42,86 @@ where } } -#[aggregate("min(*) -> auto", state = "ref")] +#[aggregate("avg(int2) -> decimal", rewritten)] +#[aggregate("avg(int4) -> decimal", rewritten)] +#[aggregate("avg(int8) -> decimal", rewritten)] +#[aggregate("avg(decimal) -> decimal", rewritten)] +#[aggregate("avg(float4) -> float8", rewritten)] +#[aggregate("avg(float8) -> float8", rewritten)] +#[aggregate("avg(int256) -> float8", rewritten)] +#[aggregate("avg(interval) -> interval", rewritten)] +fn _avg() {} + +#[aggregate("stddev_pop(int2) -> decimal", rewritten)] +#[aggregate("stddev_pop(int4) -> decimal", rewritten)] +#[aggregate("stddev_pop(int8) -> decimal", rewritten)] +#[aggregate("stddev_pop(decimal) -> decimal", rewritten)] +#[aggregate("stddev_pop(float4) -> float8", rewritten)] +#[aggregate("stddev_pop(float8) -> float8", rewritten)] +#[aggregate("stddev_pop(int256) -> float8", rewritten)] +fn _stddev_pop() {} + +#[aggregate("stddev_samp(int2) -> decimal", rewritten)] +#[aggregate("stddev_samp(int4) -> decimal", rewritten)] +#[aggregate("stddev_samp(int8) -> decimal", rewritten)] +#[aggregate("stddev_samp(decimal) -> decimal", rewritten)] +#[aggregate("stddev_samp(float4) -> float8", rewritten)] +#[aggregate("stddev_samp(float8) -> float8", rewritten)] +#[aggregate("stddev_samp(int256) -> float8", rewritten)] +fn _stddev_samp() {} + +#[aggregate("var_pop(int2) -> decimal", rewritten)] +#[aggregate("var_pop(int4) -> decimal", rewritten)] +#[aggregate("var_pop(int8) -> decimal", rewritten)] +#[aggregate("var_pop(decimal) -> decimal", rewritten)] +#[aggregate("var_pop(float4) -> float8", rewritten)] +#[aggregate("var_pop(float8) -> float8", rewritten)] +#[aggregate("var_pop(int256) -> float8", rewritten)] +fn _var_pop() {} + +#[aggregate("var_samp(int2) -> decimal", rewritten)] +#[aggregate("var_samp(int4) -> decimal", rewritten)] +#[aggregate("var_samp(int8) -> decimal", rewritten)] +#[aggregate("var_samp(decimal) -> decimal", rewritten)] +#[aggregate("var_samp(float4) -> float8", rewritten)] +#[aggregate("var_samp(float8) -> float8", rewritten)] +#[aggregate("var_samp(int256) -> float8", rewritten)] +fn _var_samp() {} + +// no `min(boolean)` and `min(jsonb)` +#[aggregate("min(*int) -> auto", state = "ref")] +#[aggregate("min(*float) -> auto", state = "ref")] +#[aggregate("min(decimal) -> auto", state = "ref")] +#[aggregate("min(int256) -> auto", state = "ref")] +#[aggregate("min(serial) -> auto", state = "ref")] +#[aggregate("min(date) -> auto", state = "ref")] +#[aggregate("min(time) -> auto", state = "ref")] +#[aggregate("min(interval) -> auto", state = "ref")] +#[aggregate("min(timestamp) -> auto", state = "ref")] +#[aggregate("min(timestamptz) -> auto", state = "ref")] +#[aggregate("min(varchar) -> auto", state = "ref")] +#[aggregate("min(bytea) -> auto", state = "ref")] +#[aggregate("min(anyarray) -> auto", state = "ref")] +#[aggregate("min(struct) -> auto", state = "ref")] fn min(state: T, input: T) -> T { state.min(input) } -#[aggregate("max(*) -> auto", state = "ref")] +// no `max(boolean)` and `max(jsonb)` +#[aggregate("max(*int) -> auto", state = "ref")] +#[aggregate("max(*float) -> auto", state = "ref")] +#[aggregate("max(decimal) -> auto", state = "ref")] +#[aggregate("max(int256) -> auto", state = "ref")] +#[aggregate("max(serial) -> auto", state = "ref")] +#[aggregate("max(date) -> auto", state = "ref")] +#[aggregate("max(time) -> auto", state = "ref")] +#[aggregate("max(interval) -> auto", state = "ref")] +#[aggregate("max(timestamp) -> auto", state = "ref")] +#[aggregate("max(timestamptz) -> auto", state = "ref")] +#[aggregate("max(varchar) -> auto", state = "ref")] +#[aggregate("max(bytea) -> auto", state = "ref")] +#[aggregate("max(anyarray) -> auto", state = "ref")] +#[aggregate("max(struct) -> auto", state = "ref")] fn max(state: T, input: T) -> T { state.max(input) } @@ -62,7 +136,7 @@ fn last_value(_: T, input: T) -> T { input } -#[aggregate("internal_last_seen_value(*) -> auto", state = "ref")] +#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)] fn internal_last_seen_value(state: T, input: T, retract: bool) -> T { if retract { state diff --git a/src/expr/impl/src/scalar/date_trunc.rs b/src/expr/impl/src/scalar/date_trunc.rs index 35ba48631bfc4..afa02961bf3fc 100644 --- a/src/expr/impl/src/scalar/date_trunc.rs +++ b/src/expr/impl/src/scalar/date_trunc.rs @@ -13,8 +13,7 @@ // limitations under the License. use risingwave_common::types::{Interval, Timestamp, Timestamptz}; -use risingwave_expr::expr::BoxedExpression; -use risingwave_expr::{build_function, function, ExprError, Result}; +use risingwave_expr::{function, ExprError, Result}; use super::timestamptz::timestamp_at_time_zone; @@ -53,16 +52,8 @@ pub fn date_trunc_timestamp(field: &str, ts: Timestamp) -> Result { }) } -// Only to register this signature to function signature map. -#[build_function("date_trunc(varchar, timestamptz) -> timestamptz")] -fn build_date_trunc_timestamptz_implicit_zone( - _return_type: risingwave_common::types::DataType, - _children: Vec, -) -> Result { - Err(ExprError::UnsupportedFunction( - "date_trunc of timestamptz should have been rewritten to include timezone".into(), - )) -} +#[function("date_trunc(varchar, timestamptz) -> timestamptz", rewritten)] +fn _date_trunc_timestamptz() {} #[function("date_trunc(varchar, timestamptz, varchar) -> timestamptz")] pub fn date_trunc_timestamptz_at_timezone( diff --git a/src/expr/impl/src/scalar/to_char.rs b/src/expr/impl/src/scalar/to_char.rs index 9d28d62eca7a6..4d4edb2d390ba 100644 --- a/src/expr/impl/src/scalar/to_char.rs +++ b/src/expr/impl/src/scalar/to_char.rs @@ -17,9 +17,8 @@ use std::sync::LazyLock; use aho_corasick::{AhoCorasick, AhoCorasickBuilder}; use chrono::format::StrftimeItems; -use risingwave_common::types::{DataType, Timestamp, Timestamptz}; -use risingwave_expr::expr::BoxedExpression; -use risingwave_expr::{build_function, function, ExprError, Result}; +use risingwave_common::types::{Timestamp, Timestamptz}; +use risingwave_expr::{function, ExprError, Result}; use super::timestamptz::time_zone_err; @@ -120,16 +119,8 @@ fn timestamp_to_char(data: Timestamp, pattern: &ChronoPattern, writer: &mut impl write!(writer, "{}", format).unwrap(); } -// Only to register this signature to function signature map. -#[build_function("to_char(timestamptz, varchar) -> varchar")] -fn timestamptz_to_char( - _return_type: DataType, - _children: Vec, -) -> Result { - Err(ExprError::UnsupportedFunction( - "to_char(timestamptz, varchar) should have been rewritten to include timezone".into(), - )) -} +#[function("to_char(timestamptz, varchar) -> varchar", rewritten)] +fn _timestamptz_to_char() {} #[function( "to_char(timestamptz, varchar, varchar) -> varchar", diff --git a/src/expr/impl/src/scalar/to_timestamp.rs b/src/expr/impl/src/scalar/to_timestamp.rs index e4ef9edc235eb..186f50e9c5cc0 100644 --- a/src/expr/impl/src/scalar/to_timestamp.rs +++ b/src/expr/impl/src/scalar/to_timestamp.rs @@ -13,9 +13,8 @@ // limitations under the License. use chrono::format::Parsed; -use risingwave_common::types::{DataType, Date, Timestamp, Timestamptz}; -use risingwave_expr::expr::BoxedExpression; -use risingwave_expr::{build_function, function, ExprError, Result}; +use risingwave_common::types::{Date, Timestamp, Timestamptz}; +use risingwave_expr::{function, ExprError, Result}; use super::timestamptz::{timestamp_at_time_zone, timestamptz_at_time_zone}; use super::to_char::ChronoPattern; @@ -94,13 +93,8 @@ pub fn to_timestamp(s: &str, timezone: &str, tmpl: &ChronoPattern) -> Result timestamptz")] -fn build_dummy(_return_type: DataType, _children: Vec) -> Result { - Err(ExprError::UnsupportedFunction( - "to_timestamp should have been rewritten to include timezone".into(), - )) -} +#[function("to_timestamp1(varchar, varchar) -> timestamptz", rewritten)] +fn _to_timestamp1() {} #[function( "char_to_date(varchar, varchar) -> date", diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 325d9101cb2f2..718c4a0b72e79 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -117,6 +117,8 @@ impl FunctionAttr { let build_fn = if build_fn { let name = format_ident!("{}", user_fn.name); quote! { #name } + } else if self.rewritten { + quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) } } else { self.generate_build_scalar_function(user_fn, true)? }; @@ -137,8 +139,6 @@ impl FunctionAttr { build: FuncBuilder::Scalar(#build_fn), type_infer: #type_infer_fn, deprecated: #deprecated, - state_type: None, - append_only: false, }) }; } }) @@ -550,9 +550,27 @@ impl FunctionAttr { let build_fn = if build_fn { let name = format_ident!("{}", user_fn.as_fn().name); quote! { #name } + } else if self.rewritten { + quote! { |_| Err(ExprError::UnsupportedFunction(#name.into())) } } else { self.generate_agg_build_fn(user_fn)? }; + let build_retractable = match append_only { + true => quote! { None }, + false => quote! { Some(#build_fn) }, + }; + let build_append_only = match append_only { + false => quote! { None }, + true => quote! { Some(#build_fn) }, + }; + let retractable_state_type = match append_only { + true => quote! { None }, + false => state_type.clone(), + }; + let append_only_state_type = match append_only { + false => quote! { None }, + true => state_type, + }; let type_infer_fn = self.generate_type_infer_fn()?; let deprecated = self.deprecated; @@ -567,10 +585,13 @@ impl FunctionAttr { inputs_type: vec![#(#args),*], variadic: false, ret_type: #ret, - build: FuncBuilder::Aggregate(#build_fn), + build: FuncBuilder::Aggregate { + retractable: #build_retractable, + append_only: #build_append_only, + retractable_state_type: #retractable_state_type, + append_only_state_type: #append_only_state_type, + }, type_infer: #type_infer_fn, - state_type: #state_type, - append_only: #append_only, deprecated: #deprecated, }) }; } @@ -876,6 +897,8 @@ impl FunctionAttr { let build_fn = if build_fn { let name = format_ident!("{}", user_fn.name); quote! { #name } + } else if self.rewritten { + quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) } } else { self.generate_build_table_function(user_fn)? }; @@ -896,8 +919,6 @@ impl FunctionAttr { build: FuncBuilder::Table(#build_fn), type_infer: #type_infer_fn, deprecated: #deprecated, - state_type: None, - append_only: false, }) }; } }) diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index 50a99cf3fda22..544b369072d79 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -503,8 +503,10 @@ struct FunctionAttr { generic: Option, /// Whether the function is volatile. volatile: bool, - /// Whether the function is deprecated. + /// If true, the function is unavailable on the frontend. deprecated: bool, + /// If true, the function is not implemented on the backend, but its signature is defined. + rewritten: bool, } /// Attributes from function signature `fn(..)` diff --git a/src/expr/macro/src/parse.rs b/src/expr/macro/src/parse.rs index fc9e4d45437e2..574d573894655 100644 --- a/src/expr/macro/src/parse.rs +++ b/src/expr/macro/src/parse.rs @@ -79,8 +79,10 @@ impl Parse for FunctionAttr { parsed.generic = Some(get_value()?); } else if meta.path().is_ident("volatile") { parsed.volatile = true; - } else if meta.path().is_ident("deprecated") { + } else if meta.path().is_ident("deprecated") || meta.path().is_ident("internal") { parsed.deprecated = true; + } else if meta.path().is_ident("rewritten") { + parsed.rewritten = true; } else if meta.path().is_ident("append_only") { parsed.append_only = true; } else { diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index 36ee6acbe9163..016bff05efe90 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -143,8 +143,7 @@ Failed to bind expression: round(true) Caused by: - Feature is not yet implemented: Round[Boolean] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + function round(boolean) does not exist - sql: | -- Single quoted literal can be treated as number without error. values(round('123')); @@ -164,8 +163,7 @@ Failed to bind expression: 1 NOT LIKE 1.23 Caused by: - Feature is not yet implemented: Like[Int32, Decimal] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + function like(integer, numeric) does not exist - sql: | select length(trim(trailing '1' from '12'))+length(trim(leading '2' from '23'))+length(trim(both '3' from '34')); batch_plan: 'BatchValues { rows: [[4:Int32]] }' @@ -217,8 +215,7 @@ Failed to bind expression: (CASE v1 WHEN 1 THEN 1 WHEN true THEN 2 ELSE 0.0 END) Caused by: - Feature is not yet implemented: Equal[Int32, Boolean] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + function equal(integer, boolean) does not exist - sql: | create table t (v1 int); select nullif(v1, 1) as expr from t; @@ -245,8 +242,7 @@ Failed to bind expression: nullif(v1, true) Caused by: - Feature is not yet implemented: Equal[Int32, Boolean] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + function equal(integer, boolean) does not exist - sql: | create table t (v1 int); select coalesce(v1, 1) as expr from t; @@ -406,8 +402,7 @@ Failed to bind expression: 1 < SOME(CAST(NULL AS CHARACTER VARYING[])) Caused by: - Feature is not yet implemented: LessThan[Int32, Varchar] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + function less_than(integer, character varying) does not exist - sql: | select 1 < SOME(null::date); binder_error: | diff --git a/src/frontend/planner_test/tests/testdata/output/struct_query.yaml b/src/frontend/planner_test/tests/testdata/output/struct_query.yaml index fb6c498321471..f55c30f4436a8 100644 --- a/src/frontend/planner_test/tests/testdata/output/struct_query.yaml +++ b/src/frontend/planner_test/tests/testdata/output/struct_query.yaml @@ -274,8 +274,7 @@ Failed to bind expression: (country + country) Caused by: - Feature is not yet implemented: Add[Struct(StructType { field_names: ["address", "zipcode"], field_types: [Varchar, Varchar] }), Struct(StructType { field_names: ["address", "zipcode"], field_types: [Varchar, Varchar] })] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + function add(struct
, struct
) does not exist create_source: format: plain encode: protobuf @@ -300,7 +299,7 @@ Failed to bind expression: avg(country) Caused by: - Invalid input syntax: Invalid aggregation: avg(struct
) + function avg(struct
) does not exist create_source: format: plain encode: protobuf diff --git a/src/frontend/src/binder/select.rs b/src/frontend/src/binder/select.rs index ceb7d55312f46..c34f2126d7bdc 100644 --- a/src/frontend/src/binder/select.rs +++ b/src/frontend/src/binder/select.rs @@ -40,10 +40,9 @@ use crate::catalog::system_catalog::rw_catalog::{ }; use crate::expr::{ AggCall, CorrelatedId, CorrelatedInputRef, Depth, Expr as _, ExprImpl, ExprType, FunctionCall, - InputRef, OrderBy, + InputRef, }; use crate::utils::group_by::GroupBy; -use crate::utils::Condition; #[derive(Debug, Clone)] pub struct BoundSelect { @@ -635,15 +634,8 @@ impl Binder { .into(); // There could be multiple indexes on a table so aggregate the sizes of all indexes - let select_items: Vec = vec![AggCall::new( - AggKind::Sum0, - vec![sum], - false, - OrderBy::any(), - Condition::true_cond(), - vec![], - )? - .into()]; + let select_items: Vec = + vec![AggCall::new_unchecked(AggKind::Sum0, vec![sum], DataType::Int64)?.into()]; let indrelid_col = PG_INDEX_COLUMNS[1].1; let indrelid_ref = self.bind_column(&[indrelid_col.into()])?; diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index c9fe56b841290..73031ff060177 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use itertools::Itertools; -use risingwave_common::error::{ErrorCode, Result, RwError}; +use risingwave_common::error::Result; use risingwave_common::types::DataType; use risingwave_expr::aggregate::AggKind; -use risingwave_expr::sig::FUNCTION_REGISTRY; -use super::{Expr, ExprImpl, Literal, OrderBy}; +use super::{infer_type, Expr, ExprImpl, Literal, OrderBy}; use crate::utils::Condition; #[derive(Clone, Eq, PartialEq, Hash)] @@ -52,57 +50,17 @@ impl std::fmt::Debug for AggCall { } impl AggCall { - /// Infer the return type for the given agg call. - /// Returns error if not supported or the arguments are invalid. - pub fn infer_return_type(agg_kind: AggKind, args: &[DataType]) -> Result { - // The function signatures are aligned with postgres, see - // https://www.postgresql.org/docs/current/functions-aggregate.html. - use DataType::*; - let err = || { - RwError::from(ErrorCode::InvalidInputSyntax(format!( - "Invalid aggregation: {}({})", - agg_kind, - args.iter().map(|t| format!("{}", t)).join(", ") - ))) - }; - Ok(match (agg_kind, args) { - // XXX: some special cases that can not be handled by signature map. - - // min/max allowed for all types except for bool and jsonb (#7981) - (AggKind::Min | AggKind::Max, [DataType::Jsonb]) => return Err(err()), - // functions that are rewritten in the frontend and don't exist in the expr crate - (AggKind::Avg, [input]) => match input { - Int16 | Int32 | Int64 | Decimal => Decimal, - Float32 | Float64 | Int256 => Float64, - Interval => Interval, - _ => return Err(err()), - }, - ( - AggKind::StddevPop | AggKind::StddevSamp | AggKind::VarPop | AggKind::VarSamp, - [input], - ) => match input { - Int16 | Int32 | Int64 | Decimal => Decimal, - Float32 | Float64 | Int256 => Float64, - _ => return Err(err()), - }, - (AggKind::Grouping, _) => Int32, - // other functions are handled by signature map - _ => FUNCTION_REGISTRY.get_return_type(agg_kind, args)?, - }) - } - /// Returns error if the function name matches with an existing function /// but with illegal arguments. pub fn new( agg_kind: AggKind, - args: Vec, + mut args: Vec, distinct: bool, order_by: OrderBy, filter: Condition, direct_args: Vec, ) -> Result { - let data_types = args.iter().map(ExprImpl::return_type).collect_vec(); - let return_type = Self::infer_return_type(agg_kind, &data_types)?; + let return_type = infer_type(agg_kind.into(), &mut args)?; Ok(AggCall { agg_kind, return_type, @@ -114,6 +72,23 @@ impl AggCall { }) } + /// Constructs an `AggCall` without type inference. + pub fn new_unchecked( + agg_kind: AggKind, + args: Vec, + return_type: DataType, + ) -> Result { + Ok(AggCall { + agg_kind, + return_type, + args, + distinct: false, + order_by: OrderBy::any(), + filter: Condition::true_cond(), + direct_args: vec![], + }) + } + pub fn decompose( self, ) -> ( diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index ad0ddc8fc08a5..87c2c6d95595d 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -102,7 +102,7 @@ impl FunctionCall { // number of arguments are checked // [elsewhere](crate::expr::type_inference::build_type_derive_map). pub fn new(func_type: ExprType, mut inputs: Vec) -> RwResult { - let return_type = infer_type(func_type, &mut inputs)?; + let return_type = infer_type(func_type.into(), &mut inputs)?; Ok(Self::new_unchecked(func_type, inputs, return_type)) } diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index dfb028d605705..05e7455617d67 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -16,13 +16,12 @@ use std::sync::Arc; use itertools::Itertools; use risingwave_common::types::DataType; -use risingwave_expr::sig::FUNCTION_REGISTRY; pub use risingwave_pb::expr::table_function::PbType as TableFunctionType; use risingwave_pb::expr::{ TableFunction as TableFunctionPb, UserDefinedTableFunction as UserDefinedTableFunctionPb, }; -use super::{Expr, ExprImpl, ExprRewriter, RwResult}; +use super::{infer_type, Expr, ExprImpl, ExprRewriter, RwResult}; use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind}; /// A table function takes a row as input and returns a table. It is also known as Set-Returning @@ -42,11 +41,8 @@ pub struct TableFunction { impl TableFunction { /// Create a `TableFunction` expr with the return type inferred from `func_type` and types of /// `inputs`. - pub fn new(func_type: TableFunctionType, args: Vec) -> RwResult { - let return_type = FUNCTION_REGISTRY.get_return_type( - func_type, - &args.iter().map(|c| c.return_type()).collect_vec(), - )?; + pub fn new(func_type: TableFunctionType, mut args: Vec) -> RwResult { + let return_type = infer_type(func_type.into(), &mut args)?; Ok(TableFunction { args, return_type, diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index b58935e810176..a7ee638bca835 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -14,10 +14,10 @@ use itertools::Itertools as _; use num_integer::Integer as _; -use risingwave_common::bail_not_implemented; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::{DataType, StructType}; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr::aggregate::AggKind; pub use risingwave_expr::sig::*; use super::{align_types, cast_ok_base, CastContext}; @@ -28,10 +28,14 @@ use crate::expr::{cast_ok, is_row_function, Expr as _, ExprImpl, ExprType, Funct /// is not supported on backend. /// /// It also mutates the `inputs` by adding necessary casts. -pub fn infer_type(func_type: ExprType, inputs: &mut [ExprImpl]) -> Result { - if let Some(res) = infer_type_for_special(func_type, inputs).transpose() { +pub fn infer_type(func_name: FuncName, inputs: &mut [ExprImpl]) -> Result { + // special cases + if let FuncName::Scalar(func_type) = func_name && let Some(res) = infer_type_for_special(func_type, inputs).transpose() { return res; } + if let FuncName::Aggregate(AggKind::Grouping) = func_name { + return Ok(DataType::Int32); + } let actuals = inputs .iter() @@ -40,7 +44,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut [ExprImpl]) -> Result Some(e.return_type()), }) .collect_vec(); - let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, func_name, &actuals)?; // add implicit casts to inputs for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) { @@ -81,7 +85,7 @@ pub fn infer_some_all( (!inputs[0].is_untyped()).then_some(inputs[0].return_type()), element_type.clone(), ]; - let sig = infer_type_name(&FUNCTION_REGISTRY, final_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, final_type.into(), &actuals)?; if sig.ret_type != DataType::Boolean.into() { return Err(ErrorCode::BindError(format!( "op SOME/ANY/ALL (array) requires operator to yield boolean, but got {}", @@ -273,7 +277,7 @@ fn infer_struct_cast_target_type( (NestedType::Infer(l), NestedType::Infer(r)) => { // Both sides are *unknown*, using the sig_map to infer the return type. let actuals = vec![None, None]; - let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, func_type.into(), &actuals)?; Ok(( sig.ret_type != l.into(), sig.ret_type != r.into(), @@ -584,7 +588,7 @@ fn infer_type_for_special( } /// From all available functions in `sig_map`, find and return the best matching `FuncSign` for the -/// provided `func_type` and `inputs`. This not only support exact function signature match, but can +/// provided `func_name` and `inputs`. This not only support exact function signature match, but can /// also match `substr(varchar, smallint)` or even `substr(varchar, unknown)` to `substr(varchar, /// int)`. /// @@ -593,7 +597,7 @@ fn infer_type_for_special( /// * /// /// To summarize, -/// 1. Find all functions with matching `func_type` and argument count. +/// 1. Find all functions with matching `func_name` and argument count. /// 2. For binary operator with unknown on exactly one side, try to find an exact match assuming /// both sides are same type. /// 3. Rank candidates based on most matching positions. This covers Rule 2, 4a, 4c and 4d in @@ -604,10 +608,10 @@ fn infer_type_for_special( /// 4f in `PostgreSQL`. See [`narrow_same_type`] for details. fn infer_type_name<'a>( sig_map: &'a FunctionRegistry, - func_type: ExprType, + func_name: FuncName, inputs: &[Option], ) -> Result<&'a FuncSign> { - let candidates = sig_map.get_with_arg_nums(func_type, inputs.len()); + let candidates = sig_map.get_with_arg_nums(func_name, inputs.len()); // Binary operators have a special `unknown` handling rule for exact match. We do not // distinguish operators from functions as of now. @@ -630,12 +634,12 @@ fn infer_type_name<'a>( let mut candidates = top_matches(&candidates, inputs); if candidates.is_empty() { - bail_not_implemented!( - issue = 112, - "{:?}{:?}", - func_type, - inputs.iter().map(TypeDebug).collect_vec() - ); + return Err(ErrorCode::NoFunction(format!( + "{}({})", + func_name, + inputs.iter().map(TypeDisplay).format(", ") + )) + .into()); } // After this line `candidates` will never be empty, as the narrow rules will retain original @@ -649,9 +653,9 @@ fn infer_type_name<'a>( [] => unreachable!(), [sig] => Ok(*sig), _ => Err(ErrorCode::BindError(format!( - "function {:?}{:?} is not unique\nHINT: Could not choose a best candidate function. You might need to add explicit type casts.", - func_type, - inputs.iter().map(TypeDebug).collect_vec(), + "function {}({}) is not unique\nHINT: Could not choose a best candidate function. You might need to add explicit type casts.", + func_name, + inputs.iter().map(TypeDisplay).format(", "), )) .into()), } @@ -881,8 +885,8 @@ fn narrow_same_type<'a>( } } -struct TypeDebug<'a>(&'a Option); -impl<'a> std::fmt::Debug for TypeDebug<'a> { +struct TypeDisplay<'a>(&'a Option); +impl<'a> std::fmt::Display for TypeDisplay<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.0 { Some(t) => t.fmt(f), @@ -917,7 +921,7 @@ mod tests { .into() }) .collect_vec(); - infer_type(func_type, &mut inputs) + infer_type(func_type.into(), &mut inputs) } fn test_simple_infer_type( @@ -1171,11 +1175,9 @@ mod tests { build: FuncBuilder::Scalar(|_, _| unreachable!()), type_infer: |_| unreachable!(), deprecated: false, - state_type: None, - append_only: false, }); } - let result = infer_type_name(&sig_map, ExprType::Add, inputs); + let result = infer_type_name(&sig_map, ExprType::Add.into(), inputs); match (expected, result) { (Ok(expected), Ok(found)) => { if !found.match_args(expected) { diff --git a/src/frontend/src/expr/window_function.rs b/src/frontend/src/expr/window_function.rs index d82cf4fb788d0..0a906379c5c21 100644 --- a/src/frontend/src/expr/window_function.rs +++ b/src/frontend/src/expr/window_function.rs @@ -16,9 +16,10 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; use risingwave_common::error::{ErrorCode, RwError}; use risingwave_common::types::DataType; +use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_expr::window_function::{Frame, WindowFuncKind}; -use super::{AggCall, Expr, ExprImpl, OrderBy, RwResult}; +use super::{Expr, ExprImpl, OrderBy, RwResult}; /// A window function performs a calculation across a set of table rows that are somehow related to /// the current row, according to the window spec `OVER (PARTITION BY .. ORDER BY ..)`. @@ -87,7 +88,8 @@ impl WindowFunction { (Aggregate(agg_kind), args) => { let arg_types = args.iter().map(ExprImpl::return_type).collect::>(); - AggCall::infer_return_type(agg_kind, &arg_types) + let return_type = FUNCTION_REGISTRY.get_return_type(agg_kind, &arg_types)?; + Ok(return_type) } _ => { diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 8cfaaff070554..75c86d82059da 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -24,7 +24,7 @@ use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType}; use risingwave_common::util::value_encoding::DatumToProtoExt; use risingwave_expr::aggregate::{agg_kinds, AggKind}; -use risingwave_expr::sig::FUNCTION_REGISTRY; +use risingwave_expr::sig::{FuncBuilder, FUNCTION_REGISTRY}; use risingwave_pb::expr::{PbAggCall, PbConstant}; use risingwave_pb::stream_plan::{agg_call_state, AggCallState as AggCallStatePb}; @@ -521,7 +521,7 @@ impl Agg { .zip_eq_fast(&mut out_fields[self.group_key.len()..]) { let sig = FUNCTION_REGISTRY - .get_aggregate( + .get( agg_call.agg_kind, &agg_call .inputs @@ -529,15 +529,36 @@ impl Agg { .map(|input| input.data_type.clone()) .collect_vec(), &agg_call.return_type, - in_append_only, ) .expect("agg not found"); - if !in_append_only && sig.append_only { - // we use materialized input state for non-retractable aggregate function. - // for backward compatibility, the state type is same as the return type. - // its values in the intermediate state table are always null. - } else if let Some(state_type) = &sig.state_type { - field.data_type = state_type.clone(); + // in_append_only: whether the input is append-only + // sig.is_append_only(): whether the agg function has append-only version + match (in_append_only, sig.is_append_only()) { + (false, true) => { + // we use materialized input state for non-retractable aggregate function. + // for backward compatibility, the state type is same as the return type. + // its values in the intermediate state table are always null. + } + (true, true) => { + // use append-only version + if let FuncBuilder::Aggregate { + append_only_state_type: Some(state_type), + .. + } = &sig.build + { + field.data_type = state_type.clone(); + } + } + (_, false) => { + // there is only retractable version, use it + if let FuncBuilder::Aggregate { + retractable_state_type: Some(state_type), + .. + } = &sig.build + { + field.data_type = state_type.clone(); + } + } } } let in_dist_key = self.input.distribution().dist_column_indices().to_vec(); diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index d58becfedd46f..127968f279571 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -19,6 +19,7 @@ use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_common::util::sort_util::ColumnOrder; use risingwave_common::{bail_not_implemented, not_implemented}; use risingwave_expr::aggregate::{agg_kinds, AggKind}; +use risingwave_expr::sig::FUNCTION_REGISTRY; use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder}; use super::utils::impl_distill_by_unit; @@ -490,8 +491,9 @@ impl LogicalAggBuilder { AggKind::Avg => { assert_eq!(inputs.len(), 1); - let left_return_type = - AggCall::infer_return_type(AggKind::Sum, &[inputs[0].return_type()]).unwrap(); + let left_return_type = FUNCTION_REGISTRY + .get_return_type(AggKind::Sum, &[inputs[0].return_type()]) + .unwrap(); let left_ref = self.push_agg_call(PlanAggCall { agg_kind: AggKind::Sum, return_type: left_return_type, @@ -503,8 +505,9 @@ impl LogicalAggBuilder { }); let left = ExprImpl::from(left_ref).cast_explicit(return_type).unwrap(); - let right_return_type = - AggCall::infer_return_type(AggKind::Count, &[inputs[0].return_type()]).unwrap(); + let right_return_type = FUNCTION_REGISTRY + .get_return_type(AggKind::Count, &[inputs[0].return_type()]) + .unwrap(); let right_ref = self.push_agg_call(PlanAggCall { agg_kind: AggKind::Count, return_type: right_return_type, @@ -546,9 +549,9 @@ impl LogicalAggBuilder { .add_expr(&squared_input_expr) .unwrap(); - let sum_of_squares_return_type = - AggCall::infer_return_type(AggKind::Sum, &[squared_input_expr.return_type()]) - .unwrap(); + let sum_of_squares_return_type = FUNCTION_REGISTRY + .get_return_type(AggKind::Sum, &[squared_input_expr.return_type()]) + .unwrap(); let sum_of_squares_expr = ExprImpl::from(self.push_agg_call(PlanAggCall { agg_kind: AggKind::Sum, @@ -566,8 +569,9 @@ impl LogicalAggBuilder { .unwrap(); // after that, we compute sum - let sum_return_type = - AggCall::infer_return_type(AggKind::Sum, &[input.return_type()]).unwrap(); + let sum_return_type = FUNCTION_REGISTRY + .get_return_type(AggKind::Sum, &[input.return_type()]) + .unwrap(); let sum_expr = ExprImpl::from(self.push_agg_call(PlanAggCall { agg_kind: AggKind::Sum, @@ -582,8 +586,9 @@ impl LogicalAggBuilder { .unwrap(); // then, we compute count - let count_return_type = - AggCall::infer_return_type(AggKind::Count, &[input.return_type()]).unwrap(); + let count_return_type = FUNCTION_REGISTRY + .get_return_type(AggKind::Count, &[input.return_type()]) + .unwrap(); let count_expr = ExprImpl::from(self.push_agg_call(PlanAggCall { agg_kind: AggKind::Count,