From 9f9c1dab6049d696b0e023cd6108b7f49f9e4059 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Thu, 29 Aug 2024 10:52:00 +0800 Subject: [PATCH] feat(flow): use DataFusion's optimizer (#4489) * feat: use datafusion optimization refactor: mv `sql_to_flow_plan` elsewhere feat(WIP): use df optimization WIP analyzer rule feat(WIP): avg expander fix: transform avg expander fix: avg expand feat: names from substrait fix: avg rewrite test: update `test_avg`&`test_avg_group_by` test: fix `test_sum` test: fix some tests chore: remove unused flow plan transform feat: tumble expander test: update tests * chore: clippy * fix: tumble lose `group expr` * test: sqlness test update * test: rm unused cast * test: simplify sqlness * refactor: per review * chore: after rebase * fix: remove a outdated test * test: add comment * fix: report error when not literal * chore: update sqlness test after rebase * refactor: per review --- src/flow/src/adapter.rs | 2 +- src/flow/src/df_optimizer.rs | 604 ++++++++++++++ src/flow/src/expr.rs | 3 + src/flow/src/expr/func.rs | 132 ++- src/flow/src/expr/scalar.rs | 80 +- src/flow/src/expr/signature.rs | 2 + src/flow/src/lib.rs | 1 + src/flow/src/plan.rs | 2 + src/flow/src/repr/relation.rs | 6 +- src/flow/src/transform.rs | 95 +-- src/flow/src/transform/aggr.rs | 749 ++++++++++-------- src/flow/src/transform/expr.rs | 115 +-- src/flow/src/transform/literal.rs | 4 +- src/flow/src/transform/plan.rs | 145 +--- .../standalone/common/flow/flow_basic.result | 54 +- .../standalone/common/flow/flow_basic.sql | 10 +- .../common/flow/flow_call_df_func.result | 114 +-- .../common/flow/flow_call_df_func.sql | 26 +- 18 files changed, 1324 insertions(+), 820 deletions(-) create mode 100644 src/flow/src/df_optimizer.rs diff --git a/src/flow/src/adapter.rs b/src/flow/src/adapter.rs index b45fc1e88916..04f7fd80b3f2 100644 --- a/src/flow/src/adapter.rs +++ b/src/flow/src/adapter.rs @@ -49,13 +49,13 @@ use crate::adapter::table_source::TableSource; use crate::adapter::util::column_schemas_to_proto; use crate::adapter::worker::{create_worker, Worker, WorkerHandle}; use crate::compute::ErrCollector; +use crate::df_optimizer::sql_to_flow_plan; use crate::error::{ExternalSnafu, InternalSnafu, TableNotFoundSnafu, UnexpectedSnafu}; use crate::expr::GlobalId; use crate::metrics::{ METRIC_FLOW_INPUT_BUF_SIZE, METRIC_FLOW_INSERT_ELAPSED, METRIC_FLOW_RUN_INTERVAL_MS, }; use crate::repr::{self, DiffRow, Row, BATCH_SIZE}; -use crate::transform::sql_to_flow_plan; mod flownode_impl; mod parse_expr; diff --git a/src/flow/src/df_optimizer.rs b/src/flow/src/df_optimizer.rs new file mode 100644 index 000000000000..d5368d5189e6 --- /dev/null +++ b/src/flow/src/df_optimizer.rs @@ -0,0 +1,604 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Datafusion optimizer for flow plan + +#![warn(unused)] + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use common_error::ext::BoxedError; +use common_telemetry::debug; +use datafusion::config::ConfigOptions; +use datafusion::error::DataFusionError; +use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; +use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate; +use datafusion::optimizer::optimize_projections::OptimizeProjections; +use datafusion::optimizer::simplify_expressions::SimplifyExpressions; +use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; +use datafusion::optimizer::utils::NamePreserver; +use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, +}; +use datafusion_common::{Column, DFSchema, ScalarValue}; +use datafusion_expr::aggregate_function::AggregateFunction; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::utils::merge_schema; +use datafusion_expr::{ + BinaryExpr, Expr, Operator, Projection, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use query::parser::QueryLanguageParser; +use query::plan::LogicalPlan; +use query::query_engine::DefaultSerializer; +use query::QueryEngine; +use snafu::ResultExt; +/// note here we are using the `substrait_proto_df` crate from the `substrait` module and +/// rename it to `substrait_proto` +use substrait::DFLogicalSubstraitConvertor; + +use crate::adapter::FlownodeContext; +use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu}; +use crate::expr::{TUMBLE_END, TUMBLE_START}; +use crate::plan::TypedPlan; + +// TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed +pub async fn apply_df_optimizer( + plan: datafusion_expr::LogicalPlan, +) -> Result { + let cfg = ConfigOptions::new(); + let analyzer = Analyzer::with_rules(vec![ + Arc::new(AvgExpandRule::new()), + Arc::new(TumbleExpandRule::new()), + Arc::new(CheckGroupByRule::new()), + Arc::new(TypeCoercion::new()), + ]); + let plan = analyzer + .execute_and_check(plan, &cfg, |p, r| { + debug!("After apply rule {}, get plan: \n{:?}", r.name(), p); + }) + .context(DatafusionSnafu { + context: "Fail to apply analyzer", + })?; + + let ctx = OptimizerContext::new(); + let optimizer = Optimizer::with_rules(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(CommonSubexprEliminate::new()), + Arc::new(SimplifyExpressions::new()), + Arc::new(UnwrapCastInComparison::new()), + ]); + let plan = optimizer + .optimize(plan, &ctx, |_, _| {}) + .context(DatafusionSnafu { + context: "Fail to apply optimizer", + })?; + + Ok(plan) +} + +/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan, +/// then to a substrait plan, and finally to a flow plan. +pub async fn sql_to_flow_plan( + ctx: &mut FlownodeContext, + engine: &Arc, + sql: &str, +) -> Result { + let query_ctx = ctx.query_context.clone().ok_or_else(|| { + UnexpectedSnafu { + reason: "Query context is missing", + } + .build() + })?; + let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + let plan = engine + .planner() + .plan(stmt, query_ctx) + .await + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + let LogicalPlan::DfPlan(plan) = plan; + + let opted_plan = apply_df_optimizer(plan).await?; + + // TODO(discord9): add df optimization + let sub_plan = DFLogicalSubstraitConvertor {} + .to_sub_plan(&opted_plan, DefaultSerializer) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + + let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?; + + Ok(flow_plan) +} + +struct AvgExpandRule {} + +impl AvgExpandRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for AvgExpandRule { + fn analyze( + &self, + plan: datafusion_expr::LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + let transformed = plan + .transform_up_with_subqueries(expand_avg_analyzer)? + .data + .transform_down_with_subqueries(put_aggr_to_proj_analyzer)? + .data; + Ok(transformed) + } + + fn name(&self) -> &str { + "avg_expand" + } +} + +/// lift aggr's composite aggr_expr to outer proj, and leave aggr only with simple direct aggr expr +/// i.e. +/// ```ignore +/// proj: avg(x) +/// -- aggr: [sum(x)/count(x) as avg(x)] +/// ``` +/// becomes: +/// ```ignore +/// proj: sum(x)/count(x) as avg(x) +/// -- aggr: [sum(x), count(x)] +/// ``` +fn put_aggr_to_proj_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + if let datafusion_expr::LogicalPlan::Projection(proj) = &plan { + if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() { + let mut replace_old_proj_exprs = HashMap::new(); + let mut expanded_aggr_exprs = vec![]; + for aggr_expr in &aggr.aggr_expr { + let mut is_composite = false; + if let Expr::AggregateFunction(_) = &aggr_expr { + expanded_aggr_exprs.push(aggr_expr.clone()); + } else { + let old_name = aggr_expr.name_for_alias()?; + let new_proj_expr = aggr_expr + .clone() + .transform(|ch| { + if let Expr::AggregateFunction(_) = &ch { + is_composite = true; + expanded_aggr_exprs.push(ch.clone()); + Ok(Transformed::yes(Expr::Column(Column::from_qualified_name( + ch.name_for_alias()?, + )))) + } else { + Ok(Transformed::no(ch)) + } + })? + .data; + replace_old_proj_exprs.insert(old_name, new_proj_expr); + } + } + + if expanded_aggr_exprs.len() > aggr.aggr_expr.len() { + let mut aggr = aggr.clone(); + aggr.aggr_expr = expanded_aggr_exprs; + let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr); + // important to recompute schema after changing aggr_expr + aggr_plan = aggr_plan.recompute_schema()?; + + // reconstruct proj with new proj_exprs + let mut new_proj_exprs = proj.expr.clone(); + for proj_expr in new_proj_exprs.iter_mut() { + if let Some(new_proj_expr) = + replace_old_proj_exprs.get(&proj_expr.name_for_alias()?) + { + *proj_expr = new_proj_expr.clone(); + } + *proj_expr = proj_expr + .clone() + .transform(|expr| { + if let Some(new_expr) = + replace_old_proj_exprs.get(&expr.name_for_alias()?) + { + Ok(Transformed::yes(new_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + } + let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new( + new_proj_exprs, + Arc::new(aggr_plan), + )?); + return Ok(Transformed::yes(proj)); + } + } + } + Ok(Transformed::no(plan)) +} + +/// expand `avg()` function into `cast(sum(() AS f64)/count(()` +fn expand_avg_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + let mut schema = merge_schema(plan.inputs()); + + if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan { + let source_schema = + DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?; + schema.merge(&source_schema); + } + + let mut expr_rewrite = ExpandAvgRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan individually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + .map_data(|plan| plan.recompute_schema()) +} + +/// rewrite `avg()` function into `CASE WHEN count() !=0 THEN cast(sum(() AS avg_return_type)/count(() ELSE 0` +/// +/// TODO(discord9): support avg return type decimal128 +/// +/// see impl details at https://github.com/apache/datafusion/blob/4ad4f90d86c57226a4e0fb1f79dfaaf0d404c273/datafusion/expr/src/type_coercion/aggregates.rs#L457-L462 +pub(crate) struct ExpandAvgRewriter<'a> { + /// schema of the plan + #[allow(unused)] + pub(crate) schema: &'a DFSchema, +} + +impl<'a> ExpandAvgRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } +} + +impl<'a> TreeNodeRewriter for ExpandAvgRewriter<'a> { + type Node = Expr; + + fn f_up(&mut self, expr: Expr) -> Result, DataFusionError> { + if let Expr::AggregateFunction(aggr_func) = &expr { + if let AggregateFunctionDefinition::BuiltIn(AggregateFunction::Avg) = + &aggr_func.func_def + { + let sum_expr = { + let mut tmp = aggr_func.clone(); + tmp.func_def = AggregateFunctionDefinition::BuiltIn(AggregateFunction::Sum); + Expr::AggregateFunction(tmp) + }; + let sum_cast = { + let mut tmp = sum_expr.clone(); + tmp = Expr::Cast(datafusion_expr::Cast { + expr: Box::new(tmp), + data_type: arrow_schema::DataType::Float64, + }); + tmp + }; + + let count_expr = { + let mut tmp = aggr_func.clone(); + tmp.func_def = AggregateFunctionDefinition::BuiltIn(AggregateFunction::Count); + + Expr::AggregateFunction(tmp) + }; + let count_expr_ref = + Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?)); + + let div = + BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr)); + let div_expr = Box::new(Expr::BinaryExpr(div)); + + let zero = Box::new(Expr::Literal(ScalarValue::Int64(Some(0)))); + let not_zero = + BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone()); + let not_zero = Box::new(Expr::BinaryExpr(not_zero)); + let null = Box::new(Expr::Literal(ScalarValue::Null)); + + let case_when = + datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null)); + let case_when_expr = Expr::Case(case_when); + + return Ok(Transformed::yes(case_when_expr)); + } + } + + Ok(Transformed::no(expr)) + } +} + +/// expand tumble in aggr expr to tumble_start and tumble_end with column name like `window_start` +struct TumbleExpandRule {} + +impl TumbleExpandRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for TumbleExpandRule { + fn analyze( + &self, + plan: datafusion_expr::LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + let transformed = plan + .transform_up_with_subqueries(expand_tumble_analyzer)? + .data; + Ok(transformed) + } + + fn name(&self) -> &str { + "tumble_expand" + } +} + +/// expand `tumble` in aggr expr to `tumble_start` and `tumble_end`, also expand related alias and column ref +/// +/// will add `tumble_start` and `tumble_end` to outer projection if not exist before +fn expand_tumble_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + if let datafusion_expr::LogicalPlan::Projection(proj) = &plan { + if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() { + let mut new_group_expr = vec![]; + let mut alias_to_expand = HashMap::new(); + let mut encountered_tumble = false; + for expr in aggr.group_expr.iter() { + match expr { + datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => { + encountered_tumble = true; + + let tumble_start = TumbleExpand::new(TUMBLE_START); + let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf( + Arc::new(tumble_start.into()), + func.args.clone(), + ); + let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start); + let start_col_name = tumble_start.name_for_alias()?; + new_group_expr.push(tumble_start); + + let tumble_end = TumbleExpand::new(TUMBLE_END); + let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf( + Arc::new(tumble_end.into()), + func.args.clone(), + ); + let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end); + let end_col_name = tumble_end.name_for_alias()?; + new_group_expr.push(tumble_end); + + alias_to_expand + .insert(expr.name_for_alias()?, (start_col_name, end_col_name)); + } + _ => new_group_expr.push(expr.clone()), + } + } + if !encountered_tumble { + return Ok(Transformed::no(plan)); + } + let mut new_aggr = aggr.clone(); + new_aggr.group_expr = new_group_expr; + let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?; + // replace alias in projection if needed, and add new column ref if necessary + let mut new_proj_expr = vec![]; + let mut have_expanded = false; + + for proj_expr in proj.expr.iter() { + if let Some((start_col_name, end_col_name)) = + alias_to_expand.get(&proj_expr.name_for_alias()?) + { + let start_col = Column::from_qualified_name(start_col_name); + let end_col = Column::from_qualified_name(end_col_name); + new_proj_expr.push(datafusion_expr::Expr::Column(start_col)); + new_proj_expr.push(datafusion_expr::Expr::Column(end_col)); + have_expanded = true; + } else { + new_proj_expr.push(proj_expr.clone()); + } + } + + // append to end of projection if not exist + if !have_expanded { + for (start_col_name, end_col_name) in alias_to_expand.values() { + let start_col = Column::from_qualified_name(start_col_name); + let end_col = Column::from_qualified_name(end_col_name); + new_proj_expr + .push(datafusion_expr::Expr::Column(start_col).alias("window_start")); + new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end")); + } + } + + let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new( + new_proj_expr, + Arc::new(new_aggr), + )?); + return Ok(Transformed::yes(new_proj)); + } + } + + Ok(Transformed::no(plan)) +} + +/// This is a placeholder for tumble_start and tumble_end function, so that datafusion can +/// recognize them as scalar function +#[derive(Debug)] +pub struct TumbleExpand { + signature: Signature, + name: String, +} + +impl TumbleExpand { + pub fn new(name: &str) -> Self { + Self { + signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable), + name: name.to_string(), + } + } +} + +impl ScalarUDFImpl for TumbleExpand { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + /// elide the signature for now + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types( + &self, + arg_types: &[arrow_schema::DataType], + ) -> datafusion_common::Result> { + match (arg_types.first(), arg_types.get(1), arg_types.get(2)) { + (Some(ts), Some(window), opt) => { + use arrow_schema::DataType::*; + if !matches!(ts, Date32 | Date64 | Timestamp(_, _)) { + return Err(DataFusionError::Plan( + format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts) + )); + } + if !matches!(window, Utf8 | Interval(_)) { + return Err(DataFusionError::Plan( + format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window), + )); + } + + if let Some(start_time) = opt{ + if !matches!(start_time, Utf8 | Date32 | Date64 | Timestamp(_, _)){ + return Err(DataFusionError::Plan( + format!("Expect start_time to either be date, timestampe or string, found {:?}", start_time) + )); + } + } + + Ok(arg_types.to_vec()) + } + _ => Err(DataFusionError::Plan( + "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(), + )), + } + } + + fn return_type( + &self, + arg_types: &[arrow_schema::DataType], + ) -> Result { + arg_types.first().cloned().ok_or_else(|| { + DataFusionError::Plan( + "Expect tumble function have at least two arg(timestamp column and window size)" + .to_string(), + ) + }) + } + + fn invoke( + &self, + _args: &[datafusion_expr::ColumnarValue], + ) -> Result { + Err(DataFusionError::Plan( + "This function should not be executed by datafusion".to_string(), + )) + } +} + +/// This rule check all group by exprs, and make sure they are also in select clause in a aggr query +struct CheckGroupByRule {} + +impl CheckGroupByRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for CheckGroupByRule { + fn analyze( + &self, + plan: datafusion_expr::LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + let transformed = plan + .transform_up_with_subqueries(check_group_by_analyzer)? + .data; + Ok(transformed) + } + + fn name(&self) -> &str { + "check_groupby" + } +} + +/// make sure everything in group by's expr is in select +fn check_group_by_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + if let datafusion_expr::LogicalPlan::Projection(proj) = &plan { + if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() { + let mut found_column_used = FindColumn::new(); + proj.expr + .iter() + .map(|i| i.visit(&mut found_column_used)) + .count(); + for expr in aggr.group_expr.iter() { + if !found_column_used + .names_for_alias + .contains(&expr.name_for_alias()?) + { + return Err(DataFusionError::Plan(format!("Expect {} expr in group by also exist in select list, but select list only contain {:?}",expr.name_for_alias()?, found_column_used.names_for_alias))); + } + } + } + } + + Ok(Transformed::no(plan)) +} + +/// Find all column names in a plan +#[derive(Debug, Default)] +struct FindColumn { + names_for_alias: HashSet, +} + +impl FindColumn { + fn new() -> Self { + Default::default() + } +} + +impl TreeNodeVisitor<'_> for FindColumn { + type Node = datafusion_expr::Expr; + fn f_down( + &mut self, + node: &datafusion_expr::Expr, + ) -> Result { + if let datafusion_expr::Expr::Column(_) = node { + self.names_for_alias.insert(node.name_for_alias()?); + } + Ok(TreeNodeRecursion::Continue) + } +} diff --git a/src/flow/src/expr.rs b/src/flow/src/expr.rs index 35f937cdc136..871b23c25dbc 100644 --- a/src/flow/src/expr.rs +++ b/src/flow/src/expr.rs @@ -37,6 +37,9 @@ use snafu::{ensure, ResultExt}; use crate::expr::error::DataTypeSnafu; +pub const TUMBLE_START: &str = "tumble_start"; +pub const TUMBLE_END: &str = "tumble_end"; + /// A batch of vectors with the same length but without schema, only useful in dataflow pub struct Batch { batch: Vec, diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index 143f1a82dda3..65da763e27d6 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -35,13 +35,13 @@ use snafu::{ensure, OptionExt, ResultExt}; use strum::{EnumIter, IntoEnumIterator}; use substrait::df_logical_plan::consumer::name_to_op; -use crate::error::{Error, ExternalSnafu, InvalidQuerySnafu, PlanSnafu}; +use crate::error::{Error, ExternalSnafu, InvalidQuerySnafu, PlanSnafu, UnexpectedSnafu}; use crate::expr::error::{ ArrowSnafu, CastValueSnafu, DataTypeSnafu, DivisionByZeroSnafu, EvalError, OverflowSnafu, TryFromValueSnafu, TypeMismatchSnafu, }; use crate::expr::signature::{GenericFn, Signature}; -use crate::expr::{Batch, InvalidArgumentSnafu, ScalarExpr, TypedExpr}; +use crate::expr::{Batch, InvalidArgumentSnafu, ScalarExpr, TypedExpr, TUMBLE_END, TUMBLE_START}; use crate::repr::{self, value_to_internal_ts}; /// UnmaterializableFunc is a function that can't be eval independently, @@ -87,42 +87,10 @@ impl UnmaterializableFunc { } /// Create a UnmaterializableFunc from a string of the function name - pub fn from_str_args(name: &str, args: Vec) -> Result { + pub fn from_str_args(name: &str, _args: Vec) -> Result { match name.to_lowercase().as_str() { "now" => Ok(Self::Now), "current_schema" => Ok(Self::CurrentSchema), - "tumble" => { - let ts = args.first().context(InvalidQuerySnafu { - reason: "Tumble window function requires a timestamp argument", - })?; - let window_size = args - .get(1) - .and_then(|expr| expr.expr.as_literal()) - .context(InvalidQuerySnafu { - reason: "Tumble window function requires a window size argument" - })?.as_string() // TODO(discord9): since df to substrait convertor does not support interval type yet, we need to take a string and cast it to interval instead - .map(|s|cast(Value::from(s), &ConcreteDataType::interval_month_day_nano_datatype())).transpose().map_err(BoxedError::new).context( - ExternalSnafu - )?.and_then(|v|v.as_interval()) - .with_context(||InvalidQuerySnafu { - reason: format!("Tumble window function requires window size argument to be a string describe a interval, found {:?}", args.get(1)) - })?; - let start_time = match args.get(2) { - Some(start_time) => start_time.expr.as_literal(), - None => None, - } - .map(|s| cast(s.clone(), &ConcreteDataType::datetime_datatype())).transpose().map_err(BoxedError::new).context(ExternalSnafu)?.map(|v|v.as_datetime().with_context( - ||InvalidQuerySnafu { - reason: format!("Tumble window function requires start time argument to be a datetime describe in string, found {:?}", args.get(2)) - } - )).transpose()?; - - Ok(Self::TumbleWindow { - ts: Box::new(ts.clone()), - window_size, - start_time, - }) - } _ => InvalidQuerySnafu { reason: format!("Unknown unmaterializable function: {}", name), } @@ -347,6 +315,96 @@ impl UnaryFunc { } } + pub fn from_tumble_func(name: &str, args: &[TypedExpr]) -> Result<(Self, TypedExpr), Error> { + match name.to_lowercase().as_str() { + TUMBLE_START | TUMBLE_END => { + let ts = args.first().context(InvalidQuerySnafu { + reason: "Tumble window function requires a timestamp argument", + })?; + let window_size = { + let window_size_untyped = args + .get(1) + .and_then(|expr| expr.expr.as_literal()) + .context(InvalidQuerySnafu { + reason: "Tumble window function requires a window size argument", + })?; + if let Some(window_size) = window_size_untyped.as_string() { + // cast as interval + cast( + Value::from(window_size), + &ConcreteDataType::interval_month_day_nano_datatype(), + ) + .map_err(BoxedError::new) + .context(ExternalSnafu)? + .as_interval() + .context(UnexpectedSnafu { + reason: "Expect window size arg to be interval after successful cast" + .to_string(), + })? + } else if let Some(interval) = window_size_untyped.as_interval() { + interval + } else { + InvalidQuerySnafu { + reason: format!( + "Tumble window function requires window size argument to be either a interval or a string describe a interval, found {:?}", + window_size_untyped + ) + }.fail()? + } + }; + + // start time argument is optional + let start_time = match args.get(2) { + Some(start_time) => { + if let Some(value) = start_time.expr.as_literal() { + // cast as DateTime + let ret = cast(value, &ConcreteDataType::datetime_datatype()) + .map_err(BoxedError::new) + .context(ExternalSnafu)? + .as_datetime() + .context(UnexpectedSnafu { + reason: + "Expect start time arg to be datetime after successful cast" + .to_string(), + })?; + Some(ret) + } else { + UnexpectedSnafu { + reason: "Expect start time arg to be literal", + } + .fail()? + } + } + None => None, + }; + + if name == TUMBLE_START { + Ok(( + Self::TumbleWindowFloor { + window_size, + start_time, + }, + ts.clone(), + )) + } else if name == TUMBLE_END { + Ok(( + Self::TumbleWindowCeiling { + window_size, + start_time, + }, + ts.clone(), + )) + } else { + unreachable!() + } + } + _ => crate::error::InternalSnafu { + reason: format!("Unknown tumble kind function: {}", name), + } + .fail()?, + } + } + /// Evaluate the function with given values and expression /// /// # Arguments @@ -712,8 +770,8 @@ impl BinaryFunc { t1 == t2, InvalidQuerySnafu { reason: format!( - "Binary function {:?} requires both arguments to have the same type", - generic + "Binary function {:?} requires both arguments to have the same type, left={:?}, right={:?}", + generic, t1, t2 ), } ); diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index b582c75114a1..c4d698529878 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -30,7 +30,7 @@ use crate::expr::error::{ }; use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; use crate::expr::{Batch, DfScalarFunction}; -use crate::repr::{ColumnType, RelationType}; +use crate::repr::ColumnType; /// A scalar expression with a known type. #[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)] pub struct TypedExpr { @@ -46,77 +46,6 @@ impl TypedExpr { } } -impl TypedExpr { - /// expand multi-value expression to multiple expressions with new indices - /// - /// Currently it just mean expand `TumbleWindow` to `TumbleWindowFloor` and `TumbleWindowCeiling` - /// - /// TODO(discord9): test if nested reduce combine with df scalar function would cause problem - pub fn expand_multi_value( - input_typ: &RelationType, - exprs: &[TypedExpr], - ) -> Result, Error> { - // old indices in mfp, expanded expr - let mut ret = vec![]; - let input_arity = input_typ.column_types.len(); - for (old_idx, expr) in exprs.iter().enumerate() { - if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::TumbleWindow { - ts, - window_size, - start_time, - }) = &expr.expr - { - let floor = UnaryFunc::TumbleWindowFloor { - window_size: *window_size, - start_time: *start_time, - }; - let ceil = UnaryFunc::TumbleWindowCeiling { - window_size: *window_size, - start_time: *start_time, - }; - let floor = ScalarExpr::CallUnary { - func: floor, - expr: Box::new(ts.expr.clone()), - } - .with_type(ts.typ.clone()); - ret.push((None, floor)); - - let ceil = ScalarExpr::CallUnary { - func: ceil, - expr: Box::new(ts.expr.clone()), - } - .with_type(ts.typ.clone()); - ret.push((None, ceil)); - } else { - ret.push((Some(input_arity + old_idx), expr.clone())) - } - } - - // get shuffled index(old_idx -> new_idx) - // note index is offset by input_arity because mfp is designed to be first include input columns then intermediate columns - let shuffle = ret - .iter() - .map(|(old_idx, _)| *old_idx) // [Option] - .enumerate() - .map(|(new, old)| (old, new + input_arity)) - .flat_map(|(old, new)| old.map(|o| (o, new))) - .chain((0..input_arity).map(|i| (i, i))) // also remember to chain the input columns as not changed - .collect::>(); - - // shuffle expr's index - let exprs = ret - .into_iter() - .map(|(_, mut expr)| { - // invariant: it is expect that no expr will try to refer the column being expanded - expr.expr.permute_map(&shuffle)?; - Ok(expr) - }) - .collect::, _>>()?; - - Ok(exprs) - } -} - /// A scalar expression, which can be evaluated to a value. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ScalarExpr { @@ -210,6 +139,13 @@ impl ScalarExpr { } impl ScalarExpr { + pub fn cast(self, typ: ConcreteDataType) -> Self { + ScalarExpr::CallUnary { + func: UnaryFunc::Cast(typ), + expr: Box::new(self), + } + } + /// apply optimization to the expression, like flatten variadic function pub fn optimize(&mut self) { self.flatten_varidic_fn(); diff --git a/src/flow/src/expr/signature.rs b/src/flow/src/expr/signature.rs index d61a60dea5e2..82506d1293c9 100644 --- a/src/flow/src/expr/signature.rs +++ b/src/flow/src/expr/signature.rs @@ -19,6 +19,8 @@ use serde::{Deserialize, Serialize}; use smallvec::SmallVec; /// Function signature +/// +/// TODO(discord9): use `common_query::signature::Signature` crate #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] pub struct Signature { /// the input types, usually not great than two input arg diff --git a/src/flow/src/lib.rs b/src/flow/src/lib.rs index d01e326427bb..ec93d5812870 100644 --- a/src/flow/src/lib.rs +++ b/src/flow/src/lib.rs @@ -23,6 +23,7 @@ // allow unused for now because it should be use later mod adapter; mod compute; +mod df_optimizer; pub mod error; mod expr; pub mod heartbeat; diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index dec70324f9fd..dc86b984ed23 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -115,6 +115,8 @@ impl TypedPlan { /// TODO(discord9): support `TableFunc`(by define FlatMap that map 1 to n) /// Plan describe how to transform data in dataflow +/// +/// This can be considered as a physical plan in dataflow, which describe how to transform data in a streaming manner. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum Plan { /// A constant collection of rows. diff --git a/src/flow/src/repr/relation.rs b/src/flow/src/repr/relation.rs index 65b75ffdcef8..54ad1c5e8ec4 100644 --- a/src/flow/src/repr/relation.rs +++ b/src/flow/src/repr/relation.rs @@ -374,10 +374,8 @@ impl RelationDesc { .collect(); let arrow_schema = arrow_schema::Schema::new(fields); - DFSchema::try_from(arrow_schema.clone()).context({ - DatafusionSnafu { - context: format!("Error when converting to DFSchema: {:?}", arrow_schema), - } + DFSchema::try_from(arrow_schema.clone()).with_context(|_e| DatafusionSnafu { + context: format!("Error when converting to DFSchema: {:?}", arrow_schema), }) } diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 5441617b93ab..f6dff58856db 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -17,24 +17,19 @@ use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use common_error::ext::BoxedError; -use datafusion::optimizer::simplify_expressions::SimplifyExpressions; -use datafusion::optimizer::{OptimizerContext, OptimizerRule}; use datatypes::data_type::ConcreteDataType as CDT; -use query::parser::QueryLanguageParser; -use query::plan::LogicalPlan; -use query::query_engine::DefaultSerializer; use query::QueryEngine; use serde::{Deserialize, Serialize}; use snafu::ResultExt; /// note here we are using the `substrait_proto_df` crate from the `substrait` module and /// rename it to `substrait_proto` -use substrait::{substrait_proto_df as substrait_proto, DFLogicalSubstraitConvertor}; +use substrait::substrait_proto_df as substrait_proto; use substrait_proto::proto::extensions::simple_extension_declaration::MappingType; use substrait_proto::proto::extensions::SimpleExtensionDeclaration; use crate::adapter::FlownodeContext; -use crate::error::{DatafusionSnafu, Error, ExternalSnafu, NotImplementedSnafu, UnexpectedSnafu}; -use crate::plan::TypedPlan; +use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu}; +use crate::expr::{TUMBLE_END, TUMBLE_START}; /// a simple macro to generate a not implemented error macro_rules! not_impl_err { ($($arg:tt)*) => { @@ -102,68 +97,39 @@ impl FunctionExtensions { } } -/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan, -/// then to a substrait plan, and finally to a flow plan. -pub async fn sql_to_flow_plan( - ctx: &mut FlownodeContext, - engine: &Arc, - sql: &str, -) -> Result { - let query_ctx = ctx.query_context.clone().ok_or_else(|| { - UnexpectedSnafu { - reason: "Query context is missing", - } - .build() - })?; - let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx) - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - let plan = engine - .planner() - .plan(stmt, query_ctx) - .await - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - let LogicalPlan::DfPlan(plan) = plan; - let plan = SimplifyExpressions::new() - .rewrite(plan, &OptimizerContext::default()) - .context(DatafusionSnafu { - context: "Fail to apply `SimplifyExpressions` optimization", - })? - .data; - let sub_plan = DFLogicalSubstraitConvertor {} - .to_sub_plan(&plan, DefaultSerializer) - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - - let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?; - - Ok(flow_plan) -} - /// register flow-specific functions to the query engine pub fn register_function_to_query_engine(engine: &Arc) { - engine.register_function(Arc::new(TumbleFunction {})); + engine.register_function(Arc::new(TumbleFunction::new("tumble"))); + engine.register_function(Arc::new(TumbleFunction::new(TUMBLE_START))); + engine.register_function(Arc::new(TumbleFunction::new(TUMBLE_END))); } #[derive(Debug)] -pub struct TumbleFunction {} +pub struct TumbleFunction { + name: String, +} -const TUMBLE_NAME: &str = "tumble"; +impl TumbleFunction { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + } + } +} impl std::fmt::Display for TumbleFunction { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", TUMBLE_NAME.to_ascii_uppercase()) + write!(f, "{}", self.name.to_ascii_uppercase()) } } impl common_function::function::Function for TumbleFunction { fn name(&self) -> &str { - TUMBLE_NAME + &self.name } fn return_type(&self, _input_types: &[CDT]) -> common_query::error::Result { - Ok(CDT::datetime_datatype()) + Ok(CDT::timestamp_millisecond_datatype()) } fn signature(&self) -> common_query::prelude::Signature { @@ -198,6 +164,7 @@ mod test { use prost::Message; use query::parser::QueryLanguageParser; use query::plan::LogicalPlan; + use query::query_engine::DefaultSerializer; use query::QueryEngine; use session::context::QueryContext; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; @@ -207,6 +174,7 @@ mod test { use super::*; use crate::adapter::node_context::IdToNameMap; + use crate::df_optimizer::apply_df_optimizer; use crate::expr::GlobalId; use crate::repr::{ColumnType, RelationType}; @@ -292,7 +260,7 @@ mod test { let factory = query::QueryEngineFactory::new(catalog_list, None, None, None, None, false); let engine = factory.query_engine(); - engine.register_function(Arc::new(TumbleFunction {})); + register_function_to_query_engine(&engine); assert_eq!("datafusion", engine.name()); engine @@ -307,6 +275,7 @@ mod test { .await .unwrap(); let LogicalPlan::DfPlan(plan) = plan; + let plan = apply_df_optimizer(plan).await.unwrap(); // encode then decode so to rely on the impl of conversion from logical plan to substrait plan let bytes = DFLogicalSubstraitConvertor {} @@ -315,4 +284,22 @@ mod test { proto::Plan::decode(bytes).unwrap() } + + /// TODO(discord9): add more illegal sql tests + #[tokio::test] + async fn test_missing_key_check() { + let engine = create_test_query_engine(); + let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number"; + + let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap(); + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) + .await + .unwrap(); + let LogicalPlan::DfPlan(plan) = plan; + let plan = apply_df_optimizer(plan).await; + + assert!(plan.is_err()); + } } diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index c07338047fe0..3eec5242cc00 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -12,10 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; - -use datatypes::data_type::DataType; -use datatypes::value::Value; use itertools::Itertools; use snafu::OptionExt; use substrait_proto::proto::aggregate_function::AggregationInvocation; @@ -25,7 +21,7 @@ use substrait_proto::proto::{self}; use crate::error::{Error, NotImplementedSnafu, PlanSnafu}; use crate::expr::{ - AggregateExpr, AggregateFunc, BinaryFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc, + AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc, }; use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan}; use crate::repr::{ColumnType, RelationDesc, RelationType}; @@ -66,10 +62,9 @@ impl AggregateExpr { measures: &[Measure], typ: &RelationDesc, extensions: &FunctionExtensions, - ) -> Result<(Vec, MapFilterProject), Error> { + ) -> Result, Error> { let _ = ctx; let mut all_aggr_exprs = vec![]; - let mut post_maps = vec![]; for m in measures { let filter = match m @@ -82,7 +77,7 @@ impl AggregateExpr { } .transpose()?; - let (aggr_expr, post_mfp) = match &m.measure { + let aggr_expr = match &m.measure { Some(f) => { let distinct = match f.invocation { _ if f.invocation == AggregationInvocation::Distinct as i32 => true, @@ -93,28 +88,17 @@ impl AggregateExpr { f, typ, extensions, &filter, // TODO(discord9): impl order_by &None, distinct, ) - .await + .await? } - None => not_impl_err!("Aggregate without aggregate function is not supported"), - }?; - // permute col index refer to the output of post_mfp, - // so to help construct a mfp at the end - let mut post_map = post_mfp.unwrap_or(ScalarExpr::Column(0)); - let cur_arity = all_aggr_exprs.len(); - let remap = (0..aggr_expr.len()).map(|i| i + cur_arity).collect_vec(); - post_map.permute(&remap)?; + None => { + return not_impl_err!("Aggregate without aggregate function is not supported") + } + }; all_aggr_exprs.extend(aggr_expr); - post_maps.push(post_map); } - let input_arity = all_aggr_exprs.len(); - let aggr_arity = post_maps.len(); - let post_mfp_final = MapFilterProject::new(all_aggr_exprs.len()) - .map(post_maps)? - .project(input_arity..input_arity + aggr_arity)?; - - Ok((all_aggr_exprs, post_mfp_final)) + Ok(all_aggr_exprs) } /// Convert AggregateFunction into Flow's AggregateExpr @@ -128,7 +112,7 @@ impl AggregateExpr { filter: &Option, order_by: &Option>, distinct: bool, - ) -> Result<(Vec, Option), Error> { + ) -> Result, Error> { // TODO(discord9): impl filter let _ = filter; let _ = order_by; @@ -159,7 +143,6 @@ impl AggregateExpr { .map(|s| s.to_lowercase()); match fn_name.as_ref().map(|s| s.as_ref()) { - Some(Self::AVG_NAME) => AggregateExpr::from_avg_aggr_func(arg), Some(function_name) => { let func = AggregateFunc::from_str_and_type( function_name, @@ -170,8 +153,7 @@ impl AggregateExpr { expr: arg.expr.clone(), distinct, }]; - let ret_mfp = None; - Ok((exprs, ret_mfp)) + Ok(exprs) } None => not_impl_err!( "Aggregated function not found: function anchor = {:?}", @@ -179,39 +161,6 @@ impl AggregateExpr { ), } } - const AVG_NAME: &'static str = "avg"; - /// convert `avg` function into `sum(x)/cast(count(x) as x_type)` - fn from_avg_aggr_func( - arg: &TypedExpr, - ) -> Result<(Vec, Option), Error> { - let arg_type = arg.typ.scalar_type.clone(); - let sum = AggregateExpr { - func: AggregateFunc::from_str_and_type("sum", Some(arg_type.clone()))?, - expr: arg.expr.clone(), - distinct: false, - }; - let sum_out_type = sum.func.signature().output.clone(); - let count = AggregateExpr { - func: AggregateFunc::Count, - expr: arg.expr.clone(), - distinct: false, - }; - let count_out_type = count.func.signature().output.clone(); - let avg_output = ScalarExpr::Column(0).call_binary( - ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())), - BinaryFunc::div(sum_out_type.clone())?, - ); - // make sure we wouldn't divide by zero - let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone()); - let non_zero = ScalarExpr::If { - cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)), - then: Box::new(avg_output), - els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())), - }; - let ret_aggr_exprs = vec![sum, count]; - let ret_mfp = Some(non_zero); - Ok((ret_aggr_exprs, ret_mfp)) - } } impl KeyValPlan { @@ -297,21 +246,13 @@ impl TypedPlan { return not_impl_err!("Aggregate without an input is not supported"); }; - let group_exprs = { - let group_exprs = TypedExpr::from_substrait_agg_grouping( - ctx, - &agg.groupings, - &input.schema, - extensions, - ) - .await?; - - TypedExpr::expand_multi_value(&input.schema.typ, &group_exprs)? - }; + let group_exprs = + TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.schema, extensions) + .await?; let time_index = find_time_index_in_group_exprs(&group_exprs); - let (mut aggr_exprs, post_mfp) = AggregateExpr::from_substrait_agg_measures( + let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures( ctx, &agg.measures, &input.schema, @@ -330,24 +271,13 @@ impl TypedPlan { let mut output_types = Vec::new(); // give best effort to get column name let mut output_names = Vec::new(); - // mark all auto added cols - let mut auto_cols = vec![]; + // first append group_expr as key, then aggr_expr as value - for (idx, expr) in group_exprs.iter().enumerate() { + for expr in group_exprs.iter() { output_types.push(expr.typ.clone()); let col_name = match &expr.expr { - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowFloor { .. }, - .. - } => Some("window_start".to_string()), - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowCeiling { .. }, - .. - } => { - auto_cols.push(idx); - Some("window_end".to_string()) - } ScalarExpr::Column(col) => input.schema.get_name(*col).clone(), + // TODO(discord9): impl& use ScalarExpr.display_name, which recursively build expr's name _ => None, }; output_names.push(col_name) @@ -367,7 +297,6 @@ impl TypedPlan { RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec()) } .with_time_index(time_index) - .with_autos(&auto_cols) .into_named(output_names) }; @@ -405,67 +334,30 @@ impl TypedPlan { reduce_plan: ReducePlan::Accumulable(accum_plan), }; // FIX(discord9): deal with key first - if post_mfp.is_identity() { - Ok(TypedPlan { - schema: output_type, - plan, - }) - } else { - // make post_mfp map identical mapping of keys - let input = TypedPlan { - schema: output_type.clone(), - plan, - }; - let key_arity = group_exprs.len(); - let mut post_mfp = post_mfp; - let val_arity = post_mfp.input_arity; - // offset post_mfp's col ref by `key_arity` - let shuffle = BTreeMap::from_iter((0..val_arity).map(|v| (v, v + key_arity))); - let new_arity = key_arity + val_arity; - post_mfp.permute(shuffle, new_arity)?; - // add key projection to post mfp - let (m, f, p) = post_mfp.into_map_filter_project(); - let p = (0..key_arity).chain(p).collect_vec(); - let post_mfp = MapFilterProject::new(new_arity) - .map(m)? - .filter(f)? - .project(p)?; - Ok(TypedPlan { - schema: output_type.apply_mfp(&post_mfp.clone().into_safe())?, - plan: Plan::Mfp { - input: Box::new(input), - mfp: post_mfp, - }, - }) - } + + return Ok(TypedPlan { + schema: output_type, + plan, + }); } } #[cfg(test)] mod test { + use std::collections::BTreeMap; + use bytes::BytesMut; use common_time::{DateTime, Interval}; use datatypes::prelude::ConcreteDataType; + use datatypes::value::Value; use pretty_assertions::assert_eq; use super::*; - use crate::expr::{DfScalarFunction, GlobalId, RawDfScalarFn}; + use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn}; use crate::plan::{Plan, TypedPlan}; use crate::repr::{ColumnType, RelationType}; use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; use crate::transform::CDT; - /// TODO(discord9): add more illegal sql tests - #[tokio::test] - async fn test_missing_key_check() { - let engine = create_test_query_engine(); - let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number"; - let plan = sql_to_substrait(engine.clone(), sql).await; - - let mut ctx = create_test_ctx(); - assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan) - .await - .is_err()); - } #[tokio::test] async fn test_df_func_basic() { @@ -479,21 +371,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("SUM(abs(numbers_with_ts.number))".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -513,7 +404,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -548,7 +441,7 @@ mod test { df_scalar_fn: DfScalarFunction::try_from_raw_fn( RawDfScalarFn { f: BytesMut::from( - b"\x08\x01\"\x08\x1a\x06\x12\x04\n\x02\x12\0" + b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0" .as_ref(), ), input_schema: RelationType::new(vec![ColumnType::new( @@ -558,9 +451,10 @@ mod test { .into_unnamed(), extensions: FunctionExtensions { anchor_to_name: BTreeMap::from([ - (0, "tumble".to_string()), - (1, "abs".to_string()), - (2, "sum".to_string()), + (0, "tumble_start".to_string()), + (1, "tumble_end".to_string()), + (2, "abs".to_string()), + (3, "sum".to_string()), ]), }, }, @@ -568,7 +462,8 @@ mod test { .await .unwrap(), exprs: vec![ScalarExpr::Column(0)], - }]) + } + .cast(CDT::uint64_datatype())]) .unwrap() .project(vec![2]) .unwrap() @@ -582,33 +477,27 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_unnamed(), ), ), mfp: MapFilterProject::new(3) .map(vec![ ScalarExpr::Column(2), - ScalarExpr::Column(3), ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; - assert_eq!(expected, flow_plan); + assert_eq!(flow_plan, expected); } #[tokio::test] @@ -623,21 +512,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("abs(SUM(numbers_with_ts.number))".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -657,7 +545,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -688,7 +578,9 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())]) + .unwrap() + .project(vec![2]) .unwrap() .into_safe(), }, @@ -700,23 +592,17 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_named(vec![None, None, None]), ), ), mfp: MapFilterProject::new(3) .map(vec![ - ScalarExpr::Column(2), ScalarExpr::CallDf { df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn { f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()), @@ -728,24 +614,25 @@ mod test { extensions: FunctionExtensions { anchor_to_name: BTreeMap::from([ (0, "abs".to_string()), - (1, "tumble".to_string()), - (2, "sum".to_string()), + (1, "tumble_start".to_string()), + (2, "tumble_end".to_string()), + (3, "sum".to_string()), ]), }, }) .await .unwrap(), - exprs: vec![ScalarExpr::Column(3)], + exprs: vec![ScalarExpr::Column(2)], }, ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; - assert_eq!(expected, flow_plan); + assert_eq!(flow_plan, expected); } /// TODO(discord9): add more illegal sql tests @@ -763,13 +650,13 @@ mod test { let aggr_exprs = vec![ AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }, AggregateExpr { func: AggregateFunc::Count, - expr: ScalarExpr::Column(0), + expr: ScalarExpr::Column(1), distinct: false, }, ]; @@ -778,11 +665,15 @@ mod test { ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), - then: Box::new(ScalarExpr::Column(3).call_binary( - ScalarExpr::Column(4).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), - BinaryFunc::DivUInt64, - )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + then: Box::new( + ScalarExpr::Column(3) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Column(4).cast(CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())), }; let expected = TypedPlan { plan: Plan::Mfp { @@ -801,7 +692,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -833,7 +726,12 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ + ScalarExpr::Column(0).cast(CDT::uint64_datatype()), + ScalarExpr::Column(0), + ]) + .unwrap() + .project(vec![2, 3]) .unwrap() .into_safe(), }, @@ -841,7 +739,7 @@ mod test { full_aggrs: aggr_exprs.clone(), simple_aggrs: vec![ AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), ], distinct_aggrs: vec![], }), @@ -849,19 +747,18 @@ mod test { .with_types( RelationType::new(vec![ // keys - ColumnType::new(CDT::datetime_datatype(), false), // window start(time index) - ColumnType::new(CDT::datetime_datatype(), false), // window end(pk) - ColumnType::new(CDT::uint32_datatype(), false), // number(pk) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start(time index) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end(pk) + ColumnType::new(CDT::uint32_datatype(), false), // number(pk) // values ColumnType::new(CDT::uint64_datatype(), true), // avg.sum(number) ColumnType::new(CDT::int64_datatype(), true), // avg.count(number) ]) .with_key(vec![1, 2]) .with_time_index(Some(0)) - .with_autos(&[1]) .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), + None, + None, Some("number".to_string()), None, None, @@ -870,28 +767,26 @@ mod test { ), mfp: MapFilterProject::new(5) .map(vec![ - avg_expr, ScalarExpr::Column(2), // number(pk) - ScalarExpr::Column(5), // avg.sum(number) + avg_expr, ScalarExpr::Column(0), // window start ScalarExpr::Column(1), // window end ]) .unwrap() - .project(vec![6, 7, 8, 9]) + .project(vec![5, 6, 7, 8]) .unwrap(), }, schema: RelationType::new(vec![ ColumnType::new(CDT::uint32_datatype(), false), // number - ColumnType::new(CDT::uint64_datatype(), true), // avg(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::float64_datatype(), true), // avg(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![0, 3]) .with_time_index(Some(2)) - .with_autos(&[3]) .into_named(vec![ - Some("number".to_string()), - None, + Some("numbers_with_ts.number".to_string()), + Some("AVG(numbers_with_ts.number)".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -911,21 +806,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("SUM(numbers_with_ts.number)".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -945,7 +839,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -976,7 +872,9 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())]) + .unwrap() + .project(vec![2]) .unwrap() .into_safe(), }, @@ -988,29 +886,23 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_named(vec![None, None, None]), ), ), mfp: MapFilterProject::new(3) .map(vec![ ScalarExpr::Column(2), - ScalarExpr::Column(3), ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; @@ -1029,21 +921,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("SUM(numbers_with_ts.number)".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -1063,7 +954,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -1094,7 +987,9 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())]) + .unwrap() + .project(vec![2]) .unwrap() .into_safe(), }, @@ -1106,29 +1001,23 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_unnamed(), ), ), mfp: MapFilterProject::new(3) .map(vec![ ScalarExpr::Column(2), - ScalarExpr::Column(3), ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; @@ -1146,13 +1035,13 @@ mod test { let aggr_exprs = vec![ AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }, AggregateExpr { func: AggregateFunc::Count, - expr: ScalarExpr::Column(0), + expr: ScalarExpr::Column(1), distinct: false, }, ]; @@ -1161,19 +1050,26 @@ mod test { ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), - then: Box::new(ScalarExpr::Column(1).call_binary( - ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), - BinaryFunc::DivUInt64, - )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + then: Box::new( + ScalarExpr::Column(1) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Column(2).cast(CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())), }; let expected = TypedPlan { schema: RelationType::new(vec![ - ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64 + ColumnType::new(CDT::float64_datatype(), true), // avg(number: u32) -> f64 ColumnType::new(CDT::uint32_datatype(), false), // number ]) .with_key(vec![1]) - .into_named(vec![None, Some("number".to_string())]), + .into_named(vec![ + Some("AVG(numbers.number)".to_string()), + Some("numbers.number".to_string()), + ]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -1187,7 +1083,14 @@ mod test { false, )]) .into_named(vec![Some("number".to_string())]), - ), + ) + .mfp( + MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + ) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(1) @@ -1197,7 +1100,12 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(1) - .project(vec![0]) + .map(vec![ + ScalarExpr::Column(0).cast(CDT::uint64_datatype()), + ScalarExpr::Column(0), + ]) + .unwrap() + .project(vec![1, 2]) .unwrap() .into_safe(), }, @@ -1205,7 +1113,7 @@ mod test { full_aggrs: aggr_exprs.clone(), simple_aggrs: vec![ AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), ], distinct_aggrs: vec![], }), @@ -1227,12 +1135,11 @@ mod test { mfp: MapFilterProject::new(3) .map(vec![ avg_expr, // col 3 + ScalarExpr::Column(0), // TODO(discord9): optimize mfp so to remove indirect ref - ScalarExpr::Column(3), // col 4 - ScalarExpr::Column(0), // col 5 ]) .unwrap() - .project(vec![4, 5]) + .project(vec![3, 4]) .unwrap(), }, }; @@ -1253,13 +1160,13 @@ mod test { let aggr_exprs = vec![ AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }, AggregateExpr { func: AggregateFunc::Count, - expr: ScalarExpr::Column(0), + expr: ScalarExpr::Column(1), distinct: false, }, ]; @@ -1268,25 +1175,42 @@ mod test { ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), - then: Box::new(ScalarExpr::Column(0).call_binary( - ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), - BinaryFunc::DivUInt64, - )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + then: Box::new( + ScalarExpr::Column(0) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Column(1).cast(CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())), }; + let input = Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + } + .with_types( + RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ), + ); let expected = TypedPlan { - schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_named(vec![None]), + schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)]) + .into_named(vec![Some("AVG(numbers.number)".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { input: Box::new( - Plan::Get { - id: crate::expr::Id::Global(GlobalId::User(0)), + Plan::Mfp { + input: input.clone(), + mfp: MapFilterProject::new(1).project(vec![0]).unwrap(), } .with_types( RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), + CDT::uint32_datatype(), false, )]) .into_named(vec![Some("number".to_string())]), @@ -1298,7 +1222,12 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(1) - .project(vec![0]) + .map(vec![ + ScalarExpr::Column(0).cast(CDT::uint64_datatype()), + ScalarExpr::Column(0), + ]) + .unwrap() + .project(vec![1, 2]) .unwrap() .into_safe(), }, @@ -1306,7 +1235,7 @@ mod test { full_aggrs: aggr_exprs.clone(), simple_aggrs: vec![ AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), ], distinct_aggrs: vec![], }), @@ -1323,10 +1252,9 @@ mod test { .map(vec![ avg_expr, // TODO(discord9): optimize mfp so to remove indirect ref - ScalarExpr::Column(2), ]) .unwrap() - .project(vec![3]) + .project(vec![2]) .unwrap(), }, }; @@ -1341,56 +1269,48 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; - let typ = RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint64_datatype(), - true, - )]); + let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_unnamed(), - plan: Plan::Mfp { + .into_named(vec![Some("SUM(numbers.number)".to_string())]), + plan: Plan::Reduce { input: Box::new( - Plan::Reduce { - input: Box::new( - Plan::Get { - id: crate::expr::Id::Global(GlobalId::User(0)), - } - .with_types( - RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), - false, - )]) - .into_named(vec![Some("number".to_string())]), - ), - ), - key_val_plan: KeyValPlan { - key_plan: MapFilterProject::new(1) - .project(vec![]) - .unwrap() - .into_safe(), - val_plan: MapFilterProject::new(1) - .project(vec![0]) - .unwrap() - .into_safe(), - }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], - }), + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), } - .with_types(typ.into_unnamed()), - ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]) - .unwrap() - .project(vec![2]) + .with_types( + RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ) + .mfp(MapFilterProject::new(1).into_safe()) .unwrap(), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0) + .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], + distinct_aggrs: vec![], + }), }, }; assert_eq!(flow_plan.unwrap(), expected); @@ -1408,7 +1328,7 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; @@ -1418,7 +1338,10 @@ mod test { ColumnType::new(CDT::uint32_datatype(), false), // col number ]) .with_key(vec![1]) - .into_named(vec![None, Some("number".to_string())]), + .into_named(vec![ + Some("SUM(numbers.number)".to_string()), + Some("numbers.number".to_string()), + ]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -1432,7 +1355,9 @@ mod test { false, )]) .into_named(vec![Some("number".to_string())]), - ), + ) + .mfp(MapFilterProject::new(1).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(1) @@ -1442,7 +1367,10 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(1) - .project(vec![0]) + .map(vec![ScalarExpr::Column(0) + .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))]) + .unwrap() + .project(vec![1]) .unwrap() .into_safe(), }, @@ -1462,13 +1390,9 @@ mod test { ), ), mfp: MapFilterProject::new(2) - .map(vec![ - ScalarExpr::Column(1), - ScalarExpr::Column(2), - ScalarExpr::Column(0), - ]) + .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)]) .unwrap() - .project(vec![3, 4]) + .project(vec![2, 3]) .unwrap(), }, }; @@ -1486,16 +1410,18 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_unnamed(), - plan: Plan::Mfp { + .into_named(vec![Some( + "SUM(numbers.number + numbers.number)".to_string(), + )]), + plan: Plan::Reduce { input: Box::new( - Plan::Reduce { + Plan::Mfp { input: Box::new( Plan::Get { id: crate::expr::Id::Global(GlobalId::User(0)), @@ -1508,37 +1434,176 @@ mod test { .into_named(vec![Some("number".to_string())]), ), ), + mfp: MapFilterProject::new(1), + } + .with_types( + RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0) + .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32) + .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], + distinct_aggrs: vec![], + }), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + + #[tokio::test] + async fn test_cast_max_min() { + let engine = create_test_query_engine(); + let sql = "SELECT (max(number) - min(number))/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; + + let aggr_exprs = vec![ + AggregateExpr { + func: AggregateFunc::MaxUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }, + AggregateExpr { + func: AggregateFunc::MinUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }, + ]; + let expected = TypedPlan { + schema: RelationType::new(vec![ + ColumnType::new(CDT::float64_datatype(), true), + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), + ]) + .with_key(vec![1]) + .into_named(vec![ + Some( + "MAX(numbers_with_ts.number) - MIN(numbers_with_ts.number) / Float64(30)" + .to_string(), + ), + Some("time_window".to_string()), + ]), + plan: Plan::Mfp { + input: Box::new( + Plan::Reduce { + input: Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(1)), + } + .with_types( + RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ColumnType::new(ConcreteDataType::datetime_datatype(), false), + ]) + .into_named(vec![ + Some("number".to_string()), + Some("ts".to_string()), + ]), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), + ), + key_val_plan: KeyValPlan { - key_plan: MapFilterProject::new(1) - .project(vec![]) - .unwrap() - .into_safe(), - val_plan: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0) - .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)]) + key_plan: MapFilterProject::new(2) + .map(vec![ScalarExpr::CallDf { + df_scalar_fn: DfScalarFunction::try_from_raw_fn( + RawDfScalarFn { + f: BytesMut::from( + b"\x08\x02\"I\x1aG\nE\x8a\x02?\x08\x03\x12+\n\x17interval-month-day-nano\x12\x10\0\xac#\xfc\x06\0\0\0\0\0\0\0\0\0\0\0\x1a\x06\x12\x04:\x02\x10\x02\x1a\x06\x12\x04:\x02\x10\x02\x98\x03\x03\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(), + ), + input_schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::interval_month_day_nano_datatype(), + true, + ),ColumnType::new( + ConcreteDataType::timestamp_millisecond_datatype(), + true, + )]) + .into_unnamed(), + extensions: FunctionExtensions { + anchor_to_name: BTreeMap::from([ + (0, "subtract".to_string()), + (1, "divide".to_string()), + (2, "date_bin".to_string()), + (3, "max".to_string()), + (4, "min".to_string()), + ]), + }, + }, + ) + .await + .unwrap(), + exprs: vec![ + ScalarExpr::Literal( + Value::Interval(Interval::from_month_day_nano(0, 0, 30000000000)), + CDT::interval_month_day_nano_datatype() + ), + ScalarExpr::Column(1).cast(CDT::timestamp_millisecond_datatype()) + ], + }]) .unwrap() - .project(vec![1]) + .project(vec![2]) .unwrap() .into_safe(), + val_plan: MapFilterProject::new(2) + .into_safe(), }, reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], + full_aggrs: aggr_exprs.clone(), + simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), + AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)], distinct_aggrs: vec![], }), } .with_types( - RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_unnamed(), + RelationType::new(vec![ + ColumnType::new( + ConcreteDataType::timestamp_millisecond_datatype(), + true, + ), // time_window + ColumnType::new(ConcreteDataType::uint32_datatype(), true), // max + ColumnType::new(ConcreteDataType::uint32_datatype(), true), // min + ]) + .with_key(vec![0]) + .into_unnamed(), ), ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]) + mfp: MapFilterProject::new(3) + .map(vec![ + ScalarExpr::Column(1) + .call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ScalarExpr::Column(0), + ]) .unwrap() - .project(vec![2]) + .project(vec![3, 4]) .unwrap(), }, }; + assert_eq!(flow_plan.unwrap(), expected); } } diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index 5848dc66b674..de05b018ac51 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -20,7 +20,7 @@ use common_error::ext::BoxedError; use common_telemetry::debug; use datafusion_physical_expr::PhysicalExpr; use datatypes::data_type::ConcreteDataType as CDT; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference; use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField; use substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction}; @@ -33,7 +33,7 @@ use crate::error::{ }; use crate::expr::{ BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc, - UnmaterializableFunc, VariadicFunc, + UnmaterializableFunc, VariadicFunc, TUMBLE_END, TUMBLE_START, }; use crate::repr::{ColumnType, RelationDesc, RelationType}; use crate::transform::literal::{ @@ -167,6 +167,16 @@ fn rewrite_scalar_function( arg_typed_exprs: &[TypedExpr], ) -> Result { let mut f_rewrite = f.clone(); + ensure!( + f_rewrite.arguments.len() == arg_typed_exprs.len(), + crate::error::InternalSnafu { + reason: format!( + "Expect `f_rewrite` and `arg_typed_expr` to be same length, found {} and {}", + f_rewrite.arguments.len(), + arg_typed_exprs.len() + ) + } + ); for (idx, raw_expr) in f_rewrite.arguments.iter_mut().enumerate() { // only replace it with col(idx) if it is not literal // will try best to determine if it is literal, i.e. for function like `cast()` will try @@ -351,7 +361,13 @@ impl TypedExpr { Ok(TypedExpr::new(ret_expr, ret_type)) } _var => { - if VariadicFunc::is_valid_func_name(fn_name) { + if fn_name == TUMBLE_START || fn_name == TUMBLE_END { + let (func, arg) = UnaryFunc::from_tumble_func(fn_name, &arg_typed_exprs)?; + + let ret_type = ColumnType::new_nullable(func.signature().output.clone()); + + Ok(TypedExpr::new(arg.expr.call_unary(func), ret_type)) + } else if VariadicFunc::is_valid_func_name(fn_name) { let func = VariadicFunc::from_str_and_types(fn_name, &arg_types)?; let ret_type = ColumnType::new_nullable(func.signature().output.clone()); let mut expr = ScalarExpr::CallVariadic { @@ -521,7 +537,6 @@ impl TypedExpr { #[cfg(test)] mod test { - use common_time::{DateTime, Interval}; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; use pretty_assertions::assert_eq; @@ -562,7 +577,7 @@ mod test { }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]) - .into_named(vec![Some("number".to_string())]), + .into_named(vec![Some("numbers.number".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -576,13 +591,7 @@ mod test { .into_named(vec![Some("number".to_string())]), ), ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0)]) - .unwrap() - .filter(vec![filter]) - .unwrap() - .project(vec![1]) - .unwrap(), + mfp: MapFilterProject::new(1).filter(vec![filter]).unwrap(), }, }; assert_eq!(flow_plan.unwrap(), expected); @@ -600,7 +609,7 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]) - .into_unnamed(), + .into_named(vec![Some("Int64(1) + Int64(1) * Int64(2) - Int64(1) / Int64(1) + Int64(1) % Int64(2) = Int64(3)".to_string())]), plan: Plan::Constant { rows: vec![( repr::Row::new(vec![Value::from(true)]), @@ -624,8 +633,8 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { - schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) - .into_unnamed(), + schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]) + .into_named(vec![Some("numbers.number + Int64(1)".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -640,10 +649,12 @@ mod test { ), ), mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0).call_binary( - ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()), - BinaryFunc::AddUInt32, - )]) + .map(vec![ScalarExpr::Column(0) + .call_unary(UnaryFunc::Cast(CDT::int64_datatype())) + .call_binary( + ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()), + BinaryFunc::AddInt64, + )]) .unwrap() .project(vec![1]) .unwrap(), @@ -663,7 +674,9 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]) - .into_unnamed(), + .into_named(vec![Some( + "arrow_cast(Int64(1),Utf8(\"Int16\"))".to_string(), + )]), plan: Plan::Constant { // cast of literal is constant folded rows: vec![(repr::Row::new(vec![Value::from(1i16)]), i64::MIN, 1)], @@ -683,7 +696,7 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) - .into_unnamed(), + .into_named(vec![Some("numbers.number + numbers.number".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -780,65 +793,5 @@ mod test { }, } ); - - let f = substrait_proto::proto::expression::ScalarFunction { - function_reference: 0, - arguments: vec![proto_col(0), lit("1 second"), lit("2021-07-01 00:00:00")], - options: vec![], - output_type: None, - ..Default::default() - }; - let input_schema = RelationType::new(vec![ - ColumnType::new(CDT::timestamp_nanosecond_datatype(), false), - ColumnType::new(CDT::string_datatype(), false), - ]) - .into_unnamed(); - let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) - .await - .unwrap(); - - assert_eq!( - res, - ScalarExpr::CallUnmaterializable(UnmaterializableFunc::TumbleWindow { - ts: Box::new( - ScalarExpr::Column(0) - .with_type(ColumnType::new(CDT::timestamp_nanosecond_datatype(), false)) - ), - window_size: Interval::from_month_day_nano(0, 0, 1_000_000_000), - start_time: Some(DateTime::new(1625097600000)) - }) - .with_type(ColumnType::new(CDT::timestamp_millisecond_datatype(), true)), - ); - - let f = substrait_proto::proto::expression::ScalarFunction { - function_reference: 0, - arguments: vec![proto_col(0), lit("1 second")], - options: vec![], - output_type: None, - ..Default::default() - }; - let input_schema = RelationType::new(vec![ - ColumnType::new(CDT::timestamp_nanosecond_datatype(), false), - ColumnType::new(CDT::string_datatype(), false), - ]) - .into_unnamed(); - let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) - .await - .unwrap(); - - assert_eq!( - res, - ScalarExpr::CallUnmaterializable(UnmaterializableFunc::TumbleWindow { - ts: Box::new( - ScalarExpr::Column(0) - .with_type(ColumnType::new(CDT::timestamp_nanosecond_datatype(), false)) - ), - window_size: Interval::from_month_day_nano(0, 0, 1_000_000_000), - start_time: None - }) - .with_type(ColumnType::new(CDT::timestamp_millisecond_datatype(), true)), - ) } } diff --git a/src/flow/src/transform/literal.rs b/src/flow/src/transform/literal.rs index 01e06e96830e..f9dd8b953553 100644 --- a/src/flow/src/transform/literal.rs +++ b/src/flow/src/transform/literal.rs @@ -340,6 +340,8 @@ pub fn from_substrait_type(null_type: &substrait_proto::proto::Type) -> Result plan_err!("Cannot parse plan relation: None") @@ -115,17 +122,6 @@ impl TypedPlan { plan, }) } else { - match input.plan.clone() { - Plan::Reduce { key_val_plan, .. } => { - rewrite_projection_after_reduce(key_val_plan, &input.schema, &mut exprs)?; - } - Plan::Mfp { input, mfp: _ } => { - if let Plan::Reduce { key_val_plan, .. } = input.plan { - rewrite_projection_after_reduce(key_val_plan, &input.schema, &mut exprs)?; - } - } - _ => (), - } input.projection(exprs) } } @@ -233,120 +229,13 @@ impl TypedPlan { } } -/// if reduce_plan contains the special function like tumble floor/ceiling, add them to the proj_exprs -/// so the effect is the window_start, window_end column are auto added to output rows -/// -/// This is to fix a problem that we have certain functions that return two values, but since substrait doesn't know that, it will assume it return one value -/// this function fix that and rewrite `proj_exprs` to correct form -fn rewrite_projection_after_reduce( - key_val_plan: KeyValPlan, - reduce_output_type: &RelationDesc, - proj_exprs: &mut Vec, -) -> Result<(), Error> { - // TODO(discord9): get keys correctly - let key_exprs = key_val_plan - .key_plan - .projection - .clone() - .into_iter() - .map(|i| { - if i < key_val_plan.key_plan.input_arity { - ScalarExpr::Column(i) - } else { - key_val_plan.key_plan.expressions[i - key_val_plan.key_plan.input_arity].clone() - } - }) - .collect_vec(); - let mut shift_offset = 0; - let mut shuffle: BTreeMap = BTreeMap::new(); - let special_keys = key_exprs - .clone() - .into_iter() - .enumerate() - .filter(|(idx, p)| { - shuffle.insert(*idx, *idx + shift_offset); - if matches!( - p, - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowFloor { .. }, - .. - } | ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowCeiling { .. }, - .. - } - ) { - if matches!( - p, - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowFloor { .. }, - .. - } - ) { - shift_offset += 1; - } - true - } else { - false - } - }) - .collect_vec(); - let spec_key_arity = special_keys.len(); - if spec_key_arity == 0 { - return Ok(()); - } - - // shuffle proj_exprs - // because substrait use offset while assume `tumble` only return one value - for proj_expr in proj_exprs.iter_mut() { - proj_expr.expr.permute_map(&shuffle)?; - } // add key to the end - for (key_idx, _key_expr) in special_keys { - // here we assume the output type of reduce operator(`reduce_output_type`) is just first keys columns, then append value columns - // so we can use `key_idx` to index `reduce_output_type` and get the keys we need to append to `proj_exprs` - proj_exprs.push( - ScalarExpr::Column(key_idx) - .with_type(reduce_output_type.typ().column_types[key_idx].clone()), - ); - } - - // check if normal expr in group exprs are all in proj_exprs - let all_cols_ref_in_proj: BTreeSet = proj_exprs - .iter() - .filter_map(|e| { - if let ScalarExpr::Column(i) = &e.expr { - Some(*i) - } else { - None - } - }) - .collect(); - for (key_idx, key_expr) in key_exprs.iter().enumerate() { - if let ScalarExpr::Column(_) = key_expr { - if !all_cols_ref_in_proj.contains(&key_idx) { - let err_msg = format!( - "Expect normal column in group by also appear in projection, but column {}(name is {}) is missing", - key_idx, - reduce_output_type - .get_name(key_idx) - .clone() - .map(|s|format!("'{}'",s)) - .unwrap_or("unknown".to_string()) - ); - return InvalidQuerySnafu { reason: err_msg }.fail(); - } - } - } - - Ok(()) -} - #[cfg(test)] mod test { use datatypes::prelude::ConcreteDataType; use pretty_assertions::assert_eq; use super::*; - use crate::expr::{GlobalId, ScalarExpr}; + use crate::expr::GlobalId; use crate::plan::{Plan, TypedPlan}; use crate::repr::{ColumnType, RelationType}; use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; @@ -363,7 +252,7 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]) - .into_named(vec![Some("number".to_string())]), + .into_named(vec![Some("numbers.number".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -377,11 +266,7 @@ mod test { .into_named(vec![Some("number".to_string())]), ), ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0)]) - .unwrap() - .project(vec![1]) - .unwrap(), + mfp: MapFilterProject::new(1), }, }; diff --git a/tests/cases/standalone/common/flow/flow_basic.result b/tests/cases/standalone/common/flow/flow_basic.result index d4fb6276bd8a..db3d3c8c3b3d 100644 --- a/tests/cases/standalone/common/flow/flow_basic.result +++ b/tests/cases/standalone/common/flow/flow_basic.result @@ -40,13 +40,13 @@ admin flush_flow('test_numbers_basic'); | 1 | +----------------------------------------+ -SELECT col_0, window_start, window_end FROM out_num_cnt_basic; +SELECT "SUM(numbers_input_basic.number)", window_start, window_end FROM out_num_cnt_basic; -+-------+---------------------+---------------------+ -| col_0 | window_start | window_end | -+-------+---------------------+---------------------+ -| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | -+-------+---------------------+---------------------+ ++---------------------------------+---------------------+---------------------+ +| SUM(numbers_input_basic.number) | window_start | window_end | ++---------------------------------+---------------------+---------------------+ +| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | ++---------------------------------+---------------------+---------------------+ admin flush_flow('test_numbers_basic'); @@ -71,14 +71,15 @@ admin flush_flow('test_numbers_basic'); | 1 | +----------------------------------------+ -SELECT col_0, window_start, window_end FROM out_num_cnt_basic; +-- note that this quote-unquote column is a column-name, **not** a aggregation expr, generated by datafusion +SELECT "SUM(numbers_input_basic.number)", window_start, window_end FROM out_num_cnt_basic; -+-------+---------------------+---------------------+ -| col_0 | window_start | window_end | -+-------+---------------------+---------------------+ -| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | -| 47 | 2021-07-01T00:00:01 | 2021-07-01T00:00:02 | -+-------+---------------------+---------------------+ ++---------------------------------+---------------------+---------------------+ +| SUM(numbers_input_basic.number) | window_start | window_end | ++---------------------------------+---------------------+---------------------+ +| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | +| 47 | 2021-07-01T00:00:01 | 2021-07-01T00:00:02 | ++---------------------------------+---------------------+---------------------+ DROP FLOW test_numbers_basic; @@ -142,8 +143,9 @@ CREATE TABLE bytes_log ( Affected Rows: 0 +-- TODO(discord9): remove this after auto infer table's time index is impl CREATE TABLE approx_rate ( - rate FLOAT, + rate DOUBLE, time_window TIMESTAMP, update_at TIMESTAMP, TIME INDEX(time_window) @@ -154,7 +156,7 @@ Affected Rows: 0 CREATE FLOW find_approx_rate SINK TO approx_rate AS -SELECT CAST((max(byte) - min(byte)) AS FLOAT)/30.0, date_bin(INTERVAL '30 second', ts) as time_window from bytes_log GROUP BY time_window; +SELECT (max(byte) - min(byte))/30.0 as rate, date_bin(INTERVAL '30 second', ts) as time_window from bytes_log GROUP BY time_window; Affected Rows: 0 @@ -174,11 +176,11 @@ admin flush_flow('find_approx_rate'); SELECT rate, time_window FROM approx_rate; -+----------+---------------------+ -| rate | time_window | -+----------+---------------------+ -| 6.633333 | 2025-01-01T00:00:00 | -+----------+---------------------+ ++-------------------+---------------------+ +| rate | time_window | ++-------------------+---------------------+ +| 6.633333333333334 | 2025-01-01T00:00:00 | ++-------------------+---------------------+ INSERT INTO bytes_log VALUES (450, '2025-01-01 00:00:32'), @@ -196,12 +198,12 @@ admin flush_flow('find_approx_rate'); SELECT rate, time_window FROM approx_rate; -+-----------+---------------------+ -| rate | time_window | -+-----------+---------------------+ -| 6.633333 | 2025-01-01T00:00:00 | -| 1.6666666 | 2025-01-01T00:00:30 | -+-----------+---------------------+ ++--------------------+---------------------+ +| rate | time_window | ++--------------------+---------------------+ +| 6.633333333333334 | 2025-01-01T00:00:00 | +| 1.6666666666666667 | 2025-01-01T00:00:30 | ++--------------------+---------------------+ DROP TABLE bytes_log; diff --git a/tests/cases/standalone/common/flow/flow_basic.sql b/tests/cases/standalone/common/flow/flow_basic.sql index e60764463e68..b9ccc810585c 100644 --- a/tests/cases/standalone/common/flow/flow_basic.sql +++ b/tests/cases/standalone/common/flow/flow_basic.sql @@ -22,7 +22,7 @@ VALUES admin flush_flow('test_numbers_basic'); -SELECT col_0, window_start, window_end FROM out_num_cnt_basic; +SELECT "SUM(numbers_input_basic.number)", window_start, window_end FROM out_num_cnt_basic; admin flush_flow('test_numbers_basic'); @@ -33,7 +33,8 @@ VALUES admin flush_flow('test_numbers_basic'); -SELECT col_0, window_start, window_end FROM out_num_cnt_basic; +-- note that this quote-unquote column is a column-name, **not** a aggregation expr, generated by datafusion +SELECT "SUM(numbers_input_basic.number)", window_start, window_end FROM out_num_cnt_basic; DROP FLOW test_numbers_basic; DROP TABLE numbers_input_basic; @@ -67,8 +68,9 @@ CREATE TABLE bytes_log ( TIME INDEX(ts) ); +-- TODO(discord9): remove this after auto infer table's time index is impl CREATE TABLE approx_rate ( - rate FLOAT, + rate DOUBLE, time_window TIMESTAMP, update_at TIMESTAMP, TIME INDEX(time_window) @@ -77,7 +79,7 @@ CREATE TABLE approx_rate ( CREATE FLOW find_approx_rate SINK TO approx_rate AS -SELECT CAST((max(byte) - min(byte)) AS FLOAT)/30.0, date_bin(INTERVAL '30 second', ts) as time_window from bytes_log GROUP BY time_window; +SELECT (max(byte) - min(byte))/30.0 as rate, date_bin(INTERVAL '30 second', ts) as time_window from bytes_log GROUP BY time_window; INSERT INTO bytes_log VALUES (101, '2025-01-01 00:00:01'), diff --git a/tests/cases/standalone/common/flow/flow_call_df_func.result b/tests/cases/standalone/common/flow/flow_call_df_func.result index 00f659550fb2..0a8f4218bdea 100644 --- a/tests/cases/standalone/common/flow/flow_call_df_func.result +++ b/tests/cases/standalone/common/flow/flow_call_df_func.result @@ -39,13 +39,14 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +-- note that this quote-unquote column is a column-name, **not** a aggregation expr, generated by datafusion +SELECT "SUM(abs(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; -+-------+---------------------+---------------------+ -| col_0 | window_start | window_end | -+-------+---------------------+---------------------+ -| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | -+-------+---------------------+---------------------+ ++----------------------------------------+---------------------+---------------------+ +| SUM(abs(numbers_input_df_func.number)) | window_start | window_end | ++----------------------------------------+---------------------+---------------------+ +| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | ++----------------------------------------+---------------------+---------------------+ admin flush_flow('test_numbers_df_func'); @@ -70,14 +71,15 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +-- note that this quote-unquote column is a column-name, **not** a aggregation expr, generated by datafusion +SELECT "SUM(abs(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; -+-------+---------------------+---------------------+ -| col_0 | window_start | window_end | -+-------+---------------------+---------------------+ -| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | -| 47 | 2021-07-01T00:00:01 | 2021-07-01T00:00:02 | -+-------+---------------------+---------------------+ ++----------------------------------------+---------------------+---------------------+ +| SUM(abs(numbers_input_df_func.number)) | window_start | window_end | ++----------------------------------------+---------------------+---------------------+ +| 42 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | +| 47 | 2021-07-01T00:00:01 | 2021-07-01T00:00:02 | ++----------------------------------------+---------------------+---------------------+ DROP FLOW test_numbers_df_func; @@ -132,13 +134,13 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +SELECT "abs(SUM(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; -+-------+---------------------+---------------------+ -| col_0 | window_start | window_end | -+-------+---------------------+---------------------+ -| 2 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | -+-------+---------------------+---------------------+ ++----------------------------------------+---------------------+---------------------+ +| abs(SUM(numbers_input_df_func.number)) | window_start | window_end | ++----------------------------------------+---------------------+---------------------+ +| 2 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | ++----------------------------------------+---------------------+---------------------+ admin flush_flow('test_numbers_df_func'); @@ -163,14 +165,14 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +SELECT "abs(SUM(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; -+-------+---------------------+---------------------+ -| col_0 | window_start | window_end | -+-------+---------------------+---------------------+ -| 2 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | -| 1 | 2021-07-01T00:00:01 | 2021-07-01T00:00:02 | -+-------+---------------------+---------------------+ ++----------------------------------------+---------------------+---------------------+ +| abs(SUM(numbers_input_df_func.number)) | window_start | window_end | ++----------------------------------------+---------------------+---------------------+ +| 2 | 2021-07-01T00:00:00 | 2021-07-01T00:00:01 | +| 1 | 2021-07-01T00:00:01 | 2021-07-01T00:00:02 | ++----------------------------------------+---------------------+---------------------+ DROP FLOW test_numbers_df_func; @@ -196,8 +198,8 @@ Affected Rows: 0 CREATE FLOW test_numbers_df_func SINK TO out_num_cnt_df_func -AS -SELECT max(number) - min(number), date_bin(INTERVAL '1 second', ts, '2021-07-01 00:00:00'::TimestampNanosecond) FROM numbers_input_df_func GROUP BY date_bin(INTERVAL '1 second', ts, '2021-07-01 00:00:00'::TimestampNanosecond); +AS +SELECT max(number) - min(number) as maxmin, date_bin(INTERVAL '1 second', ts, '2021-07-01 00:00:00'::Timestamp) as time_window FROM numbers_input_df_func GROUP BY time_window; Affected Rows: 0 @@ -224,13 +226,13 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, col_1 FROM out_num_cnt_df_func; +SELECT maxmin, time_window FROM out_num_cnt_df_func; -+-------+---------------------+ -| col_0 | col_1 | -+-------+---------------------+ -| 2 | 2021-07-01T00:00:00 | -+-------+---------------------+ ++--------+---------------------+ +| maxmin | time_window | ++--------+---------------------+ +| 2 | 2021-07-01T00:00:00 | ++--------+---------------------+ admin flush_flow('test_numbers_df_func'); @@ -255,14 +257,14 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, col_1 FROM out_num_cnt_df_func; +SELECT maxmin, time_window FROM out_num_cnt_df_func; -+-------+---------------------+ -| col_0 | col_1 | -+-------+---------------------+ -| 2 | 2021-07-01T00:00:00 | -| 1 | 2021-07-01T00:00:01 | -+-------+---------------------+ ++--------+---------------------+ +| maxmin | time_window | ++--------+---------------------+ +| 2 | 2021-07-01T00:00:00 | +| 1 | 2021-07-01T00:00:01 | ++--------+---------------------+ DROP FLOW test_numbers_df_func; @@ -288,8 +290,8 @@ Affected Rows: 0 CREATE FLOW test_numbers_df_func SINK TO out_num_cnt -AS -SELECT date_trunc('second', ts), sum(number) FROM numbers_input_df_func GROUP BY date_trunc('second', ts); +AS +SELECT date_trunc('second', ts) as time_window, sum(number) as sum_num FROM numbers_input_df_func GROUP BY date_trunc('second', ts); Affected Rows: 0 @@ -316,13 +318,13 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, col_1 FROM out_num_cnt; +SELECT time_window, sum_num FROM out_num_cnt; -+---------------------+-------+ -| col_0 | col_1 | -+---------------------+-------+ -| 2021-07-01T00:00:00 | 42 | -+---------------------+-------+ ++---------------------+---------+ +| time_window | sum_num | ++---------------------+---------+ +| 2021-07-01T00:00:00 | 42 | ++---------------------+---------+ admin flush_flow('test_numbers_df_func'); @@ -347,14 +349,14 @@ admin flush_flow('test_numbers_df_func'); | 1 | +------------------------------------------+ -SELECT col_0, col_1 FROM out_num_cnt; +SELECT time_window, sum_num FROM out_num_cnt; -+---------------------+-------+ -| col_0 | col_1 | -+---------------------+-------+ -| 2021-07-01T00:00:00 | 42 | -| 2021-07-01T00:00:01 | 47 | -+---------------------+-------+ ++---------------------+---------+ +| time_window | sum_num | ++---------------------+---------+ +| 2021-07-01T00:00:00 | 42 | +| 2021-07-01T00:00:01 | 47 | ++---------------------+---------+ DROP FLOW test_numbers_df_func; diff --git a/tests/cases/standalone/common/flow/flow_call_df_func.sql b/tests/cases/standalone/common/flow/flow_call_df_func.sql index faa9ee1aabc2..389a0975c6e9 100644 --- a/tests/cases/standalone/common/flow/flow_call_df_func.sql +++ b/tests/cases/standalone/common/flow/flow_call_df_func.sql @@ -21,7 +21,8 @@ VALUES -- flush flow to make sure that table is created and data is inserted admin flush_flow('test_numbers_df_func'); -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +-- note that this quote-unquote column is a column-name, **not** a aggregation expr, generated by datafusion +SELECT "SUM(abs(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; admin flush_flow('test_numbers_df_func'); @@ -32,7 +33,8 @@ VALUES admin flush_flow('test_numbers_df_func'); -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +-- note that this quote-unquote column is a column-name, **not** a aggregation expr, generated by datafusion +SELECT "SUM(abs(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; DROP FLOW test_numbers_df_func; DROP TABLE numbers_input_df_func; @@ -61,7 +63,7 @@ VALUES -- flush flow to make sure that table is created and data is inserted admin flush_flow('test_numbers_df_func'); -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +SELECT "abs(SUM(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; admin flush_flow('test_numbers_df_func'); @@ -72,7 +74,7 @@ VALUES admin flush_flow('test_numbers_df_func'); -SELECT col_0, window_start, window_end FROM out_num_cnt_df_func; +SELECT "abs(SUM(numbers_input_df_func.number))", window_start, window_end FROM out_num_cnt_df_func; DROP FLOW test_numbers_df_func; DROP TABLE numbers_input_df_func; @@ -88,8 +90,8 @@ CREATE TABLE numbers_input_df_func ( CREATE FLOW test_numbers_df_func SINK TO out_num_cnt_df_func -AS -SELECT max(number) - min(number), date_bin(INTERVAL '1 second', ts, '2021-07-01 00:00:00'::TimestampNanosecond) FROM numbers_input_df_func GROUP BY date_bin(INTERVAL '1 second', ts, '2021-07-01 00:00:00'::TimestampNanosecond); +AS +SELECT max(number) - min(number) as maxmin, date_bin(INTERVAL '1 second', ts, '2021-07-01 00:00:00'::Timestamp) as time_window FROM numbers_input_df_func GROUP BY time_window; admin flush_flow('test_numbers_df_func'); @@ -100,7 +102,7 @@ VALUES admin flush_flow('test_numbers_df_func'); -SELECT col_0, col_1 FROM out_num_cnt_df_func; +SELECT maxmin, time_window FROM out_num_cnt_df_func; admin flush_flow('test_numbers_df_func'); @@ -111,7 +113,7 @@ VALUES admin flush_flow('test_numbers_df_func'); -SELECT col_0, col_1 FROM out_num_cnt_df_func; +SELECT maxmin, time_window FROM out_num_cnt_df_func; DROP FLOW test_numbers_df_func; DROP TABLE numbers_input_df_func; @@ -128,8 +130,8 @@ CREATE TABLE numbers_input_df_func ( CREATE FLOW test_numbers_df_func SINK TO out_num_cnt -AS -SELECT date_trunc('second', ts), sum(number) FROM numbers_input_df_func GROUP BY date_trunc('second', ts); +AS +SELECT date_trunc('second', ts) as time_window, sum(number) as sum_num FROM numbers_input_df_func GROUP BY date_trunc('second', ts); admin flush_flow('test_numbers_df_func'); @@ -140,7 +142,7 @@ VALUES admin flush_flow('test_numbers_df_func'); -SELECT col_0, col_1 FROM out_num_cnt; +SELECT time_window, sum_num FROM out_num_cnt; admin flush_flow('test_numbers_df_func'); @@ -151,7 +153,7 @@ VALUES admin flush_flow('test_numbers_df_func'); -SELECT col_0, col_1 FROM out_num_cnt; +SELECT time_window, sum_num FROM out_num_cnt; DROP FLOW test_numbers_df_func; DROP TABLE numbers_input_df_func;