diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index e618a58500783..9a07df7558d96 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1278,7 +1278,7 @@ logical_plan: |- LogicalProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr2, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / count(t.v1)::Decimal)) as $expr3] } └─LogicalAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } - └─LogicalProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1] } + └─LogicalProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } └─LogicalScan { table: t, columns: [t.v1, t._row_id] } batch_plan: |- BatchProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int64), null:Decimal, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / (sum0(count(t.v1)) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] } @@ -1286,14 +1286,14 @@ └─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1))] } └─BatchExchange { order: [], dist: Single } └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } - └─BatchProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1] } + └─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } └─BatchScan { table: t, columns: [t.v1], distribution: SomeShard } batch_local_plan: |- BatchProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] } └─BatchProject { exprs: [sum($expr1), sum(t.v1), count(t.v1), sum(t.v1)::Decimal as $expr2, count(t.v1)::Decimal as $expr3] } └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } └─BatchExchange { order: [], dist: Single } - └─BatchProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1] } + └─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } └─BatchScan { table: t, columns: [t.v1], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [stddev_samp, stddev_pop], stream_key: [], pk_columns: [], pk_conflict: NoCheck } @@ -1302,7 +1302,7 @@ └─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), count] } └─StreamExchange { dist: Single } └─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } - └─StreamProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1, t._row_id] } + └─StreamProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1, t._row_id] } └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - name: stddev_samp with other columns sql: | @@ -1310,7 +1310,7 @@ logical_plan: |- LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32)::Decimal)) / (count(1:Int32) - 1:Int64)::Decimal))) as $expr2] } └─LogicalAgg { aggs: [count('':Varchar), sum($expr1), sum(1:Int32), count(1:Int32)] } - └─LogicalProject { exprs: ['':Varchar, 1:Int32, (1:Int32 * 1:Int32) as $expr1] } + └─LogicalProject { exprs: ['':Varchar, (1:Int32 * 1:Int32) as $expr1, 1:Int32] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } - name: stddev_samp with group sql: | @@ -1319,7 +1319,7 @@ logical_plan: |- LogicalProject { exprs: [Case((count(t.v) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v)::Decimal)) / (count(t.v) - 1:Int64)::Decimal))) as $expr2] } └─LogicalAgg { group_key: [t.w], aggs: [sum($expr1), sum(t.v), count(t.v)] } - └─LogicalProject { exprs: [t.w, t.v, (t.v * t.v) as $expr1] } + └─LogicalProject { exprs: [t.w, (t.v * t.v) as $expr1, t.v] } └─LogicalScan { table: t, columns: [t.v, t.w, t._row_id] } - name: force two phase aggregation should succeed with UpstreamHashShard and SomeShard (batch only). sql: | diff --git a/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml b/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml index ceb706446f986..0e5d72b3499a3 100644 --- a/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml @@ -67,7 +67,7 @@ └─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v))] } └─BatchExchange { order: [], dist: Single } └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] } - └─BatchProject { exprs: [t.v, (t.v * t.v) as $expr1] } + └─BatchProject { exprs: [(t.v * t.v) as $expr1, t.v] } └─BatchScan { table: t, columns: [t.v], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [stddev_pop, stddev_samp, var_pop, var_samp], stream_key: [], pk_columns: [], pk_conflict: NoCheck } @@ -78,7 +78,7 @@ └─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), count] } └─StreamExchange { dist: Single } └─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] } - └─StreamProject { exprs: [t.v, (t.v * t.v) as $expr1, t._row_id] } + └─StreamProject { exprs: [(t.v * t.v) as $expr1, t.v, t._row_id] } └─StreamTableScan { table: t, columns: [t.v, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - name: Common sub expression shouldn't extract partial expression of `some`/`all`. See 11766 sql: | diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index 0f9493a694952..353f4416af1c0 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -21,13 +21,13 @@ use crate::utils::Condition; #[derive(Clone, Eq, PartialEq, Hash)] pub struct AggCall { - agg_kind: AggKind, - return_type: DataType, - args: Vec, - distinct: bool, - order_by: OrderBy, - filter: Condition, - direct_args: Vec, + pub agg_kind: AggKind, + pub return_type: DataType, + pub args: Vec, + pub distinct: bool, + pub order_by: OrderBy, + pub filter: Condition, + pub direct_args: Vec, } impl std::fmt::Debug for AggCall { diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index cad073386a42e..f3d7d8e4d7bdf 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -18,7 +18,6 @@ use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_common::util::sort_util::ColumnOrder; use risingwave_common::{bail_not_implemented, not_implemented}; use risingwave_expr::aggregate::{agg_kinds, AggKind}; -use risingwave_expr::sig::FUNCTION_REGISTRY; use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder}; use super::utils::impl_distill_by_unit; @@ -27,7 +26,7 @@ use super::{ PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamProject, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream, }; -use crate::error::{ErrorCode, Result}; +use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{ AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal, OrderBy, WindowFunction, @@ -262,7 +261,7 @@ pub struct LogicalAggBuilder { /// the agg calls agg_calls: Vec, /// the error during the expression rewriting - error: Option, + error: Option, /// If `is_in_filter_clause` is true, it means that /// we are processing filter clause. /// This field is needed because input refs in these clauses @@ -354,7 +353,7 @@ impl LogicalAggBuilder { fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result { let rewritten_expr = self.rewrite_expr(expr); if let Some(error) = self.error.take() { - return Err(error.into()); + return Err(error); } Ok(rewritten_expr) } @@ -377,51 +376,151 @@ impl LogicalAggBuilder { self.group_key.len() } - /// Push a new planned agg call into the builder. - /// Return an `InputRef` to that agg call. - /// For existing agg calls, return an `InputRef` to the existing one. - fn push_agg_call(&mut self, agg_call: PlanAggCall) -> InputRef { - if let Some((pos, existing)) = self.agg_calls.iter().find_position(|&c| c == &agg_call) { - return InputRef::new( - self.schema_agg_start_offset() + pos, - existing.return_type.clone(), - ); - } - let index = self.schema_agg_start_offset() + self.agg_calls.len(); - let data_type = agg_call.return_type.clone(); - self.agg_calls.push(agg_call); - InputRef::new(index, data_type) - } - - /// When there is an agg call, there are 3 things to do: - /// 1. eval its inputs via project; - /// 2. add a `PlanAggCall` to agg; - /// 3. rewrite it as an `InputRef` to the agg result in select list. - /// - /// Note that the rewriter does not traverse into inputs of agg calls. - fn try_rewrite_agg_call( - &mut self, + /// Rewrite [`AggCall`] if needed, and push it into the builder using `push_agg_call`. + /// This is shared by [`LogicalAggBuilder`] and `LogicalOverWindowBuilder`. + pub(crate) fn general_rewrite_agg_call( agg_call: AggCall, - ) -> std::result::Result { - let return_type = agg_call.return_type(); - let (agg_kind, inputs, mut distinct, mut order_by, filter, direct_args) = - agg_call.decompose(); + mut push_agg_call: impl FnMut(AggCall) -> Result, + ) -> Result { + match agg_call.agg_kind { + // Rewrite avg to cast(sum as avg_return_type) / count. + AggKind::Avg => { + assert_eq!(agg_call.args.len(), 1); + + let sum = ExprImpl::from(push_agg_call(AggCall::new( + AggKind::Sum, + agg_call.args.clone(), + agg_call.distinct, + agg_call.order_by.clone(), + agg_call.filter.clone(), + agg_call.direct_args.clone(), + )?)?) + .cast_explicit(agg_call.return_type())?; + + let count = ExprImpl::from(push_agg_call(AggCall::new( + AggKind::Count, + agg_call.args.clone(), + agg_call.distinct, + agg_call.order_by.clone(), + agg_call.filter.clone(), + agg_call.direct_args.clone(), + )?)?); + + Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into()) + } + // We compute `var_samp` as + // (sum(sq) - sum * sum / count) / (count - 1) + // and `var_pop` as + // (sum(sq) - sum * sum / count) / count + // Since we don't have the square function, we use the plain Multiply for squaring, + // which is in a sense more general than the pow function, especially when calculating + // covariances in the future. Also we don't have the sqrt function for rooting, so we + // use pow(x, 0.5) to simulate + kind @ (AggKind::StddevPop + | AggKind::StddevSamp + | AggKind::VarPop + | AggKind::VarSamp) => { + let arg = agg_call.args().iter().exactly_one().unwrap(); + let squared_arg = ExprImpl::from(FunctionCall::new( + ExprType::Multiply, + vec![arg.clone(), arg.clone()], + )?); + + let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new( + AggKind::Sum, + vec![squared_arg], + agg_call.distinct, + agg_call.order_by.clone(), + agg_call.filter.clone(), + agg_call.direct_args.clone(), + )?)?) + .cast_explicit(agg_call.return_type())?; + + let sum = ExprImpl::from(push_agg_call(AggCall::new( + AggKind::Sum, + agg_call.args.clone(), + agg_call.distinct, + agg_call.order_by.clone(), + agg_call.filter.clone(), + agg_call.direct_args.clone(), + )?)?) + .cast_explicit(agg_call.return_type())?; + + let count = ExprImpl::from(push_agg_call(AggCall::new( + AggKind::Count, + agg_call.args.clone(), + agg_call.distinct, + agg_call.order_by.clone(), + agg_call.filter.clone(), + agg_call.direct_args.clone(), + )?)?); + + let one = ExprImpl::from(Literal::new( + Datum::from(ScalarImpl::Int64(1)), + DataType::Int64, + )); - if matches!(agg_kind, agg_kinds::must_have_order_by!()) && order_by.sort_exprs.is_empty() { - return Err(ErrorCode::InvalidInputSyntax(format!( - "Aggregation function {} requires ORDER BY clause", - agg_kind - ))); - } + let squared_sum = ExprImpl::from(FunctionCall::new( + ExprType::Multiply, + vec![sum.clone(), sum], + )?); + + let numerator = ExprImpl::from(FunctionCall::new( + ExprType::Subtract, + vec![ + sum_of_sq, + ExprImpl::from(FunctionCall::new( + ExprType::Divide, + vec![squared_sum, count.clone()], + )?), + ], + )?); + + let denominator = match kind { + AggKind::VarPop | AggKind::StddevPop => count.clone(), + AggKind::VarSamp | AggKind::StddevSamp => ExprImpl::from(FunctionCall::new( + ExprType::Subtract, + vec![count.clone(), one.clone()], + )?), + _ => unreachable!(), + }; - // try ignore ORDER BY if it doesn't affect the result - if matches!(agg_kind, agg_kinds::result_unaffected_by_order_by!()) { - order_by = OrderBy::any(); - } - // try ignore DISTINCT if it doesn't affect the result - if matches!(agg_kind, agg_kinds::result_unaffected_by_distinct!()) { - distinct = false; + let mut target = ExprImpl::from(FunctionCall::new( + ExprType::Divide, + vec![numerator, denominator], + )?); + + if matches!(kind, AggKind::StddevPop | AggKind::StddevSamp) { + target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?); + } + + match kind { + AggKind::VarPop | AggKind::StddevPop => Ok(target), + AggKind::StddevSamp | AggKind::VarSamp => { + let case_cond = ExprImpl::from(FunctionCall::new( + ExprType::LessThanOrEqual, + vec![count, one], + )?); + let null = ExprImpl::from(Literal::new(None, agg_call.return_type())); + + Ok(ExprImpl::from(FunctionCall::new( + ExprType::Case, + vec![case_cond, null, target], + )?)) + } + _ => unreachable!(), + } + } + _ => Ok(push_agg_call(agg_call)?.into()), } + } + + /// Push a new agg call into the builder. + /// Return an `InputRef` to that agg call. + /// For existing agg calls, return an `InputRef` to the existing one. + fn push_agg_call(&mut self, agg_call: AggCall) -> Result { + let return_type = agg_call.return_type(); + let (kind, args, distinct, order_by, filter, direct_args) = agg_call.decompose(); self.is_in_filter_clause = true; // filter expr is not added to `input_proj_builder` as a whole. Special exprs incl @@ -429,27 +528,7 @@ impl LogicalAggBuilder { let filter = filter.rewrite_expr(self); self.is_in_filter_clause = false; - if matches!(agg_kind, AggKind::Grouping) { - if self.grouping_sets.is_empty() { - return Err(ErrorCode::NotSupported( - "GROUPING must be used in a query with grouping sets".into(), - "try to use grouping sets instead".into(), - )); - } - if inputs.len() >= 32 { - return Err(ErrorCode::InvalidInputSyntax( - "GROUPING must have fewer than 32 arguments".into(), - )); - } - if inputs.iter().any(|x| self.try_as_group_expr(x).is_none()) { - return Err(ErrorCode::InvalidInputSyntax( - "arguments to GROUPING must be grouping expressions of the associated query level" - .into(), - )); - } - } - - let inputs: Vec<_> = inputs + let args: Vec<_> = args .iter() .map(|expr| { let index = self.input_proj_builder.add_expr(expr)?; @@ -470,220 +549,90 @@ impl LogicalAggBuilder { not_implemented!("{err} inside aggregation calls order by") })?; - match agg_kind { - // Rewrite avg to cast(sum as avg_return_type) / count. - AggKind::Avg => { - assert_eq!(inputs.len(), 1); - - let left_return_type = FUNCTION_REGISTRY - .get_return_type(AggKind::Sum, &[inputs[0].return_type()]) - .unwrap(); - let left_ref = self.push_agg_call(PlanAggCall { - agg_kind: AggKind::Sum, - return_type: left_return_type, - inputs: inputs.clone(), - distinct, - order_by: order_by.clone(), - filter: filter.clone(), - direct_args: direct_args.clone(), - }); - let left = ExprImpl::from(left_ref).cast_explicit(return_type).unwrap(); - - let right_return_type = FUNCTION_REGISTRY - .get_return_type(AggKind::Count, &[inputs[0].return_type()]) - .unwrap(); - let right_ref = self.push_agg_call(PlanAggCall { - agg_kind: AggKind::Count, - return_type: right_return_type, - inputs, - distinct, - order_by, - filter, - direct_args, - }); - - Ok(ExprImpl::from( - FunctionCall::new(ExprType::Divide, vec![left, right_ref.into()]).unwrap(), - )) - } + let plan_agg_call = PlanAggCall { + agg_kind: kind, + return_type: return_type.clone(), + inputs: args, + distinct, + order_by, + filter, + direct_args, + }; - // We compute `var_samp` as - // (sum(sq) - sum * sum / count) / (count - 1) - // and `var_pop` as - // (sum(sq) - sum * sum / count) / count - // Since we don't have the square function, we use the plain Multiply for squaring, - // which is in a sense more general than the pow function, especially when calculating - // covariances in the future. Also we don't have the sqrt function for rooting, so we - // use pow(x, 0.5) to simulate - AggKind::StddevPop | AggKind::StddevSamp | AggKind::VarPop | AggKind::VarSamp => { - let input = inputs.iter().exactly_one().unwrap(); - let pre_proj_input = self.input_proj_builder.get_expr(input.index).unwrap(); - - // first, we compute sum of squared as sum_sq - let squared_input_expr = ExprImpl::from( - FunctionCall::new( - ExprType::Multiply, - vec![pre_proj_input.clone(), pre_proj_input.clone()], - ) - .unwrap(), - ); - - let squared_input_proj_index = self - .input_proj_builder - .add_expr(&squared_input_expr) - .unwrap(); - - let sum_of_squares_return_type = FUNCTION_REGISTRY - .get_return_type(AggKind::Sum, &[squared_input_expr.return_type()]) - .unwrap(); - - let sum_of_squares_expr = ExprImpl::from(self.push_agg_call(PlanAggCall { - agg_kind: AggKind::Sum, - return_type: sum_of_squares_return_type, - inputs: vec![InputRef::new( - squared_input_proj_index, - squared_input_expr.return_type(), - )], - distinct, - order_by: order_by.clone(), - filter: filter.clone(), - direct_args: direct_args.clone(), - })) - .cast_explicit(return_type.clone()) - .unwrap(); - - // after that, we compute sum - let sum_return_type = FUNCTION_REGISTRY - .get_return_type(AggKind::Sum, &[input.return_type()]) - .unwrap(); - - let sum_expr = ExprImpl::from(self.push_agg_call(PlanAggCall { - agg_kind: AggKind::Sum, - return_type: sum_return_type, - inputs: inputs.clone(), - distinct, - order_by: order_by.clone(), - filter: filter.clone(), - direct_args: direct_args.clone(), - })) - .cast_explicit(return_type.clone()) - .unwrap(); - - // then, we compute count - let count_return_type = FUNCTION_REGISTRY - .get_return_type(AggKind::Count, &[input.return_type()]) - .unwrap(); - - let count_expr = ExprImpl::from(self.push_agg_call(PlanAggCall { - agg_kind: AggKind::Count, - return_type: count_return_type, - inputs, - distinct, - order_by, - filter, - direct_args, - })); - - // we start with variance - - // sum * sum - let square_of_sum_expr = ExprImpl::from( - FunctionCall::new(ExprType::Multiply, vec![sum_expr.clone(), sum_expr]) - .unwrap(), - ); - - // sum_sq - sum * sum / count - let numerator_expr = ExprImpl::from( - FunctionCall::new( - ExprType::Subtract, - vec![ - sum_of_squares_expr, - ExprImpl::from( - FunctionCall::new( - ExprType::Divide, - vec![square_of_sum_expr, count_expr.clone()], - ) - .unwrap(), - ), - ], - ) - .unwrap(), - ); - - // count or count - 1 - let denominator_expr = match agg_kind { - AggKind::StddevPop | AggKind::VarPop => count_expr.clone(), - AggKind::StddevSamp | AggKind::VarSamp => ExprImpl::from( - FunctionCall::new( - ExprType::Subtract, - vec![ - count_expr.clone(), - ExprImpl::from(Literal::new( - Datum::from(ScalarImpl::Int64(1)), - DataType::Int64, - )), - ], - ) - .unwrap(), - ), - _ => unreachable!(), - }; + if let Some((pos, existing)) = self + .agg_calls + .iter() + .find_position(|&c| c == &plan_agg_call) + { + return Ok(InputRef::new( + self.schema_agg_start_offset() + pos, + existing.return_type.clone(), + )); + } + let index = self.schema_agg_start_offset() + self.agg_calls.len(); + self.agg_calls.push(plan_agg_call); + Ok(InputRef::new(index, return_type)) + } - let mut target_expr = ExprImpl::from( - FunctionCall::new(ExprType::Divide, vec![numerator_expr, denominator_expr]) - .unwrap(), - ); + /// When there is an agg call, there are 3 things to do: + /// 1. Rewrite `avg`, `var_samp`, etc. into a combination of `sum`, `count`, etc.; + /// 2. Add exprs in arguments to input `Project`; + /// 2. Add the agg call to current `Agg`, and return an `InputRef` to it. + /// + /// Note that the rewriter does not traverse into inputs of agg calls. + fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result { + if matches!(agg_call.agg_kind, agg_kinds::must_have_order_by!()) + && agg_call.order_by.sort_exprs.is_empty() + { + return Err(ErrorCode::InvalidInputSyntax(format!( + "Aggregation function {} requires ORDER BY clause", + agg_call.agg_kind + )) + .into()); + } - // stddev = sqrt(variance) - if matches!(agg_kind, AggKind::StddevPop | AggKind::StddevSamp) { - target_expr = ExprImpl::from( - FunctionCall::new(ExprType::Sqrt, vec![target_expr]).unwrap(), - ); - } + // try ignore ORDER BY if it doesn't affect the result + if matches!( + agg_call.agg_kind, + agg_kinds::result_unaffected_by_order_by!() + ) { + agg_call.order_by = OrderBy::any(); + } + // try ignore DISTINCT if it doesn't affect the result + if matches!( + agg_call.agg_kind, + agg_kinds::result_unaffected_by_distinct!() + ) { + agg_call.distinct = false; + } - match agg_kind { - AggKind::VarPop | AggKind::StddevPop => Ok(target_expr), - AggKind::StddevSamp | AggKind::VarSamp => { - let less_than_expr = ExprImpl::from( - FunctionCall::new( - ExprType::LessThanOrEqual, - vec![ - count_expr, - ExprImpl::from(Literal::new( - Datum::from(ScalarImpl::Int64(1)), - DataType::Int64, - )), - ], - ) - .unwrap(), - ); - let null_expr = ExprImpl::from(Literal::new(None, return_type)); - - let case_expr = ExprImpl::from( - FunctionCall::new( - ExprType::Case, - vec![less_than_expr, null_expr, target_expr], - ) - .unwrap(), - ); - - Ok(case_expr) - } - _ => unreachable!(), - } + if matches!(agg_call.agg_kind, AggKind::Grouping) { + if self.grouping_sets.is_empty() { + return Err(ErrorCode::NotSupported( + "GROUPING must be used in a query with grouping sets".into(), + "try to use grouping sets instead".into(), + ) + .into()); + } + if agg_call.args.len() >= 32 { + return Err(ErrorCode::InvalidInputSyntax( + "GROUPING must have fewer than 32 arguments".into(), + ) + .into()); + } + if agg_call + .args + .iter() + .any(|x| self.try_as_group_expr(x).is_none()) + { + return Err(ErrorCode::InvalidInputSyntax( + "arguments to GROUPING must be grouping expressions of the associated query level" + .into(), + ).into()); } - _ => Ok(self - .push_agg_call(PlanAggCall { - agg_kind, - return_type, - inputs, - distinct, - order_by, - filter, - direct_args, - }) - .into()), } + + Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call)) } } @@ -754,10 +703,13 @@ impl ExprRewriter for LogicalAggBuilder { ) .into() } else { - self.error = Some(ErrorCode::InvalidInputSyntax( - "column must appear in the GROUP BY clause or be used in an aggregate function" - .into(), - )); + self.error = Some( + ErrorCode::InvalidInputSyntax( + "column must appear in the GROUP BY clause or be used in an aggregate function" + .into(), + ) + .into(), + ); expr } } diff --git a/src/frontend/src/optimizer/plan_node/logical_over_window.rs b/src/frontend/src/optimizer/plan_node/logical_over_window.rs index af1fc50c9057f..b5762a224d180 100644 --- a/src/frontend/src/optimizer/plan_node/logical_over_window.rs +++ b/src/frontend/src/optimizer/plan_node/logical_over_window.rs @@ -14,7 +14,7 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; -use risingwave_common::types::{DataType, Datum, ScalarImpl}; +use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::{bail_not_implemented, not_implemented}; use risingwave_expr::aggregate::AggKind; @@ -29,9 +29,11 @@ use super::{ }; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{ - Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, WindowFunction, + AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, + WindowFunction, }; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::logical_agg::LogicalAggBuilder; use crate::optimizer::plan_node::{ ColumnPruningContext, Literal, PredicatePushdownContext, RewriteStreamContext, ToStreamContext, }; @@ -103,7 +105,7 @@ impl<'a> LogicalOverWindowBuilder<'a> { window_func.frame, ); - if let WindowFuncKind::Aggregate(agg_kind) = kind + let new_expr = if let WindowFuncKind::Aggregate(agg_kind) = kind && matches!( agg_kind, AggKind::Avg @@ -111,148 +113,39 @@ impl<'a> LogicalOverWindowBuilder<'a> { | AggKind::StddevSamp | AggKind::VarPop | AggKind::VarSamp - ) - { - // Refer to `LogicalAggBuilder::try_rewrite_agg_call` - match agg_kind { - AggKind::Avg => { - assert_eq!(args.len(), 1); - let left_ref = ExprImpl::from(self.push_window_func(WindowFunction::new( - WindowFuncKind::Aggregate(AggKind::Sum), - partition_by.clone(), - order_by.clone(), - args.clone(), - frame.clone(), - )?)) - .cast_explicit(return_type)?; - let right_ref = ExprImpl::from(self.push_window_func(WindowFunction::new( - WindowFuncKind::Aggregate(AggKind::Count), - partition_by, - order_by, - args, - frame, - )?)); - - let new_expr = ExprImpl::from(FunctionCall::new( - ExprType::Divide, - vec![left_ref, right_ref], - )?); - Ok(new_expr) - } - AggKind::StddevPop | AggKind::StddevSamp | AggKind::VarPop | AggKind::VarSamp => { - let input = args.first().unwrap(); - let squared_input_expr = ExprImpl::from(FunctionCall::new( - ExprType::Multiply, - vec![input.clone(), input.clone()], - )?); - - let sum_of_squares_expr = - ExprImpl::from(self.push_window_func(WindowFunction::new( - WindowFuncKind::Aggregate(AggKind::Sum), - partition_by.clone(), - order_by.clone(), - vec![squared_input_expr], - frame.clone(), - )?)) - .cast_explicit(return_type.clone())?; - - let sum_expr = ExprImpl::from(self.push_window_func(WindowFunction::new( - WindowFuncKind::Aggregate(AggKind::Sum), + ) { + let agg_call = AggCall::new( + agg_kind, + args, + false, + order_by, + Condition::true_cond(), + vec![], + )?; + LogicalAggBuilder::general_rewrite_agg_call(agg_call, |agg_call| { + Ok(self.push_window_func( + // AggCall -> WindowFunction + WindowFunction::new( + WindowFuncKind::Aggregate(agg_call.agg_kind), partition_by.clone(), - order_by.clone(), - args.clone(), + agg_call.order_by.clone(), + agg_call.args.clone(), frame.clone(), - )?)) - .cast_explicit(return_type.clone())?; - - let count_expr = ExprImpl::from(self.push_window_func(WindowFunction::new( - WindowFuncKind::Aggregate(AggKind::Count), - partition_by, - order_by, - args.clone(), - frame, - )?)); - - let square_of_sum_expr = ExprImpl::from(FunctionCall::new( - ExprType::Multiply, - vec![sum_expr.clone(), sum_expr], - )?); - - let numerator_expr = ExprImpl::from(FunctionCall::new( - ExprType::Subtract, - vec![ - sum_of_squares_expr, - ExprImpl::from(FunctionCall::new( - ExprType::Divide, - vec![square_of_sum_expr, count_expr.clone()], - )?), - ], - )?); - - let denominator_expr = match agg_kind { - AggKind::StddevPop | AggKind::VarPop => count_expr.clone(), - AggKind::StddevSamp | AggKind::VarSamp => { - ExprImpl::from(FunctionCall::new( - ExprType::Subtract, - vec![ - count_expr.clone(), - ExprImpl::from(Literal::new( - Datum::from(ScalarImpl::Int64(1)), - DataType::Int64, - )), - ], - )?) - } - _ => unreachable!(), - }; - - let mut target_expr = ExprImpl::from(FunctionCall::new( - ExprType::Divide, - vec![numerator_expr, denominator_expr], - )?); - - if matches!(agg_kind, AggKind::StddevPop | AggKind::StddevSamp) { - target_expr = ExprImpl::from( - FunctionCall::new(ExprType::Sqrt, vec![target_expr]).unwrap(), - ); - } - - match agg_kind { - AggKind::VarPop | AggKind::StddevPop => Ok(target_expr), - AggKind::StddevSamp | AggKind::VarSamp => { - let less_than_expr = ExprImpl::from(FunctionCall::new( - ExprType::LessThanOrEqual, - vec![ - count_expr, - ExprImpl::from(Literal::new( - Datum::from(ScalarImpl::Int64(1)), - DataType::Int64, - )), - ], - )?); - let null_expr = ExprImpl::from(Literal::new(None, return_type)); - - let case_expr = ExprImpl::from(FunctionCall::new( - ExprType::Case, - vec![less_than_expr, null_expr, target_expr], - )?); - Ok(case_expr) - } - _ => unreachable!(), - } - } - _ => unreachable!(), - } + )?, + )) + })? } else { - let new_expr = ExprImpl::from(self.push_window_func(WindowFunction::new( + ExprImpl::from(self.push_window_func(WindowFunction::new( kind, partition_by, order_by, args, frame, - )?)); - Ok(new_expr) - } + )?)) + }; + + assert_eq!(new_expr.return_type(), return_type); + Ok(new_expr) } }