diff --git a/e2e_test/batch/basic/dml.slt.part b/e2e_test/batch/basic/dml_basic.slt.part similarity index 100% rename from e2e_test/batch/basic/dml.slt.part rename to e2e_test/batch/basic/dml_basic.slt.part diff --git a/e2e_test/batch/basic/dml_update.slt.part b/e2e_test/batch/basic/dml_update.slt.part new file mode 100644 index 0000000000000..fcc3bbdfce9a2 --- /dev/null +++ b/e2e_test/batch/basic/dml_update.slt.part @@ -0,0 +1,132 @@ +# Extension to `dml_basic.slt.part` for testing advanced `UPDATE` statements. + +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t (v1 int default 1919, v2 int default 810); + +statement ok +insert into t values (114, 514); + + +# Single assignment, to subquery. +statement ok +update t set v1 = (select 666); + +query II +select * from t; +---- +666 514 + +# Single assignment, to runtime-cardinality subquery returning 1 row. +statement ok +update t set v1 = (select generate_series(888, 888)); + +query II +select * from t; +---- +888 514 + +# Single assignment, to runtime-cardinality subquery returning 0 rows (set to NULL). +statement ok +update t set v1 = (select generate_series(1, 0)); + +query II +select * from t; +---- +NULL 514 + +# Single assignment, to runtime-cardinality subquery returning multiple rows. +statement error Scalar subquery produced more than one row +update t set v1 = (select generate_series(1, 2)); + +# Single assignment, to correlated subquery. +statement ok +update t set v1 = (select count(*) from t as source where source.v2 = t.v2); + +query II +select * from t; +---- +1 514 + +# Single assignment, to subquery with mismatched column count. +statement error must return only one column +update t set v1 = (select 666, 888); + + +# Multiple assignment clauses. +statement ok +update t set v1 = 1919, v2 = 810; + +query II +select * from t; +---- +1919 810 + +# Multiple assignments to the same column. +statement error multiple assignments to the same column +update t set v1 = 1, v1 = 2; + +statement error multiple assignments to the same column +update t set (v1, v1) = (1, 2); + +statement error multiple assignments to the same column +update t set (v1, v2) = (1, 2), v2 = 2; + +# Multiple assignments, to subquery. +statement ok +update t set (v1, v2) = (select 666, 888); + +query II +select * from t; +---- +666 888 + +# Multiple assignments, to subquery with cast. +statement ok +update t set (v1, v2) = (select 888.88, 999); + +query II +select * from t; +---- +889 999 + +# Multiple assignments, to subquery with cast failure. +# TODO: this currently shows `cannot cast type "record" to "record"` because we wrap the subquery result +# into a struct, which is not quite clear. +statement error cannot cast type +update t set (v1, v2) = (select '888.88', 999); + +# Multiple assignments, to subquery with mismatched column count. +statement error number of columns does not match number of values +update t set (v1, v2) = (select 666); + +# Multiple assignments, to scalar expression. +statement error source for a multiple-column UPDATE item must be a sub-SELECT or ROW\(\) expression +update t set (v1, v2) = v1 + 1; + + +# Assignment to system columns. +statement error update modifying column `_rw_timestamp` is unsupported +update t set _rw_timestamp = _rw_timestamp + interval '1 second'; + + +# https://github.com/risingwavelabs/risingwave/pull/19402#pullrequestreview-2444427475 +# https://github.com/risingwavelabs/risingwave/pull/19452 +statement ok +create table y (v1 int, v2 int); + +statement ok +insert into y values (11, 11), (22, 22); + +statement error Scalar subquery produced more than one row +update t set (v1, v2) = (select y.v1, y.v2 from y); + +statement ok +drop table y; + + +# Cleanup. +statement ok +drop table t; diff --git a/proto/batch_plan.proto b/proto/batch_plan.proto index b46230b2438d6..f10092d952ac7 100644 --- a/proto/batch_plan.proto +++ b/proto/batch_plan.proto @@ -173,11 +173,12 @@ message UpdateNode { // Id of the table to perform updating. uint32 table_id = 1; // Version of the table. - uint64 table_version_id = 4; - repeated expr.ExprNode exprs = 2; - bool returning = 3; - // The columns indices in the input schema, representing the columns need to send to streamDML exeuctor. - repeated uint32 update_column_indices = 5; + uint64 table_version_id = 2; + // Expressions to generate `U-` records. + repeated expr.ExprNode old_exprs = 3; + // Expressions to generate `U+` records. + repeated expr.ExprNode new_exprs = 4; + bool returning = 5; // Session id is used to ensure that dml data from the same session should be sent to a fixed worker node and channel. uint32 session_id = 6; diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index a753aef840f52..95f1963cf582e 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -42,13 +42,13 @@ pub struct UpdateExecutor { table_version_id: TableVersionId, dml_manager: DmlManagerRef, child: BoxedExecutor, - exprs: Vec, + old_exprs: Vec, + new_exprs: Vec, chunk_size: usize, schema: Schema, identity: String, returning: bool, txn_id: TxnId, - update_column_indices: Vec, session_id: u32, } @@ -59,11 +59,11 @@ impl UpdateExecutor { table_version_id: TableVersionId, dml_manager: DmlManagerRef, child: BoxedExecutor, - exprs: Vec, + old_exprs: Vec, + new_exprs: Vec, chunk_size: usize, identity: String, returning: bool, - update_column_indices: Vec, session_id: u32, ) -> Self { let chunk_size = chunk_size.next_multiple_of(2); @@ -75,7 +75,8 @@ impl UpdateExecutor { table_version_id, dml_manager, child, - exprs, + old_exprs, + new_exprs, chunk_size, schema: if returning { table_schema @@ -87,7 +88,6 @@ impl UpdateExecutor { identity, returning, txn_id, - update_column_indices, session_id, } } @@ -109,7 +109,7 @@ impl Executor for UpdateExecutor { impl UpdateExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] - async fn do_execute(mut self: Box) { + async fn do_execute(self: Box) { let table_dml_handle = self .dml_manager .table_dml_handle(self.table_id, self.table_version_id)?; @@ -122,15 +122,12 @@ impl UpdateExecutor { assert_eq!( data_types, - self.exprs.iter().map(|e| e.return_type()).collect_vec(), + self.new_exprs.iter().map(|e| e.return_type()).collect_vec(), "bad update schema" ); assert_eq!( data_types, - self.update_column_indices - .iter() - .map(|i: &usize| self.child.schema()[*i].data_type.clone()) - .collect_vec(), + self.old_exprs.iter().map(|e| e.return_type()).collect_vec(), "bad update schema" ); @@ -159,27 +156,35 @@ impl UpdateExecutor { let mut rows_updated = 0; #[for_await] - for data_chunk in self.child.execute() { - let data_chunk = data_chunk?; + for input in self.child.execute() { + let input = input?; + + let old_data_chunk = { + let mut columns = Vec::with_capacity(self.old_exprs.len()); + for expr in &self.old_exprs { + let column = expr.eval(&input).await?; + columns.push(column); + } + + DataChunk::new(columns, input.visibility().clone()) + }; let updated_data_chunk = { - let mut columns = Vec::with_capacity(self.exprs.len()); - for expr in &mut self.exprs { - let column = expr.eval(&data_chunk).await?; + let mut columns = Vec::with_capacity(self.new_exprs.len()); + for expr in &self.new_exprs { + let column = expr.eval(&input).await?; columns.push(column); } - DataChunk::new(columns, data_chunk.visibility().clone()) + DataChunk::new(columns, input.visibility().clone()) }; if self.returning { yield updated_data_chunk.clone(); } - for (row_delete, row_insert) in data_chunk - .project(&self.update_column_indices) - .rows() - .zip_eq_debug(updated_data_chunk.rows()) + for (row_delete, row_insert) in + (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows()) { rows_updated += 1; // If row_delete == row_insert, we don't need to do a actual update @@ -227,34 +232,35 @@ impl BoxedExecutorBuilder for UpdateExecutor { let table_id = TableId::new(update_node.table_id); - let exprs: Vec<_> = update_node - .get_exprs() + let old_exprs: Vec<_> = update_node + .get_old_exprs() .iter() .map(build_from_prost) .try_collect()?; - let update_column_indices = update_node - .update_column_indices + let new_exprs: Vec<_> = update_node + .get_new_exprs() .iter() - .map(|x| *x as usize) - .collect_vec(); + .map(build_from_prost) + .try_collect()?; Ok(Box::new(Self::new( table_id, update_node.table_version_id, source.context().dml_manager(), child, - exprs, + old_exprs, + new_exprs, source.context.get_config().developer.chunk_size, source.plan_node().get_identity().clone(), update_node.returning, - update_column_indices, update_node.session_id, ))) } } #[cfg(test)] +#[cfg(any())] mod tests { use std::sync::Arc; diff --git a/src/frontend/planner_test/tests/testdata/input/update.yaml b/src/frontend/planner_test/tests/testdata/input/update.yaml index 65c0f47eb4cd4..744735af843de 100644 --- a/src/frontend/planner_test/tests/testdata/input/update.yaml +++ b/src/frontend/planner_test/tests/testdata/input/update.yaml @@ -76,7 +76,7 @@ update t set v2 = 3; expected_outputs: - binder_error -- name: update subquery +- name: update subquery selection sql: | create table t (a int, b int); update t set a = 777 where b not in (select a from t); @@ -98,10 +98,27 @@ update t set a = a + 1; expected_outputs: - batch_distributed_plan -- name: update table with subquery in the set clause +- name: update table to subquery sql: | - create table t1 (v1 int primary key, v2 int); - create table t2 (v1 int primary key, v2 int); - update t1 set v1 = (select v1 from t2 where t1.v2 = t2.v2); + create table t (v1 int, v2 int); + update t set v1 = (select 666); + expected_outputs: + - batch_plan +- name: update table to subquery with runtime cardinality + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select generate_series(888, 888)); + expected_outputs: + - batch_plan +- name: update table to correlated subquery + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select count(*) from t as source where source.v2 = t.v2); expected_outputs: - - binder_error + - batch_plan +- name: update table to subquery with multiple assignments + sql: | + create table t (v1 int, v2 int); + update t set (v1, v2) = (select 666.66, 777); + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/index_selection.yaml b/src/frontend/planner_test/tests/testdata/output/index_selection.yaml index a6240c69f395f..349c5f7d89012 100644 --- a/src/frontend/planner_test/tests/testdata/output/index_selection.yaml +++ b/src/frontend/planner_test/tests/testdata/output/index_selection.yaml @@ -213,16 +213,18 @@ update t1 set c = 3 where a = 1 and b = 2; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t1, exprs: [$0, $1, 3:Int64, $3] } + └─BatchUpdate { table: t1, exprs: [$0, $1, $5, $3] } └─BatchExchange { order: [], dist: Single } - └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } - └─BatchExchange { order: [], dist: UpstreamHashShard(idx2.t1._row_id) } - └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } + └─BatchProject { exprs: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp, 3:Int64] } + └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } + └─BatchExchange { order: [], dist: UpstreamHashShard(idx2.t1._row_id) } + └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } batch_local_plan: |- - BatchUpdate { table: t1, exprs: [$0, $1, 3:Int64, $3] } - └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } - └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } + BatchUpdate { table: t1, exprs: [$0, $1, $5, $3] } + └─BatchProject { exprs: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp, 3:Int64] } + └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } - sql: | create table t1 (a int, b numeric, c bigint, p int); create materialized view v as select count(*) as cnt, p from t1 group by p; diff --git a/src/frontend/planner_test/tests/testdata/output/update.yaml b/src/frontend/planner_test/tests/testdata/output/update.yaml index 19d6673d77f9a..4a12b492660ad 100644 --- a/src/frontend/planner_test/tests/testdata/output/update.yaml +++ b/src/frontend/planner_test/tests/testdata/output/update.yaml @@ -4,9 +4,10 @@ update t set v1 = 0; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [0:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, 0:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set v1 = true; @@ -16,72 +17,81 @@ update t set v1 = v2 + 1; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 real); update t set v1 = v2; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [$1::Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, t.v2::Int32 as $expr1] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 real); update t set v1 = DEFAULT; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [null:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, null:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set v1 = v2 + 1 where v2 > 0; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchFilter { predicate: (t.v2 > 0:Int32) } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1] } + └─BatchFilter { predicate: (t.v2 > 0:Int32) } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set (v1, v2) = (v2 + 1, v1 - 1) where v1 != v2; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), ($0 - 1:Int32), $2] } + └─BatchUpdate { table: t, exprs: [$4, $5, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchFilter { predicate: (t.v1 <> t.v2) } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } + └─BatchFilter { predicate: (t.v1 <> t.v2) } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set (v1, v2) = (v2 + 1, v1 - 1) where v1 != v2 returning *, v2+1, v1-1; logical_plan: |- - LogicalProject { exprs: [t.v1, t.v2, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } - └─LogicalUpdate { table: t, exprs: [($1 + 1:Int32), ($0 - 1:Int32), $2], returning: true } - └─LogicalFilter { predicate: (t.v1 <> t.v2) } - └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } + LogicalProject { exprs: [, , ( + 1:Int32) as $expr3, ( - 1:Int32) as $expr4] } + └─LogicalUpdate { table: t, exprs: [$4, $5, $2], returning: true } + └─LogicalProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } + └─LogicalFilter { predicate: (t.v1 <> t.v2) } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchProject { exprs: [t.v1, t.v2, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), ($0 - 1:Int32), $2], returning: true } + └─BatchProject { exprs: [, , ( + 1:Int32) as $expr3, ( - 1:Int32) as $expr4] } + └─BatchUpdate { table: t, exprs: [$4, $5, $2], returning: true } └─BatchExchange { order: [], dist: Single } - └─BatchFilter { predicate: (t.v1 <> t.v2) } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } + └─BatchFilter { predicate: (t.v1 <> t.v2) } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - name: update with returning statement, should keep `Update` sql: | create table t (v int); update t set v = 114 returning 514; logical_plan: |- LogicalProject { exprs: [514:Int32] } - └─LogicalUpdate { table: t, exprs: [114:Int32, $1], returning: true } - └─LogicalScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp] } + └─LogicalUpdate { table: t, exprs: [$3, $1], returning: true } + └─LogicalProject { exprs: [t.v, t._row_id, t._rw_timestamp, 114:Int32] } + └─LogicalScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } └─BatchProject { exprs: [514:Int32] } - └─BatchUpdate { table: t, exprs: [114:Int32, $1], returning: true } + └─BatchUpdate { table: t, exprs: [$3, $1], returning: true } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v, t._row_id, t._rw_timestamp, 114:Int32] } + └─BatchScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int primary key, v2 int); update t set (v2, v1) = (v1, v2); @@ -90,22 +100,25 @@ create table t (v1 int default 1+1, v2 int); update t set v1 = default; logical_plan: |- - LogicalUpdate { table: t, exprs: [(1:Int32 + 1:Int32), $1, $2] } - └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } + LogicalUpdate { table: t, exprs: [$4, $1, $2] } + └─LogicalProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (1:Int32 + 1:Int32) as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [2:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, 2:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - name: update table with generated columns sql: | create table t(v1 int as v2-1, v2 int, v3 int as v2+1); update t set v2 = 3; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [3:Int32, $3] } + └─BatchUpdate { table: t, exprs: [$5, $3] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t.v3, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t.v3, t._row_id, t._rw_timestamp, 3:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t.v3, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - name: update generated column sql: | create table t(v1 int as v2-1, v2 int, v3 int as v2+1); @@ -116,25 +129,27 @@ create table t(v1 int as v2-1, v2 int, v3 int as v2+1, primary key (v3)); update t set v2 = 3; binder_error: 'Bind error: update modifying the column referenced by generated columns that are part of the primary key is not allowed' -- name: update subquery +- name: update subquery selection sql: | create table t (a int, b int); update t set a = 777 where b not in (select a from t); logical_plan: |- - LogicalUpdate { table: t, exprs: [777:Int32, $1, $2] } - └─LogicalApply { type: LeftAnti, on: (t.b = t.a), correlated_id: 1 } - ├─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } - └─LogicalProject { exprs: [t.a] } - └─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } + LogicalUpdate { table: t, exprs: [$4, $1, $2] } + └─LogicalProject { exprs: [t.a, t.b, t._row_id, t._rw_timestamp, 777:Int32] } + └─LogicalApply { type: LeftAnti, on: (t.b = t.a), correlated_id: 1 } + ├─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } + └─LogicalProject { exprs: [t.a] } + └─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [777:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchHashJoin { type: LeftAnti, predicate: t.b = t.a, output: all } - ├─BatchExchange { order: [], dist: HashShard(t.b) } - │ └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - └─BatchExchange { order: [], dist: HashShard(t.a) } - └─BatchScan { table: t, columns: [t.a], distribution: SomeShard } + └─BatchProject { exprs: [t.a, t.b, t._row_id, t._rw_timestamp, 777:Int32] } + └─BatchHashJoin { type: LeftAnti, predicate: t.b = t.a, output: all } + ├─BatchExchange { order: [], dist: HashShard(t.b) } + │ └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchExchange { order: [], dist: HashShard(t.a) } + └─BatchScan { table: t, columns: [t.a], distribution: SomeShard } - name: delete subquery sql: | create table t (a int, b int); @@ -163,12 +178,65 @@ batch_distributed_plan: |- BatchSimpleAgg { aggs: [sum()] } └─BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($0 + 1:Int32), $1, $2] } - └─BatchExchange { order: [], dist: HashShard(t.a, t.b, t._row_id, t._rw_timestamp) } - └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } -- name: update table with subquery in the set clause + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchExchange { order: [], dist: HashShard(t.a, t.b, t._row_id, t._rw_timestamp, $expr1) } + └─BatchProject { exprs: [t.a, t.b, t._row_id, t._rw_timestamp, (t.a + 1:Int32) as $expr1] } + └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } +- name: update table to subquery + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select 666); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchValues { rows: [[666:Int32]] } +- name: update table to subquery with runtime cardinality + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select generate_series(888, 888)); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchMaxOneRow + └─BatchProject { exprs: [GenerateSeries(888:Int32, 888:Int32)] } + └─BatchProjectSet { select_list: [GenerateSeries(888:Int32, 888:Int32)] } + └─BatchValues { rows: [[]] } +- name: update table to correlated subquery sql: | - create table t1 (v1 int primary key, v2 int); - create table t2 (v1 int primary key, v2 int); - update t1 set v1 = (select v1 from t2 where t1.v2 = t2.v2); - binder_error: 'Bind error: subquery on the right side of assignment is unsupported' + create table t (v1 int, v2 int); + update t set v1 = (select count(*) from t as source where source.v2 = t.v2); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, count(1:Int32)::Int32 as $expr1] } + └─BatchHashJoin { type: LeftOuter, predicate: t.v2 IS NOT DISTINCT FROM t.v2, output: [t.v1, t.v2, t._row_id, t._rw_timestamp, count(1:Int32)] } + ├─BatchExchange { order: [], dist: HashShard(t.v2) } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchHashAgg { group_key: [t.v2], aggs: [count(1:Int32)] } + └─BatchHashJoin { type: LeftOuter, predicate: t.v2 IS NOT DISTINCT FROM t.v2, output: [t.v2, 1:Int32] } + ├─BatchHashAgg { group_key: [t.v2], aggs: [] } + │ └─BatchExchange { order: [], dist: HashShard(t.v2) } + │ └─BatchScan { table: t, columns: [t.v2], distribution: SomeShard } + └─BatchExchange { order: [], dist: HashShard(t.v2) } + └─BatchProject { exprs: [t.v2, 1:Int32] } + └─BatchFilter { predicate: IsNotNull(t.v2) } + └─BatchScan { table: t, columns: [t.v2], distribution: SomeShard } +- name: update table to subquery with multiple assignments + sql: | + create table t (v1 int, v2 int); + update t set (v1, v2) = (select 666.66, 777); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [Field($4, 0:Int32), Field($4, 1:Int32), $2] } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, $expr10011::Struct(StructType { field_names: [], field_types: [Int32, Int32] }) as $expr1] } + └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchValues { rows: [['(666.66,777)':Struct(StructType { field_names: [], field_types: [Decimal, Int32] })]] } diff --git a/src/frontend/src/binder/expr/subquery.rs b/src/frontend/src/binder/expr/subquery.rs index 51819116771f1..c31a5d653aeb5 100644 --- a/src/frontend/src/binder/expr/subquery.rs +++ b/src/frontend/src/binder/expr/subquery.rs @@ -15,20 +15,16 @@ use risingwave_sqlparser::ast::Query; use crate::binder::Binder; -use crate::error::{ErrorCode, Result}; +use crate::error::{bail_bind_error, Result}; use crate::expr::{ExprImpl, Subquery, SubqueryKind}; impl Binder { - pub(super) fn bind_subquery_expr( - &mut self, - query: Query, - kind: SubqueryKind, - ) -> Result { + pub fn bind_subquery_expr(&mut self, query: Query, kind: SubqueryKind) -> Result { let query = self.bind_query(query)?; - if !matches!(kind, SubqueryKind::Existential) && query.data_types().len() != 1 { - return Err( - ErrorCode::BindError("Subquery must return only one column".to_string()).into(), - ); + if !matches!(kind, SubqueryKind::Existential | SubqueryKind::UpdateSet) + && query.data_types().len() != 1 + { + bail_bind_error!("Subquery must return only one column"); } Ok(Subquery::new(query, kind).into()) } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index b346dc45ca2d0..4560e51bd6562 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -58,7 +58,7 @@ pub use relation::{ pub use select::{BoundDistinct, BoundSelect}; pub use set_expr::*; pub use statement::BoundStatement; -pub use update::BoundUpdate; +pub use update::{BoundUpdate, UpdateProject}; pub use values::BoundValues; use crate::catalog::catalog_service::CatalogReadGuard; diff --git a/src/frontend/src/binder/update.rs b/src/frontend/src/binder/update.rs index 9cc80dbde4471..f57ad1d197982 100644 --- a/src/frontend/src/binder/update.rs +++ b/src/frontend/src/binder/update.rs @@ -12,23 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::hash_map::Entry; use std::collections::{BTreeMap, HashMap}; use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::catalog::{Schema, TableVersionId}; +use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem}; use super::statement::RewriteExprsRecursive; use super::{Binder, BoundBaseTable}; use crate::catalog::TableId; -use crate::error::{ErrorCode, Result, RwError}; -use crate::expr::{Expr as _, ExprImpl, InputRef}; +use crate::error::{bail_bind_error, bind_error, ErrorCode, Result, RwError}; +use crate::expr::{Expr as _, ExprImpl, SubqueryKind}; use crate::user::UserId; use crate::TableCatalog; +/// Project into `exprs` in `BoundUpdate` to get the new values for updating. +#[derive(Debug, Clone, Copy)] +pub enum UpdateProject { + /// Use the expression at the given index in `exprs`. + Simple(usize), + /// Use the `i`-th field of the expression (returning a struct) at the given index in `exprs`. + Composite(usize, usize), +} + +impl UpdateProject { + /// Offset the index by `i`. + pub fn offset(self, i: usize) -> Self { + match self { + UpdateProject::Simple(index) => UpdateProject::Simple(index + i), + UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j), + } + } +} + #[derive(Debug, Clone)] pub struct BoundUpdate { /// Id of the table to perform updating. @@ -48,10 +67,14 @@ pub struct BoundUpdate { pub selection: Option, - /// Expression used to project to the updated row. The assigned columns will use the new - /// expression, and the other columns will be simply `InputRef`. + /// Expression used to evaluate the new values for the columns. pub exprs: Vec, + /// Mapping from the index of the column to be updated, to the index of the expression in `exprs`. + /// + /// By constructing two `Project` nodes with `exprs` and `projects`, we can get the new values. + pub projects: HashMap, + // used for the 'RETURNING" keyword to indicate the returning items and schema // if the list is empty and the schema is None, the output schema will be a INT64 as the // affected row cnt @@ -124,107 +147,112 @@ impl Binder { let selection = selection.map(|expr| self.bind_expr(expr)).transpose()?; - let mut assignment_exprs = HashMap::new(); - for Assignment { id, value } in assignments { - // FIXME: Parsing of `id` is not strict. It will even treat `a.b` as `(a, b)`. - let assignments = match (id.as_slice(), value) { - // _ = (subquery) - (_ids, AssignmentValue::Expr(Expr::Subquery(_))) => { - return Err(ErrorCode::BindError( - "subquery on the right side of assignment is unsupported".to_owned(), - ) - .into()) - } - // col = expr - ([id], value) => { - vec![(id.clone(), value)] - } - // (col1, col2) = (expr1, expr2) - // TODO: support `DEFAULT` in multiple assignments - (ids, AssignmentValue::Expr(Expr::Row(values))) if ids.len() == values.len() => id - .into_iter() - .zip_eq_fast(values.into_iter().map(AssignmentValue::Expr)) - .collect(), - // (col1, col2) = - _ => { - return Err(ErrorCode::BindError( - "number of columns does not match number of values".to_owned(), - ) - .into()) - } + let mut exprs = Vec::new(); + let mut projects = HashMap::new(); + + macro_rules! record { + ($id:expr, $project:expr) => { + let id_index = $id.as_input_ref().unwrap().index; + projects + .try_insert(id_index, $project) + .map_err(|_e| bind_error!("multiple assignments to the same column"))?; }; + } - for (id, value) in assignments { - let id_expr = self.bind_expr(Expr::Identifier(id.clone()))?; - let id_index = if let Some(id_input_ref) = id_expr.clone().as_input_ref() { - let id_index = id_input_ref.index; - if table - .table_catalog - .pk() - .iter() - .any(|k| k.column_index == id_index) - { - return Err(ErrorCode::BindError( - "update modifying the PK column is unsupported".to_owned(), - ) - .into()); - } - if table - .table_catalog - .generated_col_idxes() - .contains(&id_index) - { - return Err(ErrorCode::BindError( - "update modifying the generated column is unsupported".to_owned(), - ) - .into()); + for Assignment { id, value } in assignments { + let ids: Vec<_> = id + .into_iter() + .map(|id| self.bind_expr(Expr::Identifier(id))) + .try_collect()?; + + match (ids.as_slice(), value) { + // `SET col1 = DEFAULT`, `SET (col1, col2, ...) = DEFAULT` + (ids, AssignmentValue::Default) => { + for id in ids { + let id_index = id.as_input_ref().unwrap().index; + let expr = default_columns_from_catalog + .get(&id_index) + .cloned() + .unwrap_or_else(|| ExprImpl::literal_null(id.return_type())); + + exprs.push(expr); + record!(id, UpdateProject::Simple(exprs.len() - 1)); } - if cols_refed_by_generated_pk.contains(id_index) { - return Err(ErrorCode::BindError( - "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(), - ) - .into()); + } + + // `SET col1 = expr` + ([id], AssignmentValue::Expr(expr)) => { + let expr = self.bind_expr(expr)?.cast_assign(id.return_type())?; + exprs.push(expr); + record!(id, UpdateProject::Simple(exprs.len() - 1)); + } + // `SET (col1, col2, ...) = (val1, val2, ...)` + (ids, AssignmentValue::Expr(Expr::Row(values))) => { + if ids.len() != values.len() { + bail_bind_error!("number of columns does not match number of values"); } - id_index - } else { - unreachable!() - }; - - let value_expr = match value { - AssignmentValue::Expr(expr) => { - self.bind_expr(expr)?.cast_assign(id_expr.return_type())? + + for (id, value) in ids.iter().zip_eq_fast(values) { + let expr = self.bind_expr(value)?.cast_assign(id.return_type())?; + exprs.push(expr); + record!(id, UpdateProject::Simple(exprs.len() - 1)); } - AssignmentValue::Default => default_columns_from_catalog - .get(&id_index) - .cloned() - .unwrap_or_else(|| ExprImpl::literal_null(id_expr.return_type())), - }; - - match assignment_exprs.entry(id_expr) { - Entry::Occupied(_) => { - return Err(ErrorCode::BindError( - "multiple assignments to same column".to_owned(), - ) - .into()) + } + // `SET (col1, col2, ...) = (SELECT ...)` + (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => { + let expr = self.bind_subquery_expr(*subquery, SubqueryKind::UpdateSet)?; + + if expr.return_type().as_struct().len() != ids.len() { + bail_bind_error!("number of columns does not match number of values"); } - Entry::Vacant(v) => { - v.insert(value_expr); + + let target_type = DataType::new_unnamed_struct( + ids.iter().map(|id| id.return_type()).collect(), + ); + let expr = expr.cast_assign(target_type)?; + + exprs.push(expr); + + for (i, id) in ids.iter().enumerate() { + record!(id, UpdateProject::Composite(exprs.len() - 1, i)); } } + + (_ids, AssignmentValue::Expr(_expr)) => { + bail_bind_error!("source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"); + } } } - let exprs = table - .table_catalog - .columns() - .iter() - .enumerate() - .filter_map(|(i, c)| { - c.can_dml() - .then_some(InputRef::new(i, c.data_type().clone()).into()) - }) - .map(|c| assignment_exprs.remove(&c).unwrap_or(c)) - .collect_vec(); + // Check whether updating these columns is allowed. + for &id_index in projects.keys() { + if (table.table_catalog.pk()) + .iter() + .any(|k| k.column_index == id_index) + { + return Err(ErrorCode::BindError( + "update modifying the PK column is unsupported".to_owned(), + ) + .into()); + } + if (table.table_catalog.generated_col_idxes()).contains(&id_index) { + return Err(ErrorCode::BindError( + "update modifying the generated column is unsupported".to_owned(), + ) + .into()); + } + if cols_refed_by_generated_pk.contains(id_index) { + return Err(ErrorCode::BindError( + "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(), + ) + .into()); + } + + let col = &table.table_catalog.columns()[id_index]; + if !col.can_dml() { + bail_bind_error!("update modifying column `{}` is unsupported", col.name()); + } + } let (returning_list, fields) = self.bind_returning_list(returning_items)?; let returning = !returning_list.is_empty(); @@ -236,6 +264,7 @@ impl Binder { owner, table, selection, + projects, exprs, returning_list, returning_schema: if returning { diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 3092c9bee91a9..f0cf35e859664 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -33,8 +33,8 @@ use tokio::task::JoinError; // - Some variants are never constructed. // - Some variants store a type-erased `BoxedError` to resolve the reverse dependency. // It's not necessary anymore as the error type is now defined at the top-level. -#[derive(Error, thiserror_ext::ReportDebug, thiserror_ext::Box)] -#[thiserror_ext(newtype(name = RwError, backtrace))] +#[derive(Error, thiserror_ext::ReportDebug, thiserror_ext::Box, thiserror_ext::Macro)] +#[thiserror_ext(newtype(name = RwError, backtrace), macro(path = "crate::error"))] pub enum ErrorCode { #[error("internal error: {0}")] InternalError(String), @@ -105,7 +105,7 @@ pub enum ErrorCode { // TODO: use a new type for bind error // TODO(error-handling): should prefer use error types than strings. #[error("Bind error: {0}")] - BindError(String), + BindError(#[message] String), // TODO: only keep this one #[error("Failed to bind expression: {expr}: {error}")] BindErrorRoot { diff --git a/src/frontend/src/expr/subquery.rs b/src/frontend/src/expr/subquery.rs index 62f59c934dd6d..8460f73d5fbba 100644 --- a/src/frontend/src/expr/subquery.rs +++ b/src/frontend/src/expr/subquery.rs @@ -24,6 +24,9 @@ use crate::expr::{CorrelatedId, Depth}; pub enum SubqueryKind { /// Returns a scalar value (single column single row). Scalar, + /// Returns a scalar struct value composed of multiple columns. + /// Used in `UPDATE SET (col1, col2) = (SELECT ...)`. + UpdateSet, /// `EXISTS` | `NOT EXISTS` subquery (semi/anti-semi join). Returns a boolean. Existential, /// `IN` subquery. @@ -88,6 +91,7 @@ impl Expr for Subquery { assert_eq!(types.len(), 1, "Subquery with more than one column"); types[0].clone() } + SubqueryKind::UpdateSet => DataType::new_unnamed_struct(self.query.data_types()), SubqueryKind::Array => { let types = self.query.data_types(); assert_eq!(types.len(), 1, "Subquery with more than one column"); diff --git a/src/frontend/src/optimizer/plan_node/batch_update.rs b/src/frontend/src/optimizer/plan_node/batch_update.rs index d0351e6fdec2e..28dfa79916cc9 100644 --- a/src/frontend/src/optimizer/plan_node/batch_update.rs +++ b/src/frontend/src/optimizer/plan_node/batch_update.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use itertools::Itertools; use risingwave_common::catalog::Schema; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::UpdateNode; @@ -84,20 +83,21 @@ impl ToDistributedBatch for BatchUpdate { impl ToBatchPb for BatchUpdate { fn to_batch_prost_body(&self) -> NodeBody { - let exprs = self.core.exprs.iter().map(|x| x.to_expr_proto()).collect(); - - let update_column_indices = self - .core - .update_column_indices + let old_exprs = (self.core.old_exprs) + .iter() + .map(|x| x.to_expr_proto()) + .collect(); + let new_exprs = (self.core.new_exprs) .iter() - .map(|i| *i as _) - .collect_vec(); + .map(|x| x.to_expr_proto()) + .collect(); + NodeBody::Update(UpdateNode { - exprs, table_id: self.core.table_id.table_id(), table_version_id: self.core.table_version_id, returning: self.core.returning, - update_column_indices, + old_exprs, + new_exprs, session_id: self.base.ctx().session_ctx().session_id().0 as u32, }) } @@ -125,6 +125,6 @@ impl ExprRewritable for BatchUpdate { impl ExprVisitable for BatchUpdate { fn visit_exprs(&self, v: &mut dyn ExprVisitor) { - self.core.exprs.iter().for_each(|e| v.visit_expr(e)); + self.core.visit_exprs(v); } } diff --git a/src/frontend/src/optimizer/plan_node/generic/update.rs b/src/frontend/src/optimizer/plan_node/generic/update.rs index 61d044f53c998..d68af1a01ae3f 100644 --- a/src/frontend/src/optimizer/plan_node/generic/update.rs +++ b/src/frontend/src/optimizer/plan_node/generic/update.rs @@ -21,7 +21,7 @@ use risingwave_common::types::DataType; use super::{DistillUnit, GenericPlanNode, GenericPlanRef}; use crate::catalog::TableId; -use crate::expr::{ExprImpl, ExprRewriter}; +use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::utils::childless_record; use crate::optimizer::property::FunctionalDependencySet; use crate::OptimizerContextRef; @@ -35,15 +35,15 @@ pub struct Update { pub table_id: TableId, pub table_version_id: TableVersionId, pub input: PlanRef, - pub exprs: Vec, + pub old_exprs: Vec, + pub new_exprs: Vec, pub returning: bool, - pub update_column_indices: Vec, } impl Update { pub fn output_len(&self) -> usize { if self.returning { - self.input.schema().len() + self.new_exprs.len() } else { 1 } @@ -56,18 +56,19 @@ impl GenericPlanNode for Update { fn schema(&self) -> Schema { if self.returning { - self.input.schema().clone() + Schema::new( + self.new_exprs + .iter() + .map(|e| Field::unnamed(e.return_type())) + .collect(), + ) } else { Schema::new(vec![Field::unnamed(DataType::Int64)]) } } fn stream_key(&self) -> Option> { - if self.returning { - Some(self.input.stream_key()?.to_vec()) - } else { - Some(vec![]) - } + None } fn ctx(&self) -> OptimizerContextRef { @@ -81,27 +82,31 @@ impl Update { table_name: String, table_id: TableId, table_version_id: TableVersionId, - exprs: Vec, + old_exprs: Vec, + new_exprs: Vec, returning: bool, - update_column_indices: Vec, ) -> Self { Self { table_name, table_id, table_version_id, input, - exprs, + old_exprs, + new_exprs, returning, - update_column_indices, } } pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) { - self.exprs = self - .exprs - .iter() - .map(|e| r.rewrite_expr(e.clone())) - .collect(); + for exprs in [&mut self.old_exprs, &mut self.new_exprs] { + *exprs = exprs.iter().map(|e| r.rewrite_expr(e.clone())).collect(); + } + } + + pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) { + for exprs in [&self.old_exprs, &self.new_exprs] { + exprs.iter().for_each(|e| v.visit_expr(e)); + } } } @@ -109,7 +114,7 @@ impl DistillUnit for Update { fn distill_with_name<'a>(&self, name: impl Into>) -> XmlNode<'a> { let mut vec = Vec::with_capacity(if self.returning { 3 } else { 2 }); vec.push(("table", Pretty::from(self.table_name.clone()))); - vec.push(("exprs", Pretty::debug(&self.exprs))); + vec.push(("exprs", Pretty::debug(&self.new_exprs))); if self.returning { vec.push(("returning", Pretty::display(&true))); } diff --git a/src/frontend/src/optimizer/plan_node/logical_update.rs b/src/frontend/src/optimizer/plan_node/logical_update.rs index 127b6ed8b317b..a5590501715b9 100644 --- a/src/frontend/src/optimizer/plan_node/logical_update.rs +++ b/src/frontend/src/optimizer/plan_node/logical_update.rs @@ -12,17 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::catalog::TableVersionId; - use super::generic::GenericPlanRef; use super::utils::impl_distill_by_unit; use super::{ gen_filter_and_pushdown, generic, BatchUpdate, ColPrunable, ExprRewritable, Logical, LogicalProject, PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown, ToBatch, ToStream, }; -use crate::catalog::TableId; use crate::error::Result; -use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor}; +use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::{ ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext, @@ -46,25 +43,6 @@ impl From> for LogicalUpdate { } } -impl LogicalUpdate { - #[must_use] - pub fn table_id(&self) -> TableId { - self.core.table_id - } - - pub fn exprs(&self) -> &[ExprImpl] { - self.core.exprs.as_ref() - } - - pub fn has_returning(&self) -> bool { - self.core.returning - } - - pub fn table_version_id(&self) -> TableVersionId { - self.core.table_version_id - } -} - impl PlanTreeNodeUnary for LogicalUpdate { fn input(&self) -> PlanRef { self.core.input.clone() @@ -86,15 +64,15 @@ impl ExprRewritable for LogicalUpdate { } fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { - let mut new = self.core.clone(); - new.exprs = new.exprs.into_iter().map(|e| r.rewrite_expr(e)).collect(); - Self::from(new).into() + let mut core = self.core.clone(); + core.rewrite_exprs(r); + Self::from(core).into() } } impl ExprVisitable for LogicalUpdate { fn visit_exprs(&self, v: &mut dyn ExprVisitor) { - self.core.exprs.iter().for_each(|e| v.visit_expr(e)); + self.core.visit_exprs(v); } } diff --git a/src/frontend/src/planner/select.rs b/src/frontend/src/planner/select.rs index a9e7dd3526ed1..ebed01351f7d1 100644 --- a/src/frontend/src/planner/select.rs +++ b/src/frontend/src/planner/select.rs @@ -320,7 +320,7 @@ impl Planner { /// /// The [`InputRef`]s' indexes start from `root.schema().len()`, /// which means they are additional columns beyond the original `root`. - fn substitute_subqueries( + pub(super) fn substitute_subqueries( &mut self, mut root: PlanRef, mut exprs: Vec, @@ -366,10 +366,27 @@ impl Planner { .zip_eq_fast(rewriter.correlated_indices_collection) .zip_eq_fast(rewriter.correlated_ids) { + let return_type = subquery.return_type(); let subroot = self.plan_query(subquery.query)?; let right = match subquery.kind { SubqueryKind::Scalar => subroot.into_unordered_subplan(), + SubqueryKind::UpdateSet => { + let plan = subroot.into_unordered_subplan(); + + // Compose all input columns into a struct with `ROW` function. + let all_input_refs = plan + .schema() + .data_types() + .into_iter() + .enumerate() + .map(|(i, data_type)| InputRef::new(i, data_type).into()) + .collect::>(); + let call = + FunctionCall::new_unchecked(ExprType::Row, all_input_refs, return_type); + + LogicalProject::create(plan, vec![call.into()]) + } SubqueryKind::Existential => { self.create_exists(subroot.into_unordered_subplan())? } diff --git a/src/frontend/src/planner/update.rs b/src/frontend/src/planner/update.rs index ddf9ab0bdf9ae..2db18ac0e2924 100644 --- a/src/frontend/src/planner/update.rs +++ b/src/frontend/src/planner/update.rs @@ -13,41 +13,92 @@ // limitations under the License. use fixedbitset::FixedBitSet; -use itertools::Itertools; +use risingwave_common::types::{DataType, Scalar}; +use risingwave_pb::expr::expr_node::Type; use super::Planner; -use crate::binder::BoundUpdate; +use crate::binder::{BoundUpdate, UpdateProject}; use crate::error::Result; +use crate::expr::{ExprImpl, FunctionCall, InputRef, Literal}; +use crate::optimizer::plan_node::generic::GenericPlanRef; use crate::optimizer::plan_node::{generic, LogicalProject, LogicalUpdate}; use crate::optimizer::property::{Order, RequiredDist}; use crate::optimizer::{PlanRef, PlanRoot}; impl Planner { pub(super) fn plan_update(&mut self, update: BoundUpdate) -> Result { + let returning = !update.returning_list.is_empty(); + let scan = self.plan_base_table(&update.table)?; let input = if let Some(expr) = update.selection { self.plan_where(scan, expr)? } else { scan }; - let returning = !update.returning_list.is_empty(); - let update_column_indices = update - .table - .table_catalog - .columns() - .iter() - .enumerate() - .filter_map(|(i, c)| c.can_dml().then_some(i)) - .collect_vec(); + let old_schema_len = input.schema().len(); + + // Extend table scan with updated columns. + let with_new: PlanRef = { + let mut plan = input; + + let mut exprs: Vec = plan + .schema() + .data_types() + .into_iter() + .enumerate() + .map(|(index, data_type)| InputRef::new(index, data_type).into()) + .collect(); + + exprs.extend(update.exprs); + + // Substitute subqueries into `LogicalApply`s. + if exprs.iter().any(|e| e.has_subquery()) { + (plan, exprs) = self.substitute_subqueries(plan, exprs)?; + } + + LogicalProject::new(plan, exprs).into() + }; + + let mut olds = Vec::new(); + let mut news = Vec::new(); + + for (i, col) in update.table.table_catalog.columns().iter().enumerate() { + // Skip generated columns and system columns. + if !col.can_dml() { + continue; + } + let data_type = col.data_type(); + + let old: ExprImpl = InputRef::new(i, data_type.clone()).into(); + + let new: ExprImpl = match (update.projects.get(&i)).map(|p| p.offset(old_schema_len)) { + Some(UpdateProject::Simple(j)) => InputRef::new(j, data_type.clone()).into(), + Some(UpdateProject::Composite(j, field)) => FunctionCall::new_unchecked( + Type::Field, + vec![ + InputRef::new(j, with_new.schema().data_types()[j].clone()).into(), // struct + Literal::new(Some((field as i32).to_scalar_value()), DataType::Int32) + .into(), + ], + data_type.clone(), + ) + .into(), + + None => old.clone(), + }; + + olds.push(old); + news.push(new); + } let mut plan: PlanRef = LogicalUpdate::from(generic::Update::new( - input, + with_new, update.table_name.clone(), update.table_id, update.table_version_id, - update.exprs, + olds, + news, returning, - update_column_indices, )) .into();