From 90191c70624cbfe97ff2ff128cf08d243faf8497 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 16 Oct 2023 11:40:38 -0500 Subject: [PATCH] feat(binder): support distinct on column aliases (#12699) Signed-off-by: Richard Chien Co-authored-by: xxchan --- .../tests/testdata/input/distinct_on.yaml | 23 +++++- .../tests/testdata/output/distinct_on.yaml | 44 ++++++++++++ .../tests/testdata/output/order_by.yaml | 2 +- src/frontend/src/binder/query.rs | 2 +- src/frontend/src/binder/select.rs | 72 ++++++++++++++++--- 5 files changed, 128 insertions(+), 15 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/input/distinct_on.yaml b/src/frontend/planner_test/tests/testdata/input/distinct_on.yaml index 5117c8d9da093..940ac9d803178 100644 --- a/src/frontend/planner_test/tests/testdata/input/distinct_on.yaml +++ b/src/frontend/planner_test/tests/testdata/input/distinct_on.yaml @@ -2,9 +2,28 @@ create table t1 (k int, v int) append only; select distinct on (k) k + v as sum from t1; expected_outputs: - - stream_plan + - stream_plan + - batch_plan - sql: | create table t2 (k int, v int); select distinct on (k) k + v as sum from t2; expected_outputs: - - stream_plan + - stream_plan + - batch_plan +- sql: | + create table t (a int, b int, c int); + select distinct on (foo, b) a as foo, b from t; + expected_outputs: + - stream_plan + - batch_plan +- sql: | + create table t (a int, b int, c int); + select distinct on (2) a as foo, b from t; + expected_outputs: + - stream_plan + - batch_plan +- sql: | + create table t (a int, b int, c int); + select distinct on (4) * from t; + expected_outputs: + - binder_error diff --git a/src/frontend/planner_test/tests/testdata/output/distinct_on.yaml b/src/frontend/planner_test/tests/testdata/output/distinct_on.yaml index 244379843d141..656f96ecd04ef 100644 --- a/src/frontend/planner_test/tests/testdata/output/distinct_on.yaml +++ b/src/frontend/planner_test/tests/testdata/output/distinct_on.yaml @@ -2,6 +2,13 @@ - sql: | create table t1 (k int, v int) append only; select distinct on (k) k + v as sum from t1; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [$expr1] } + └─BatchGroupTopN { order: [], limit: 1, offset: 0, group_key: [t1.k] } + └─BatchExchange { order: [], dist: HashShard(t1.k) } + └─BatchProject { exprs: [(t1.k + t1.v) as $expr1, t1.k] } + └─BatchScan { table: t1, columns: [t1.k, t1.v], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [sum, t1.k(hidden)], stream_key: [t1.k], pk_columns: [t1.k], pk_conflict: NoCheck } └─StreamProject { exprs: [$expr1, t1.k] } @@ -12,6 +19,13 @@ - sql: | create table t2 (k int, v int); select distinct on (k) k + v as sum from t2; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [$expr1] } + └─BatchGroupTopN { order: [], limit: 1, offset: 0, group_key: [t2.k] } + └─BatchExchange { order: [], dist: HashShard(t2.k) } + └─BatchProject { exprs: [(t2.k + t2.v) as $expr1, t2.k] } + └─BatchScan { table: t2, columns: [t2.k, t2.v], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [sum, t2.k(hidden)], stream_key: [t2.k], pk_columns: [t2.k], pk_conflict: NoCheck } └─StreamProject { exprs: [$expr1, t2.k] } @@ -19,3 +33,33 @@ └─StreamExchange { dist: HashShard(t2.k) } └─StreamProject { exprs: [(t2.k + t2.v) as $expr1, t2.k, t2._row_id] } └─StreamTableScan { table: t2, columns: [t2.k, t2.v, t2._row_id], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } +- sql: | + create table t (a int, b int, c int); + select distinct on (foo, b) a as foo, b from t; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchGroupTopN { order: [], limit: 1, offset: 0, group_key: [t.a, t.b] } + └─BatchExchange { order: [], dist: HashShard(t.a, t.b) } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [foo, b, t._row_id(hidden)], stream_key: [foo, b], pk_columns: [foo, b], pk_conflict: NoCheck } + └─StreamGroupTopN { order: [], limit: 1, offset: 0, group_key: [t.a, t.b] } + └─StreamExchange { dist: HashShard(t.a, t.b) } + └─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + create table t (a int, b int, c int); + select distinct on (2) a as foo, b from t; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchGroupTopN { order: [], limit: 1, offset: 0, group_key: [t.b] } + └─BatchExchange { order: [], dist: HashShard(t.b) } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [foo, b, t._row_id(hidden)], stream_key: [b], pk_columns: [b], pk_conflict: NoCheck } + └─StreamGroupTopN { order: [], limit: 1, offset: 0, group_key: [t.b] } + └─StreamExchange { dist: HashShard(t.b) } + └─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + create table t (a int, b int, c int); + select distinct on (4) * from t; + binder_error: 'Invalid input syntax: Invalid ordinal number in DISTINCT ON: 4' diff --git a/src/frontend/planner_test/tests/testdata/output/order_by.yaml b/src/frontend/planner_test/tests/testdata/output/order_by.yaml index 669691df6acc9..548259c834f47 100644 --- a/src/frontend/planner_test/tests/testdata/output/order_by.yaml +++ b/src/frontend/planner_test/tests/testdata/output/order_by.yaml @@ -153,7 +153,7 @@ sql: | create table t (x int, y int); select x from t order by 2; - binder_error: 'Invalid input syntax: Invalid value in ORDER BY: 2' + binder_error: 'Invalid input syntax: Invalid ordinal number in ORDER BY: 2' - name: an output column name cannot be used in an expression sql: | create table t (x int, y int); diff --git a/src/frontend/src/binder/query.rs b/src/frontend/src/binder/query.rs index bb1c30d22b3d2..a3b78343c6041 100644 --- a/src/frontend/src/binder/query.rs +++ b/src/frontend/src/binder/query.rs @@ -262,7 +262,7 @@ impl Binder { Ok(index) if 1 <= index && index <= visible_output_num => index - 1, _ => { return Err(ErrorCode::InvalidInputSyntax(format!( - "Invalid value in ORDER BY: {}", + "Invalid ordinal number in ORDER BY: {}", number )) .into()) diff --git a/src/frontend/src/binder/select.rs b/src/frontend/src/binder/select.rs index f2eb37867e74b..48c4290ee7e05 100644 --- a/src/frontend/src/binder/select.rs +++ b/src/frontend/src/binder/select.rs @@ -23,7 +23,7 @@ use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::aggregate::AggKind; use risingwave_sqlparser::ast::{ BinaryOperator, DataType as AstDataType, Distinct, Expr, Ident, Join, JoinConstraint, - JoinOperator, ObjectName, Select, SelectItem, TableFactor, TableWithJoins, + JoinOperator, ObjectName, Select, SelectItem, TableFactor, TableWithJoins, Value, }; use super::bind_context::{Clause, ColumnBinding}; @@ -207,9 +207,10 @@ impl Binder { // Bind SELECT clause. let (select_items, aliases) = self.bind_select_list(select.projection)?; + let out_name_to_index = Self::build_name_to_index(aliases.iter().filter_map(Clone::clone)); // Bind DISTINCT ON. - let distinct = self.bind_distinct_on(select.distinct)?; + let distinct = self.bind_distinct_on(select.distinct, &out_name_to_index, &select_items)?; // Bind WHERE clause. self.context.clause = Some(Clause::Where); @@ -223,7 +224,6 @@ impl Binder { self.context.clause = None; // Bind GROUP BY clause. - let out_name_to_index = Self::build_name_to_index(aliases.iter().filter_map(Clone::clone)); self.context.clause = Some(Clause::GroupBy); // Only support one grouping item in group by clause @@ -360,6 +360,7 @@ impl Binder { } } } + assert_eq!(select_list.len(), aliases.len()); Ok((select_list, aliases)) } @@ -709,9 +710,7 @@ impl Binder { .expect("ExprImpl value is a Literal but cannot get ref to data") .as_utf8(); self.bind_cast( - Expr::Value(risingwave_sqlparser::ast::Value::SingleQuotedString( - table_name.to_string(), - )), + Expr::Value(Value::SingleQuotedString(table_name.to_string())), AstDataType::Regclass, ) } @@ -769,14 +768,67 @@ impl Binder { .unzip() } - fn bind_distinct_on(&mut self, distinct: Distinct) -> Result { + /// Bind `DISTINCT` clause in a [`Select`]. + /// Note that for `DISTINCT ON`, each expression is interpreted in the same way as `ORDER BY` + /// expression, which means it will be bound in the following order: + /// + /// * as an output-column name (can use aliases) + /// * as an index (from 1) of an output column + /// * as an arbitrary expression (cannot use aliases) + /// + /// See also the `bind_order_by_expr_in_query` method. + /// + /// # Arguments + /// + /// * `name_to_index` - output column name -> index. Ambiguous (duplicate) output names are + /// marked with `usize::MAX`. + fn bind_distinct_on( + &mut self, + distinct: Distinct, + name_to_index: &HashMap, + select_items: &[ExprImpl], + ) -> Result { Ok(match distinct { Distinct::All => BoundDistinct::All, Distinct::Distinct => BoundDistinct::Distinct, Distinct::DistinctOn(exprs) => { let mut bound_exprs = vec![]; for expr in exprs { - bound_exprs.push(self.bind_expr(expr)?); + let expr_impl = match expr { + Expr::Identifier(name) if let Some(index) = name_to_index.get(&name.real_value()) => { + match *index { + usize::MAX => { + return Err(ErrorCode::BindError(format!( + "DISTINCT ON \"{}\" is ambiguous", + name.real_value() + )) + .into()) + } + _ => { + InputRef::new(*index, select_items[*index].return_type()).into() + } + } + } + Expr::Value(Value::Number(number)) => { + match number.parse::() { + Ok(index) if 1 <= index && index <= select_items.len() => { + let idx_from_0 = index - 1; + InputRef::new(idx_from_0, select_items[idx_from_0].return_type()).into() + } + _ => { + return Err(ErrorCode::InvalidInputSyntax(format!( + "Invalid ordinal number in DISTINCT ON: {}", + number + )) + .into()) + } + } + } + expr => { + self.bind_expr(expr)? + } + }; + bound_exprs.push(expr_impl); } BoundDistinct::DistinctOn(bound_exprs) } @@ -822,9 +874,7 @@ fn derive_alias(expr: &Expr) -> Option { derive_alias(&expr).or_else(|| data_type_to_alias(&data_type)) } Expr::TypedString { data_type, .. } => data_type_to_alias(&data_type), - Expr::Value(risingwave_sqlparser::ast::Value::Interval { .. }) => { - Some("interval".to_string()) - } + Expr::Value(Value::Interval { .. }) => Some("interval".to_string()), Expr::Row(_) => Some("row".to_string()), Expr::Array(_) => Some("array".to_string()), Expr::ArrayIndex { obj, index: _ } => derive_alias(&obj),