From d53406bae577bf727ea0fe7ef014f58366963a62 Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 11 Mar 2024 15:55:27 +0800 Subject: [PATCH] fix(binder): insert binding (#15597) --- .../tests/testdata/input/insert.yaml | 14 ++++++++ .../tests/testdata/output/insert.yaml | 34 ++++++++++++++++++- src/frontend/src/binder/bind_context.rs | 1 + src/frontend/src/binder/expr/column.rs | 11 +++++- src/frontend/src/binder/expr/function.rs | 3 ++ src/frontend/src/binder/insert.rs | 4 ++- 6 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/input/insert.yaml b/src/frontend/planner_test/tests/testdata/input/insert.yaml index 1b4bb3b7d8000..2ddca64d81523 100644 --- a/src/frontend/planner_test/tests/testdata/input/insert.yaml +++ b/src/frontend/planner_test/tests/testdata/input/insert.yaml @@ -264,3 +264,17 @@ insert into t select * from t; expected_outputs: - batch_distributed_plan +- name: test binding for insert select issue 15594 + sql: | + create table t (a int); + create table t2 (b int); + insert into t select 1 as a from t2 group by a; + expected_outputs: + - batch_plan +- name: test binding for insert select issue 15594 + sql: | + create table t (a int); + create table t2 (b int); + insert into t select 1 as a from t2 group by a returning a; + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/insert.yaml b/src/frontend/planner_test/tests/testdata/output/insert.yaml index 356975b2fdf58..8f51f07d458e7 100644 --- a/src/frontend/planner_test/tests/testdata/output/insert.yaml +++ b/src/frontend/planner_test/tests/testdata/output/insert.yaml @@ -244,7 +244,11 @@ sql: | create table t (a int, b int); insert into t values (0,1), (1,2) returning sum(a); - binder_error: 'Bind error: should not have agg/window in the `RETURNING` list' + binder_error: | + Failed to bind expression: sum(a) + + Caused by: + Invalid input syntax: aggregate functions are not allowed in INSERT - name: insert and specify all columns with values sql: | create table t (a int, b int); @@ -343,3 +347,31 @@ └─BatchInsert { table: t, mapping: [0:0, 1:1] } └─BatchExchange { order: [], dist: HashShard(t.a, t.b) } └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } +- name: test binding for insert select issue 15594 + sql: | + create table t (a int); + create table t2 (b int); + insert into t select 1 as a from t2 group by a; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchInsert { table: t, mapping: [0:0] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [1:Int32] } + └─BatchHashAgg { group_key: [1:Int32], aggs: [] } + └─BatchExchange { order: [], dist: HashShard(1:Int32) } + └─BatchProject { exprs: [1:Int32] } + └─BatchScan { table: t2, columns: [], distribution: SomeShard } +- name: test binding for insert select issue 15594 + sql: | + create table t (a int); + create table t2 (b int); + insert into t select 1 as a from t2 group by a returning a; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchInsert { table: t, returning: true, mapping: [0:0] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [1:Int32] } + └─BatchHashAgg { group_key: [1:Int32], aggs: [] } + └─BatchExchange { order: [], dist: HashShard(1:Int32) } + └─BatchProject { exprs: [1:Int32] } + └─BatchScan { table: t2, columns: [], distribution: SomeShard } diff --git a/src/frontend/src/binder/bind_context.rs b/src/frontend/src/binder/bind_context.rs index 7a39ed1cee63b..0f545bab0ac26 100644 --- a/src/frontend/src/binder/bind_context.rs +++ b/src/frontend/src/binder/bind_context.rs @@ -56,6 +56,7 @@ pub enum Clause { Filter, From, GeneratedColumn, + Insert, } /// A `BindContext` that is only visible if the `LATERAL` keyword diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 9fb17c6e43520..8d00fb47a1d8b 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -16,7 +16,7 @@ use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Ident; -use crate::binder::Binder; +use crate::binder::{Binder, Clause}; use crate::expr::{CorrelatedInputRef, ExprImpl, ExprType, FunctionCall, InputRef, Literal}; impl Binder { @@ -103,6 +103,9 @@ impl Binder { for (i, lateral_context) in self.lateral_contexts.iter().rev().enumerate() { if lateral_context.is_visible { let context = &lateral_context.context; + if matches!(context.clause, Some(Clause::Insert)) { + continue; + } // input ref from lateral context `depth` starts from 1. let depth = i + 1; match context.get_column_binding_index(&table_name, &column_name) { @@ -125,6 +128,9 @@ impl Binder { for (i, (context, lateral_contexts)) in self.upper_subquery_contexts.iter().rev().enumerate() { + if matches!(context.clause, Some(Clause::Insert)) { + continue; + } // `depth` starts from 1. let depth = i + 1; match context.get_column_binding_index(&table_name, &column_name) { @@ -145,6 +151,9 @@ impl Binder { for (i, lateral_context) in lateral_contexts.iter().rev().enumerate() { if lateral_context.is_visible { let context = &lateral_context.context; + if matches!(context.clause, Some(Clause::Insert)) { + continue; + } // correlated input ref from lateral context `depth` starts from 1. let depth = i + 1; match context.get_column_binding_index(&table_name, &column_name) { diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 6f07696544365..c804f41388473 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -1403,6 +1403,7 @@ impl Binder { | Clause::Filter | Clause::GeneratedColumn | Clause::From + | Clause::Insert | Clause::JoinOn => { return Err(ErrorCode::InvalidInputSyntax(format!( "window functions are not allowed in {}", @@ -1455,6 +1456,7 @@ impl Binder { | Clause::Values | Clause::From | Clause::GeneratedColumn + | Clause::Insert | Clause::JoinOn => { return Err(ErrorCode::InvalidInputSyntax(format!( "aggregate functions are not allowed in {}", @@ -1476,6 +1478,7 @@ impl Binder { | Clause::Having | Clause::Filter | Clause::Values + | Clause::Insert | Clause::GeneratedColumn => { return Err(ErrorCode::InvalidInputSyntax(format!( "table functions are not allowed in {}", diff --git a/src/frontend/src/binder/insert.rs b/src/frontend/src/binder/insert.rs index 492d0abb5d91f..f67fa14496e65 100644 --- a/src/frontend/src/binder/insert.rs +++ b/src/frontend/src/binder/insert.rs @@ -23,7 +23,7 @@ use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem}; use super::statement::RewriteExprsRecursive; use super::BoundQuery; -use crate::binder::Binder; +use crate::binder::{Binder, Clause}; use crate::catalog::TableId; use crate::expr::{ExprImpl, InputRef}; use crate::user::UserId; @@ -102,6 +102,8 @@ impl Binder { returning_items: Vec, ) -> Result { let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?; + // bind insert table + self.context.clause = Some(Clause::Insert); self.bind_table(schema_name.as_deref(), &table_name, None)?; let table_catalog = self.resolve_dml_table(schema_name.as_deref(), &table_name, true)?;