Skip to content

Commit

Permalink
refactor(frontend): rework UPDATE to support subqueries (#19402)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Nov 20, 2024
1 parent c1162ab commit 5f1a59b
Show file tree
Hide file tree
Showing 17 changed files with 589 additions and 283 deletions.
File renamed without changes.
132 changes: 132 additions & 0 deletions e2e_test/batch/basic/dml_update.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Extension to `dml_basic.slt.part` for testing advanced `UPDATE` statements.

statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t (v1 int default 1919, v2 int default 810);

statement ok
insert into t values (114, 514);


# Single assignment, to subquery.
statement ok
update t set v1 = (select 666);

query II
select * from t;
----
666 514

# Single assignment, to runtime-cardinality subquery returning 1 row.
statement ok
update t set v1 = (select generate_series(888, 888));

query II
select * from t;
----
888 514

# Single assignment, to runtime-cardinality subquery returning 0 rows (set to NULL).
statement ok
update t set v1 = (select generate_series(1, 0));

query II
select * from t;
----
NULL 514

# Single assignment, to runtime-cardinality subquery returning multiple rows.
statement error Scalar subquery produced more than one row
update t set v1 = (select generate_series(1, 2));

# Single assignment, to correlated subquery.
statement ok
update t set v1 = (select count(*) from t as source where source.v2 = t.v2);

query II
select * from t;
----
1 514

# Single assignment, to subquery with mismatched column count.
statement error must return only one column
update t set v1 = (select 666, 888);


# Multiple assignment clauses.
statement ok
update t set v1 = 1919, v2 = 810;

query II
select * from t;
----
1919 810

# Multiple assignments to the same column.
statement error multiple assignments to the same column
update t set v1 = 1, v1 = 2;

statement error multiple assignments to the same column
update t set (v1, v1) = (1, 2);

statement error multiple assignments to the same column
update t set (v1, v2) = (1, 2), v2 = 2;

# Multiple assignments, to subquery.
statement ok
update t set (v1, v2) = (select 666, 888);

query II
select * from t;
----
666 888

# Multiple assignments, to subquery with cast.
statement ok
update t set (v1, v2) = (select 888.88, 999);

query II
select * from t;
----
889 999

# Multiple assignments, to subquery with cast failure.
# TODO: this currently shows `cannot cast type "record" to "record"` because we wrap the subquery result
# into a struct, which is not quite clear.
statement error cannot cast type
update t set (v1, v2) = (select '888.88', 999);

# Multiple assignments, to subquery with mismatched column count.
statement error number of columns does not match number of values
update t set (v1, v2) = (select 666);

# Multiple assignments, to scalar expression.
statement error source for a multiple-column UPDATE item must be a sub-SELECT or ROW\(\) expression
update t set (v1, v2) = v1 + 1;


# Assignment to system columns.
statement error update modifying column `_rw_timestamp` is unsupported
update t set _rw_timestamp = _rw_timestamp + interval '1 second';


# https://github.com/risingwavelabs/risingwave/pull/19402#pullrequestreview-2444427475
# https://github.com/risingwavelabs/risingwave/pull/19452
statement ok
create table y (v1 int, v2 int);

statement ok
insert into y values (11, 11), (22, 22);

statement error Scalar subquery produced more than one row
update t set (v1, v2) = (select y.v1, y.v2 from y);

statement ok
drop table y;


# Cleanup.
statement ok
drop table t;
11 changes: 6 additions & 5 deletions proto/batch_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ message UpdateNode {
// Id of the table to perform updating.
uint32 table_id = 1;
// Version of the table.
uint64 table_version_id = 4;
repeated expr.ExprNode exprs = 2;
bool returning = 3;
// The columns indices in the input schema, representing the columns need to send to streamDML exeuctor.
repeated uint32 update_column_indices = 5;
uint64 table_version_id = 2;
// Expressions to generate `U-` records.
repeated expr.ExprNode old_exprs = 3;
// Expressions to generate `U+` records.
repeated expr.ExprNode new_exprs = 4;
bool returning = 5;

// Session id is used to ensure that dml data from the same session should be sent to a fixed worker node and channel.
uint32 session_id = 6;
Expand Down
66 changes: 36 additions & 30 deletions src/batch/src/executor/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ pub struct UpdateExecutor {
table_version_id: TableVersionId,
dml_manager: DmlManagerRef,
child: BoxedExecutor,
exprs: Vec<BoxedExpression>,
old_exprs: Vec<BoxedExpression>,
new_exprs: Vec<BoxedExpression>,
chunk_size: usize,
schema: Schema,
identity: String,
returning: bool,
txn_id: TxnId,
update_column_indices: Vec<usize>,
session_id: u32,
}

Expand All @@ -59,11 +59,11 @@ impl UpdateExecutor {
table_version_id: TableVersionId,
dml_manager: DmlManagerRef,
child: BoxedExecutor,
exprs: Vec<BoxedExpression>,
old_exprs: Vec<BoxedExpression>,
new_exprs: Vec<BoxedExpression>,
chunk_size: usize,
identity: String,
returning: bool,
update_column_indices: Vec<usize>,
session_id: u32,
) -> Self {
let chunk_size = chunk_size.next_multiple_of(2);
Expand All @@ -75,7 +75,8 @@ impl UpdateExecutor {
table_version_id,
dml_manager,
child,
exprs,
old_exprs,
new_exprs,
chunk_size,
schema: if returning {
table_schema
Expand All @@ -87,7 +88,6 @@ impl UpdateExecutor {
identity,
returning,
txn_id,
update_column_indices,
session_id,
}
}
Expand All @@ -109,7 +109,7 @@ impl Executor for UpdateExecutor {

impl UpdateExecutor {
#[try_stream(boxed, ok = DataChunk, error = BatchError)]
async fn do_execute(mut self: Box<Self>) {
async fn do_execute(self: Box<Self>) {
let table_dml_handle = self
.dml_manager
.table_dml_handle(self.table_id, self.table_version_id)?;
Expand All @@ -122,15 +122,12 @@ impl UpdateExecutor {

assert_eq!(
data_types,
self.exprs.iter().map(|e| e.return_type()).collect_vec(),
self.new_exprs.iter().map(|e| e.return_type()).collect_vec(),
"bad update schema"
);
assert_eq!(
data_types,
self.update_column_indices
.iter()
.map(|i: &usize| self.child.schema()[*i].data_type.clone())
.collect_vec(),
self.old_exprs.iter().map(|e| e.return_type()).collect_vec(),
"bad update schema"
);

Expand Down Expand Up @@ -159,27 +156,35 @@ impl UpdateExecutor {
let mut rows_updated = 0;

#[for_await]
for data_chunk in self.child.execute() {
let data_chunk = data_chunk?;
for input in self.child.execute() {
let input = input?;

let old_data_chunk = {
let mut columns = Vec::with_capacity(self.old_exprs.len());
for expr in &self.old_exprs {
let column = expr.eval(&input).await?;
columns.push(column);
}

DataChunk::new(columns, input.visibility().clone())
};

let updated_data_chunk = {
let mut columns = Vec::with_capacity(self.exprs.len());
for expr in &mut self.exprs {
let column = expr.eval(&data_chunk).await?;
let mut columns = Vec::with_capacity(self.new_exprs.len());
for expr in &self.new_exprs {
let column = expr.eval(&input).await?;
columns.push(column);
}

DataChunk::new(columns, data_chunk.visibility().clone())
DataChunk::new(columns, input.visibility().clone())
};

if self.returning {
yield updated_data_chunk.clone();
}

for (row_delete, row_insert) in data_chunk
.project(&self.update_column_indices)
.rows()
.zip_eq_debug(updated_data_chunk.rows())
for (row_delete, row_insert) in
(old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows())
{
rows_updated += 1;
// If row_delete == row_insert, we don't need to do a actual update
Expand Down Expand Up @@ -227,34 +232,35 @@ impl BoxedExecutorBuilder for UpdateExecutor {

let table_id = TableId::new(update_node.table_id);

let exprs: Vec<_> = update_node
.get_exprs()
let old_exprs: Vec<_> = update_node
.get_old_exprs()
.iter()
.map(build_from_prost)
.try_collect()?;

let update_column_indices = update_node
.update_column_indices
let new_exprs: Vec<_> = update_node
.get_new_exprs()
.iter()
.map(|x| *x as usize)
.collect_vec();
.map(build_from_prost)
.try_collect()?;

Ok(Box::new(Self::new(
table_id,
update_node.table_version_id,
source.context().dml_manager(),
child,
exprs,
old_exprs,
new_exprs,
source.context.get_config().developer.chunk_size,
source.plan_node().get_identity().clone(),
update_node.returning,
update_column_indices,
update_node.session_id,
)))
}
}

#[cfg(test)]
#[cfg(any())]
mod tests {
use std::sync::Arc;

Expand Down
29 changes: 23 additions & 6 deletions src/frontend/planner_test/tests/testdata/input/update.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
update t set v2 = 3;
expected_outputs:
- binder_error
- name: update subquery
- name: update subquery selection
sql: |
create table t (a int, b int);
update t set a = 777 where b not in (select a from t);
Expand All @@ -98,10 +98,27 @@
update t set a = a + 1;
expected_outputs:
- batch_distributed_plan
- name: update table with subquery in the set clause
- name: update table to subquery
sql: |
create table t1 (v1 int primary key, v2 int);
create table t2 (v1 int primary key, v2 int);
update t1 set v1 = (select v1 from t2 where t1.v2 = t2.v2);
create table t (v1 int, v2 int);
update t set v1 = (select 666);
expected_outputs:
- batch_plan
- name: update table to subquery with runtime cardinality
sql: |
create table t (v1 int, v2 int);
update t set v1 = (select generate_series(888, 888));
expected_outputs:
- batch_plan
- name: update table to correlated subquery
sql: |
create table t (v1 int, v2 int);
update t set v1 = (select count(*) from t as source where source.v2 = t.v2);
expected_outputs:
- binder_error
- batch_plan
- name: update table to subquery with multiple assignments
sql: |
create table t (v1 int, v2 int);
update t set (v1, v2) = (select 666.66, 777);
expected_outputs:
- batch_plan
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,18 @@
update t1 set c = 3 where a = 1 and b = 2;
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchUpdate { table: t1, exprs: [$0, $1, 3:Int64, $3] }
└─BatchUpdate { table: t1, exprs: [$0, $1, $5, $3] }
└─BatchExchange { order: [], dist: Single }
└─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 }
└─BatchExchange { order: [], dist: UpstreamHashShard(idx2.t1._row_id) }
└─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard }
└─BatchProject { exprs: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp, 3:Int64] }
└─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 }
└─BatchExchange { order: [], dist: UpstreamHashShard(idx2.t1._row_id) }
└─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard }
batch_local_plan: |-
BatchUpdate { table: t1, exprs: [$0, $1, 3:Int64, $3] }
└─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard }
BatchUpdate { table: t1, exprs: [$0, $1, $5, $3] }
└─BatchProject { exprs: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp, 3:Int64] }
└─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard }
- sql: |
create table t1 (a int, b numeric, c bigint, p int);
create materialized view v as select count(*) as cnt, p from t1 group by p;
Expand Down
Loading

0 comments on commit 5f1a59b

Please sign in to comment.