From 185243618b23fea634237a70a438e3c3c7da7e2d Mon Sep 17 00:00:00 2001 From: Shanicky Chen Date: Wed, 20 Nov 2024 14:53:13 +0800 Subject: [PATCH 01/11] chore: more logs for offline scaling (#19407) Signed-off-by: Shanicky Chen --- src/meta/src/barrier/context/recovery.rs | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/meta/src/barrier/context/recovery.rs b/src/meta/src/barrier/context/recovery.rs index 83e7cd13919a..344d8af0a1e6 100644 --- a/src/meta/src/barrier/context/recovery.rs +++ b/src/meta/src/barrier/context/recovery.rs @@ -142,6 +142,12 @@ impl GlobalBarrierWorkerContextImpl { .list_background_creating_jobs() .await?; + info!( + "background streaming jobs: {:?} total {}", + background_streaming_jobs, + background_streaming_jobs.len() + ); + // Resolve actor info for recovery. If there's no actor to recover, most of the // following steps will be no-op, while the compute nodes will still be reset. // FIXME: Transactions should be used. @@ -149,6 +155,7 @@ impl GlobalBarrierWorkerContextImpl { let mut info = if !self.env.opts.disable_automatic_parallelism_control && background_streaming_jobs.is_empty() { + info!("trigger offline scaling"); self.scale_actors(&active_streaming_nodes) .await .inspect_err(|err| { @@ -159,6 +166,7 @@ impl GlobalBarrierWorkerContextImpl { warn!(error = %err.as_report(), "resolve actor info failed"); })? } else { + info!("trigger actor migration"); // Migrate actors in expired CN to newly joined one. self.migrate_actors(&mut active_streaming_nodes) .await @@ -376,7 +384,7 @@ impl GlobalBarrierWorkerContextImpl { mgr.catalog_controller.migrate_actors(plan).await?; - debug!("migrate actors succeed."); + info!("migrate actors succeed."); self.resolve_graph_info().await } @@ -447,6 +455,11 @@ impl GlobalBarrierWorkerContextImpl { result }; + info!( + "target table parallelisms for offline scaling: {:?}", + table_parallelisms + ); + let schedulable_worker_ids = active_nodes .current() .values() @@ -460,6 +473,11 @@ impl GlobalBarrierWorkerContextImpl { .map(|worker| worker.id as WorkerId) .collect(); + info!( + "target worker ids for offline scaling: {:?}", + schedulable_worker_ids + ); + let plan = self .scale_controller .generate_table_resize_plan(TableResizePolicy { @@ -497,6 +515,8 @@ impl GlobalBarrierWorkerContextImpl { // Because custom parallelism doesn't exist, this function won't result in a no-shuffle rewrite for table parallelisms. debug_assert_eq!(compared_table_parallelisms, table_parallelisms); + info!("post applying reschedule for offline scaling"); + if let Err(e) = self .scale_controller .post_apply_reschedule(&reschedule_fragment, &table_parallelisms) @@ -510,7 +530,7 @@ impl GlobalBarrierWorkerContextImpl { return Err(e); } - debug!("scaling actors succeed."); + info!("scaling actors succeed."); Ok(()) } From b37f048e33722fc922d8bd3adce6c45fdeae42af Mon Sep 17 00:00:00 2001 From: Shanicky Chen Date: Wed, 20 Nov 2024 14:58:28 +0800 Subject: [PATCH 02/11] fix: Ensure non-negative variance in stddev calculations (#19448) Signed-off-by: Shanicky Chen --- .../tests/testdata/output/agg.yaml | 41 ++++++++--------- .../tests/testdata/output/cse_expr.yaml | 30 +++++-------- .../testdata/output/over_window_function.yaml | 38 ++++++++-------- .../src/optimizer/plan_node/logical_agg.rs | 44 ++++++++++--------- 4 files changed, 73 insertions(+), 80 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index 9a6b00360374..80a4de55b6f4 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1299,39 +1299,36 @@ create table t (v1 int); select stddev_samp(v1), stddev_pop(v1) from t; logical_plan: |- - LogicalProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr2, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / count(t.v1)::Decimal)) as $expr3] } + LogicalProject { exprs: [Case((count(t.v1) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Int32::Decimal) / (count(t.v1) - 1:Int32)::Decimal))) as $expr2, Case((count(t.v1) = 0:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Int32::Decimal) / count(t.v1)::Decimal))) as $expr3] } └─LogicalAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } └─LogicalProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } └─LogicalScan { table: t, columns: [t.v1, t._row_id, t._rw_timestamp] } batch_plan: |- - BatchProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int64), null:Decimal, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / (sum0(count(t.v1)) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] } - └─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1))::Decimal as $expr2, sum0(count(t.v1))::Decimal as $expr3] } - └─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1))] } - └─BatchExchange { order: [], dist: Single } - └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } - └─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } - └─BatchScan { table: t, columns: [t.v1], distribution: SomeShard } - batch_local_plan: |- - BatchProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] } - └─BatchProject { exprs: [sum($expr1), sum(t.v1), count(t.v1), sum(t.v1)::Decimal as $expr2, count(t.v1)::Decimal as $expr3] } - └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } - └─BatchExchange { order: [], dist: Single } + BatchProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / (sum0(count(t.v1)) - 1:Int32)::Decimal))) as $expr2, Case((sum0(count(t.v1)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / sum0(count(t.v1))::Decimal))) as $expr3] } + └─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1))] } + └─BatchExchange { order: [], dist: Single } + └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } └─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } └─BatchScan { table: t, columns: [t.v1], distribution: SomeShard } + batch_local_plan: |- + BatchProject { exprs: [Case((count(t.v1) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Decimal) / (count(t.v1) - 1:Int32)::Decimal))) as $expr2, Case((count(t.v1) = 0:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Decimal) / count(t.v1)::Decimal))) as $expr3] } + └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] } + └─BatchScan { table: t, columns: [t.v1], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [stddev_samp, stddev_pop], stream_key: [], pk_columns: [], pk_conflict: NoCheck } - └─StreamProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int64), null:Decimal, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / (sum0(count(t.v1)) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] } - └─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1))::Decimal as $expr2, sum0(count(t.v1))::Decimal as $expr3] } - └─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), count] } - └─StreamExchange { dist: Single } - └─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } - └─StreamProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1, t._row_id] } - └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / (sum0(count(t.v1)) - 1:Int32)::Decimal))) as $expr2, Case((sum0(count(t.v1)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / sum0(count(t.v1))::Decimal))) as $expr3] } + └─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] } + └─StreamProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - name: stddev_samp with other columns sql: | select count(''), stddev_samp(1); logical_plan: |- - LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32)::Decimal)) / (count(1:Int32) - 1:Int64)::Decimal))) as $expr2] } + LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32)::Decimal)), 0:Int32::Decimal) / (count(1:Int32) - 1:Int32)::Decimal))) as $expr2] } └─LogicalAgg { aggs: [count('':Varchar), sum($expr1), sum(1:Int32), count(1:Int32)] } └─LogicalProject { exprs: ['':Varchar, (1:Int32 * 1:Int32) as $expr1, 1:Int32] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } @@ -1340,7 +1337,7 @@ create table t(v int, w float); select stddev_samp(v) from t group by w; logical_plan: |- - LogicalProject { exprs: [Case((count(t.v) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v)::Decimal)) / (count(t.v) - 1:Int64)::Decimal))) as $expr2] } + LogicalProject { exprs: [Case((count(t.v) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v)::Decimal)), 0:Int32::Decimal) / (count(t.v) - 1:Int32)::Decimal))) as $expr2] } └─LogicalAgg { group_key: [t.w], aggs: [sum($expr1), sum(t.v), count(t.v)] } └─LogicalProject { exprs: [t.w, (t.v * t.v) as $expr1, t.v] } └─LogicalScan { table: t, columns: [t.v, t.w, t._row_id, t._rw_timestamp] } diff --git a/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml b/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml index 9372c837324f..4eb6a85e5eb8 100644 --- a/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/cse_expr.yaml @@ -60,26 +60,20 @@ create table t(v int); select stddev_pop(v), stddev_samp(v), var_pop(v), var_samp(v) from t; batch_plan: |- - BatchProject { exprs: [Sqrt($expr5) as $expr6, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, Sqrt(($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal))) as $expr7, $expr5, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, ($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal)) as $expr8] } - └─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), ($expr4 / $expr3) as $expr5, $expr4] } - └─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), (sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) as $expr4, $expr3] } - └─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), sum(sum(t.v))::Decimal as $expr2, sum0(count(t.v))::Decimal as $expr3] } - └─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v))] } - └─BatchExchange { order: [], dist: Single } - └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] } - └─BatchProject { exprs: [(t.v * t.v) as $expr1, t.v] } - └─BatchScan { table: t, columns: [t.v], distribution: SomeShard } + BatchProject { exprs: [Case((sum0(count(t.v)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal))) as $expr2, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal))) as $expr3, Case((sum0(count(t.v)) = 0:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal)) as $expr4, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal)) as $expr5] } + └─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v))] } + └─BatchExchange { order: [], dist: Single } + └─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] } + └─BatchProject { exprs: [(t.v * t.v) as $expr1, t.v] } + └─BatchScan { table: t, columns: [t.v], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [stddev_pop, stddev_samp, var_pop, var_samp], stream_key: [], pk_columns: [], pk_conflict: NoCheck } - └─StreamProject { exprs: [Sqrt($expr5) as $expr6, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, Sqrt(($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal))) as $expr7, $expr5, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, ($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal)) as $expr8] } - └─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), ($expr4 / $expr3) as $expr5, $expr4] } - └─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), (sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) as $expr4, $expr3] } - └─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), sum(sum(t.v))::Decimal as $expr2, sum0(count(t.v))::Decimal as $expr3] } - └─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), count] } - └─StreamExchange { dist: Single } - └─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] } - └─StreamProject { exprs: [(t.v * t.v) as $expr1, t.v, t._row_id] } - └─StreamTableScan { table: t, columns: [t.v, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamProject { exprs: [Case((sum0(count(t.v)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal))) as $expr2, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal))) as $expr3, Case((sum0(count(t.v)) = 0:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal)) as $expr4, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal)) as $expr5] } + └─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), count] } + └─StreamExchange { dist: Single } + └─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] } + └─StreamProject { exprs: [(t.v * t.v) as $expr1, t.v, t._row_id] } + └─StreamTableScan { table: t, columns: [t.v, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - name: Common sub expression shouldn't extract partial expression of `some`/`all`. See 11766 sql: | with t(v, arr) as (select 1, array[2, 3]) select v < all(arr), v < some(arr) from t; diff --git a/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml b/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml index 114d9fead0f3..a6a2c284beb0 100644 --- a/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml +++ b/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml @@ -325,33 +325,31 @@ logical_plan: |- LogicalProject { exprs: [t.x, t.y, t.z, $expr4, $expr5] } └─LogicalFilter { predicate: (t.z > 0:Int32) AND (t.y > 0:Int32) AND (t.x > 0:Int32) AND ($expr4 <= 3.0:Decimal) AND ($expr5 > 1.0:Decimal) } - └─LogicalProject { exprs: [t.x, t.y, t.z, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) as $expr4, Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) as $expr5] } + └─LogicalProject { exprs: [t.x, t.y, t.z, Case((count = 0:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Int32::Decimal) / count::Decimal))) as $expr4, Case((count <= 1:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Int32::Decimal) / (count - 1:Int32)::Decimal))) as $expr5] } └─LogicalOverWindow { window_functions: [sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } └─LogicalProject { exprs: [t.x, t.y, t.z, t.w, t._row_id, t._rw_timestamp, ((t.x - t.y) * (t.x - t.y)) as $expr1, (t.x - t.y) as $expr2, (t.x * t.x) as $expr3] } └─LogicalScan { table: t, columns: [t.x, t.y, t.z, t.w, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchProject { exprs: [t.x, t.y, t.z, Sqrt(((sum::Decimal - (($expr4 * $expr4) / $expr5)) / $expr5)) as $expr6, Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) as $expr7] } - └─BatchProject { exprs: [t.x, t.y, t.z, $expr2, $expr1, $expr3, sum, sum, count, sum, sum, count, sum::Decimal as $expr4, count::Decimal as $expr5] } - └─BatchFilter { predicate: (t.y > 0:Int32) AND (t.x > 0:Int32) AND (Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) <= 3.0:Decimal) AND (Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) > 1.0:Decimal) } - └─BatchOverWindow { window_functions: [sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } - └─BatchExchange { order: [t.z ASC, t.x ASC], dist: HashShard(t.z) } - └─BatchSort { order: [t.z ASC, t.x ASC] } - └─BatchProject { exprs: [t.x, t.y, t.z, ($expr1 * $expr1) as $expr2, $expr1, (t.x * t.x) as $expr3] } - └─BatchProject { exprs: [t.x, t.y, t.z, (t.x - t.y) as $expr1] } - └─BatchFilter { predicate: (t.z > 0:Int32) } - └─BatchScan { table: t, columns: [t.x, t.y, t.z], distribution: SomeShard } + └─BatchProject { exprs: [t.x, t.y, t.z, Case((count = 0:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / count::Decimal))) as $expr4, Case((count <= 1:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / (count - 1:Int32)::Decimal))) as $expr5] } + └─BatchFilter { predicate: (t.y > 0:Int32) AND (t.x > 0:Int32) AND (Case((count = 0:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / count::Decimal))) <= 3.0:Decimal) AND (Case((count <= 1:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / (count - 1:Int32)::Decimal))) > 1.0:Decimal) } + └─BatchOverWindow { window_functions: [sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─BatchExchange { order: [t.z ASC, t.x ASC], dist: HashShard(t.z) } + └─BatchSort { order: [t.z ASC, t.x ASC] } + └─BatchProject { exprs: [t.x, t.y, t.z, ($expr1 * $expr1) as $expr2, $expr1, (t.x * t.x) as $expr3] } + └─BatchProject { exprs: [t.x, t.y, t.z, (t.x - t.y) as $expr1] } + └─BatchFilter { predicate: (t.z > 0:Int32) } + └─BatchScan { table: t, columns: [t.x, t.y, t.z], distribution: SomeShard } stream_plan: |- StreamMaterialize { columns: [x, y, z, res0, res1, t._row_id(hidden)], stream_key: [t._row_id, z], pk_columns: [t._row_id, z], pk_conflict: NoCheck } - └─StreamProject { exprs: [t.x, t.y, t.z, Sqrt(((sum::Decimal - (($expr4 * $expr4) / $expr5)) / $expr5)) as $expr6, Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) as $expr7, t._row_id] } - └─StreamProject { exprs: [t.x, t.y, t.z, $expr2, $expr1, $expr3, sum, sum, count, sum, sum, count, sum::Decimal as $expr4, count::Decimal as $expr5, t._row_id] } - └─StreamFilter { predicate: (t.y > 0:Int32) AND (t.x > 0:Int32) AND (Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) <= 3.0:Decimal) AND (Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) > 1.0:Decimal) } - └─StreamOverWindow { window_functions: [sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } - └─StreamExchange { dist: HashShard(t.z) } - └─StreamProject { exprs: [t.x, t.y, t.z, ($expr1 * $expr1) as $expr2, $expr1, (t.x * t.x) as $expr3, t._row_id] } - └─StreamProject { exprs: [t.x, t.y, t.z, (t.x - t.y) as $expr1, t._row_id] } - └─StreamFilter { predicate: (t.z > 0:Int32) } - └─StreamTableScan { table: t, columns: [t.x, t.y, t.z, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamProject { exprs: [t.x, t.y, t.z, Case((count = 0:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / count::Decimal))) as $expr4, Case((count <= 1:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / (count - 1:Int32)::Decimal))) as $expr5, t._row_id] } + └─StreamFilter { predicate: (t.y > 0:Int32) AND (t.x > 0:Int32) AND (Case((count = 0:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / count::Decimal))) <= 3.0:Decimal) AND (Case((count <= 1:Int32), null:Decimal, Sqrt((Greatest((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)), 0:Decimal) / (count - 1:Int32)::Decimal))) > 1.0:Decimal) } + └─StreamOverWindow { window_functions: [sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─StreamExchange { dist: HashShard(t.z) } + └─StreamProject { exprs: [t.x, t.y, t.z, ($expr1 * $expr1) as $expr2, $expr1, (t.x * t.x) as $expr3, t._row_id] } + └─StreamProject { exprs: [t.x, t.y, t.z, (t.x - t.y) as $expr1, t._row_id] } + └─StreamFilter { predicate: (t.z > 0:Int32) } + └─StreamTableScan { table: t, columns: [t.x, t.y, t.z, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } - id: aggregate with expression in func arguments and over clause sql: | create table t(x int, y int, z int, w int); diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 7f2b52797924..5a146c37398a 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -14,7 +14,7 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; -use risingwave_common::types::{DataType, Datum, ScalarImpl}; +use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::{bail, bail_not_implemented, not_implemented}; use risingwave_expr::aggregate::{agg_types, AggType, PbAggKind}; @@ -684,17 +684,15 @@ impl LogicalAggBuilder { agg_call.direct_args.clone(), )?)?); - let one = ExprImpl::from(Literal::new( - Datum::from(ScalarImpl::Int64(1)), - DataType::Int64, - )); + let zero = ExprImpl::literal_int(0); + let one = ExprImpl::literal_int(1); let squared_sum = ExprImpl::from(FunctionCall::new( ExprType::Multiply, vec![sum.clone(), sum], )?); - let numerator = ExprImpl::from(FunctionCall::new( + let raw_numerator = ExprImpl::from(FunctionCall::new( ExprType::Subtract, vec![ sum_of_sq, @@ -705,6 +703,13 @@ impl LogicalAggBuilder { ], )?); + // We need to check for potential accuracy issues that may occasionally lead to results less than 0. + let numerator_type = raw_numerator.return_type(); + let numerator = ExprImpl::from(FunctionCall::new( + ExprType::Greatest, + vec![raw_numerator, zero.clone().cast_explicit(numerator_type)?], + )?); + let denominator = match kind { PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(), PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from( @@ -722,22 +727,21 @@ impl LogicalAggBuilder { target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?); } - match kind { - PbAggKind::VarPop | PbAggKind::StddevPop => Ok(target), - PbAggKind::StddevSamp | PbAggKind::VarSamp => { - let case_cond = ExprImpl::from(FunctionCall::new( - ExprType::LessThanOrEqual, - vec![count, one], - )?); - let null = ExprImpl::from(Literal::new(None, agg_call.return_type())); - - Ok(ExprImpl::from(FunctionCall::new( - ExprType::Case, - vec![case_cond, null, target], - )?)) + let null = ExprImpl::from(Literal::new(None, agg_call.return_type())); + let case_cond = match kind { + PbAggKind::VarPop | PbAggKind::StddevPop => { + ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?) } + PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from( + FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?, + ), _ => unreachable!(), - } + }; + + Ok(ExprImpl::from(FunctionCall::new( + ExprType::Case, + vec![case_cond, null, target], + )?)) } AggType::Builtin(PbAggKind::ApproxPercentile) => { if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() { From c8e96b9634dc38bcd5bbea74e23aa7ef1974922c Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:32:25 +0800 Subject: [PATCH 03/11] refactor(barrier): decouple barrier collect and sync in local barrier manager (#19393) --- src/stream/src/task/barrier_manager.rs | 169 ++++++++- .../src/task/barrier_manager/managed_state.rs | 347 +++++------------- 2 files changed, 240 insertions(+), 276 deletions(-) diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index 93d107c073ed..fec0d74ab6d5 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashSet}; use std::fmt::Display; use std::future::pending; use std::iter::once; @@ -20,10 +20,13 @@ use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; -use futures::stream::BoxStream; -use futures::StreamExt; +use await_tree::InstrumentAwait; +use futures::future::BoxFuture; +use futures::stream::{BoxStream, FuturesOrdered}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use itertools::Itertools; use risingwave_common::error::tonic::extra::Score; +use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_pb::stream_service::barrier_complete_response::{ PbCreateMviewProgress, PbLocalSstableInfo, }; @@ -264,6 +267,9 @@ pub(super) struct LocalBarrierWorker { /// Current barrier collection state. pub(super) state: ManagedBarrierState, + /// Futures will be finished in the order of epoch in ascending order. + await_epoch_completed_futures: FuturesOrdered, + control_stream_handle: ControlStreamHandle, pub(super) actor_manager: Arc, @@ -295,6 +301,7 @@ impl LocalBarrierWorker { shared_context.clone(), initial_partial_graphs, ), + await_epoch_completed_futures: Default::default(), control_stream_handle: ControlStreamHandle::empty(), actor_manager, current_shared_context: shared_context, @@ -315,10 +322,17 @@ impl LocalBarrierWorker { loop { select! { biased; - (partial_graph_id, completed_epoch) = self.state.next_completed_epoch() => { - let result = self.on_epoch_completed(partial_graph_id, completed_epoch); - if let Err(err) = result { - self.notify_other_failure(err, "failed to complete epoch").await; + (partial_graph_id, barrier) = self.state.next_collected_epoch() => { + self.complete_barrier(partial_graph_id, barrier.epoch.prev); + } + (partial_graph_id, barrier, result) = rw_futures_util::pending_on_none(self.await_epoch_completed_futures.next()) => { + match result { + Ok(result) => { + self.on_epoch_completed(partial_graph_id, barrier.epoch.prev, result); + } + Err(err) => { + self.notify_other_failure(err, "failed to complete epoch").await; + } } }, event = self.barrier_event_rx.recv() => { @@ -453,23 +467,139 @@ impl LocalBarrierWorker { } } -// event handler +mod await_epoch_completed_future { + use std::future::Future; + + use futures::future::BoxFuture; + use futures::FutureExt; + use risingwave_hummock_sdk::SyncResult; + use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress; + + use crate::error::StreamResult; + use crate::executor::Barrier; + use crate::task::{await_tree_key, BarrierCompleteResult, PartialGraphId}; + + pub(super) type AwaitEpochCompletedFuture = impl Future)> + + 'static; + + pub(super) fn instrument_complete_barrier_future( + partial_graph_id: PartialGraphId, + complete_barrier_future: Option>>, + barrier: Barrier, + barrier_await_tree_reg: Option<&await_tree::Registry>, + create_mview_progress: Vec, + ) -> AwaitEpochCompletedFuture { + let prev_epoch = barrier.epoch.prev; + let future = async move { + if let Some(future) = complete_barrier_future { + let result = future.await; + result.map(Some) + } else { + Ok(None) + } + } + .map(move |result| { + ( + partial_graph_id, + barrier, + result.map(|sync_result| BarrierCompleteResult { + sync_result, + create_mview_progress, + }), + ) + }); + if let Some(reg) = barrier_await_tree_reg { + reg.register( + await_tree_key::BarrierAwait { prev_epoch }, + format!("SyncEpoch({})", prev_epoch), + ) + .instrument(future) + .left_future() + } else { + future.right_future() + } + } +} + +use await_epoch_completed_future::*; +use risingwave_common::catalog::TableId; +use risingwave_storage::StateStoreImpl; + +fn sync_epoch( + state_store: &StateStoreImpl, + streaming_metrics: &StreamingMetrics, + prev_epoch: u64, + table_ids: HashSet, +) -> BoxFuture<'static, StreamResult> { + let timer = streaming_metrics.barrier_sync_latency.start_timer(); + let hummock = state_store.as_hummock().cloned(); + let future = async move { + if let Some(hummock) = hummock { + hummock.sync(vec![(prev_epoch, table_ids)]).await + } else { + Ok(SyncResult::default()) + } + }; + future + .instrument_await(format!("sync_epoch (epoch {})", prev_epoch)) + .inspect_ok(move |_| { + timer.observe_duration(); + }) + .map_err(move |e| { + tracing::error!( + prev_epoch, + error = %e.as_report(), + "Failed to sync state store", + ); + e.into() + }) + .boxed() +} + impl LocalBarrierWorker { + fn complete_barrier(&mut self, partial_graph_id: PartialGraphId, prev_epoch: u64) { + { + let (barrier, table_ids, create_mview_progress) = self + .state + .pop_barrier_to_complete(partial_graph_id, prev_epoch); + + let complete_barrier_future = match &barrier.kind { + BarrierKind::Unspecified => unreachable!(), + BarrierKind::Initial => { + tracing::info!( + epoch = prev_epoch, + "ignore sealing data for the first barrier" + ); + tracing::info!(?prev_epoch, "ignored syncing data for the first barrier"); + None + } + BarrierKind::Barrier => None, + BarrierKind::Checkpoint => Some(sync_epoch( + &self.actor_manager.env.state_store(), + &self.actor_manager.streaming_metrics, + prev_epoch, + table_ids.expect("should be Some on BarrierKind::Checkpoint"), + )), + }; + + self.await_epoch_completed_futures.push_back({ + instrument_complete_barrier_future( + partial_graph_id, + complete_barrier_future, + barrier, + self.actor_manager.await_tree_reg.as_ref(), + create_mview_progress, + ) + }); + } + } + fn on_epoch_completed( &mut self, partial_graph_id: PartialGraphId, epoch: u64, - ) -> StreamResult<()> { - let state = self - .state - .graph_states - .get_mut(&partial_graph_id) - .expect("should exist"); - let result = state - .pop_completed_epoch(epoch) - .expect("should exist") - .expect("should have completed")?; - + result: BarrierCompleteResult, + ) { let BarrierCompleteResult { create_mview_progress, sync_result, @@ -523,7 +653,6 @@ impl LocalBarrierWorker { }; self.control_stream_handle.send_response(result); - Ok(()) } /// Broadcast a barrier to all senders. Save a receiver which will get notified when this diff --git a/src/stream/src/task/barrier_manager/managed_state.rs b/src/stream/src/task/barrier_manager/managed_state.rs index e555c0ff037f..bd5c92570f13 100644 --- a/src/stream/src/task/barrier_manager/managed_state.rs +++ b/src/stream/src/task/barrier_manager/managed_state.rs @@ -12,58 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::assert_matches::assert_matches; use std::cell::LazyCell; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; use std::future::{pending, poll_fn, Future}; use std::mem::replace; use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::Poll; -use anyhow::anyhow; -use await_tree::InstrumentAwait; -use futures::future::BoxFuture; -use futures::stream::FuturesOrdered; -use futures::{FutureExt, StreamExt, TryFutureExt}; use prometheus::HistogramTimer; use risingwave_common::catalog::TableId; -use risingwave_common::must_match; use risingwave_common::util::epoch::EpochPair; -use risingwave_hummock_sdk::SyncResult; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_storage::StateStoreImpl; -use thiserror_ext::AsReport; use tokio::sync::mpsc; use tokio::task::JoinHandle; use super::progress::BackfillState; -use super::BarrierCompleteResult; use crate::error::{StreamError, StreamResult}; use crate::executor::monitor::StreamingMetrics; -use crate::executor::{Barrier, Mutation}; +use crate::executor::Barrier; use crate::task::{ActorId, PartialGraphId, SharedContext, StreamActorManager}; struct IssuedState { - pub mutation: Option>, /// Actor ids remaining to be collected. pub remaining_actors: BTreeSet, pub barrier_inflight_latency: HistogramTimer, - - /// Only be `Some(_)` when `kind` is `Checkpoint` - pub table_ids: Option>, - - pub kind: BarrierKind, } impl Debug for IssuedState { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("IssuedState") - .field("mutation", &self.mutation) .field("remaining_actors", &self.remaining_actors) - .field("table_ids", &self.table_ids) - .field("kind", &self.kind) .finish() } } @@ -75,107 +56,23 @@ enum ManagedBarrierStateInner { Issued(IssuedState), /// The barrier has been collected by all remaining actors - AllCollected, - - /// The barrier has been completed, which means the barrier has been collected by all actors and - /// synced in state store - Completed(StreamResult), + AllCollected(Vec), } #[derive(Debug)] -pub(super) struct BarrierState { +struct BarrierState { barrier: Barrier, + /// Only be `Some(_)` when `barrier.kind` is `Checkpoint` + table_ids: Option>, inner: ManagedBarrierStateInner, } -mod await_epoch_completed_future { - use std::future::Future; - - use futures::future::BoxFuture; - use futures::FutureExt; - use risingwave_hummock_sdk::SyncResult; - use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress; - - use crate::error::StreamResult; - use crate::executor::Barrier; - use crate::task::{await_tree_key, BarrierCompleteResult}; - - pub(super) type AwaitEpochCompletedFuture = - impl Future)> + 'static; - - pub(super) fn instrument_complete_barrier_future( - complete_barrier_future: Option>>, - barrier: Barrier, - barrier_await_tree_reg: Option<&await_tree::Registry>, - create_mview_progress: Vec, - ) -> AwaitEpochCompletedFuture { - let prev_epoch = barrier.epoch.prev; - let future = async move { - if let Some(future) = complete_barrier_future { - let result = future.await; - result.map(Some) - } else { - Ok(None) - } - } - .map(move |result| { - ( - barrier, - result.map(|sync_result| BarrierCompleteResult { - sync_result, - create_mview_progress, - }), - ) - }); - if let Some(reg) = barrier_await_tree_reg { - reg.register( - await_tree_key::BarrierAwait { prev_epoch }, - format!("SyncEpoch({})", prev_epoch), - ) - .instrument(future) - .left_future() - } else { - future.right_future() - } - } -} - -use await_epoch_completed_future::*; +use risingwave_common::must_match; use risingwave_pb::stream_plan::SubscriptionUpstreamInfo; +use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress; use risingwave_pb::stream_service::streaming_control_stream_request::InitialPartialGraph; use risingwave_pb::stream_service::InjectBarrierRequest; -fn sync_epoch( - state_store: &StateStoreImpl, - streaming_metrics: &StreamingMetrics, - prev_epoch: u64, - table_ids: HashSet, -) -> BoxFuture<'static, StreamResult> { - let timer = streaming_metrics.barrier_sync_latency.start_timer(); - let hummock = state_store.as_hummock().cloned(); - let future = async move { - if let Some(hummock) = hummock { - hummock.sync(vec![(prev_epoch, table_ids)]).await - } else { - Ok(SyncResult::default()) - } - }; - future - .instrument_await(format!("sync_epoch (epoch {})", prev_epoch)) - .inspect_ok(move |_| { - timer.observe_duration(); - }) - .map_err(move |e| { - tracing::error!( - prev_epoch, - error = %e.as_report(), - "Failed to sync state store", - ); - e.into() - }) - .boxed() -} - pub(super) struct ManagedBarrierStateDebugInfo<'a> { graph_states: &'a HashMap, } @@ -197,7 +94,11 @@ impl Display for &'_ PartialGraphManagedBarrierState { write!(f, "> Epoch {}: ", epoch)?; match &barrier_state.inner { ManagedBarrierStateInner::Issued(state) => { - write!(f, "Issued [{:?}]. Remaining actors: [", state.kind)?; + write!( + f, + "Issued [{:?}]. Remaining actors: [", + barrier_state.barrier.kind + )?; let mut is_prev_epoch_issued = false; if prev_epoch != 0 { let bs = &self.epoch_barrier_state_map[&prev_epoch]; @@ -228,12 +129,9 @@ impl Display for &'_ PartialGraphManagedBarrierState { } write!(f, "]")?; } - ManagedBarrierStateInner::AllCollected => { + ManagedBarrierStateInner::AllCollected(_) => { write!(f, "AllCollected")?; } - ManagedBarrierStateInner::Completed(_) => { - write!(f, "Completed")?; - } } prev_epoch = *epoch; writeln!(f)?; @@ -385,18 +283,12 @@ pub(super) struct PartialGraphManagedBarrierState { /// Record the progress updates of creating mviews for each epoch of concurrent checkpoints. /// /// This is updated by [`super::CreateMviewProgressReporter::update`] and will be reported to meta - /// in [`BarrierCompleteResult`]. + /// in [`crate::task::barrier_manager::BarrierCompleteResult`]. pub(super) create_mview_progress: HashMap>, - pub(super) state_store: StateStoreImpl, + state_store: StateStoreImpl, - pub(super) streaming_metrics: Arc, - - /// Futures will be finished in the order of epoch in ascending order. - await_epoch_completed_futures: FuturesOrdered, - - /// Manages the await-trees of all barriers. - barrier_await_tree_reg: Option, + streaming_metrics: Arc, } impl PartialGraphManagedBarrierState { @@ -404,24 +296,17 @@ impl PartialGraphManagedBarrierState { Self::new_inner( actor_manager.env.state_store(), actor_manager.streaming_metrics.clone(), - actor_manager.await_tree_reg.clone(), ) } - fn new_inner( - state_store: StateStoreImpl, - streaming_metrics: Arc, - barrier_await_tree_reg: Option, - ) -> Self { + fn new_inner(state_store: StateStoreImpl, streaming_metrics: Arc) -> Self { Self { epoch_barrier_state_map: Default::default(), prev_barrier_table_ids: None, mv_depended_subscriptions: Default::default(), create_mview_progress: Default::default(), - await_epoch_completed_futures: Default::default(), state_store, streaming_metrics, - barrier_await_tree_reg, } } @@ -430,7 +315,6 @@ impl PartialGraphManagedBarrierState { Self::new_inner( StateStoreImpl::for_test(), Arc::new(StreamingMetrics::unused()), - None, ) } @@ -683,23 +567,26 @@ impl ManagedBarrierState { Ok(()) } - pub(super) fn next_completed_epoch( + pub(super) fn next_collected_epoch( &mut self, - ) -> impl Future + '_ { - poll_fn(|cx| { + ) -> impl Future + '_ { + poll_fn(|_| { + let mut output = None; for (partial_graph_id, graph_state) in &mut self.graph_states { - if let Poll::Ready(barrier) = graph_state.poll_next_completed_barrier(cx) { + if let Some(barrier) = graph_state.may_have_collected_all() { if let Some(actors_to_stop) = barrier.all_stop_actors() { self.current_shared_context.drop_actors(actors_to_stop); } - let partial_graph_id = *partial_graph_id; - return Poll::Ready((partial_graph_id, barrier.epoch.prev)); + output = Some((*partial_graph_id, barrier)); + break; } } - Poll::Pending + output.map(Poll::Ready).unwrap_or(Poll::Pending) }) } +} +impl ManagedBarrierState { pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) { let (prev_partial_graph_id, is_finished) = self .actor_states @@ -718,25 +605,34 @@ impl ManagedBarrierState { .expect("should exist"); prev_graph_state.collect(actor_id, epoch); } + + pub(super) fn pop_barrier_to_complete( + &mut self, + partial_graph_id: PartialGraphId, + prev_epoch: u64, + ) -> ( + Barrier, + Option>, + Vec, + ) { + self.graph_states + .get_mut(&partial_graph_id) + .expect("should exist") + .pop_barrier_to_complete(prev_epoch) + } } impl PartialGraphManagedBarrierState { /// This method is called when barrier state is modified in either `Issued` or `Stashed` /// to transform the state to `AllCollected` and start state store `sync` when the barrier /// has been collected from all actors for an `Issued` barrier. - fn may_have_collected_all(&mut self, prev_epoch: u64) { - // Report if there's progress on the earliest in-flight barrier. - if self.epoch_barrier_state_map.keys().next() == Some(&prev_epoch) { - self.streaming_metrics.barrier_manager_progress.inc(); - } - - for (prev_epoch, barrier_state) in &mut self.epoch_barrier_state_map { - let prev_epoch = *prev_epoch; + fn may_have_collected_all(&mut self) -> Option { + for barrier_state in self.epoch_barrier_state_map.values_mut() { match &barrier_state.inner { ManagedBarrierStateInner::Issued(IssuedState { remaining_actors, .. }) if remaining_actors.is_empty() => {} - ManagedBarrierStateInner::AllCollected | ManagedBarrierStateInner::Completed(_) => { + ManagedBarrierStateInner::AllCollected(_) => { continue; } ManagedBarrierStateInner::Issued(_) => { @@ -744,61 +640,60 @@ impl PartialGraphManagedBarrierState { } } + self.streaming_metrics.barrier_manager_progress.inc(); + + let create_mview_progress = self + .create_mview_progress + .remove(&barrier_state.barrier.epoch.curr) + .unwrap_or_default() + .into_iter() + .map(|(actor, state)| state.to_pb(actor)) + .collect(); + let prev_state = replace( &mut barrier_state.inner, - ManagedBarrierStateInner::AllCollected, + ManagedBarrierStateInner::AllCollected(create_mview_progress), ); - let (kind, table_ids) = must_match!(prev_state, ManagedBarrierStateInner::Issued(IssuedState { + must_match!(prev_state, ManagedBarrierStateInner::Issued(IssuedState { barrier_inflight_latency: timer, - kind, - table_ids, .. }) => { timer.observe_duration(); - (kind, table_ids) }); - let create_mview_progress = self - .create_mview_progress - .remove(&barrier_state.barrier.epoch.curr) - .unwrap_or_default() - .into_iter() - .map(|(actor, state)| state.to_pb(actor)) - .collect(); + return Some(barrier_state.barrier.clone()); + } + None + } - let complete_barrier_future = match kind { - BarrierKind::Unspecified => unreachable!(), - BarrierKind::Initial => { - tracing::info!( - epoch = prev_epoch, - "ignore sealing data for the first barrier" - ); - tracing::info!(?prev_epoch, "ignored syncing data for the first barrier"); - None - } - BarrierKind::Barrier => None, - BarrierKind::Checkpoint => Some(sync_epoch( - &self.state_store, - &self.streaming_metrics, - prev_epoch, - table_ids.expect("should be Some on BarrierKind::Checkpoint"), - )), - }; + fn pop_barrier_to_complete( + &mut self, + prev_epoch: u64, + ) -> ( + Barrier, + Option>, + Vec, + ) { + let (popped_prev_epoch, barrier_state) = self + .epoch_barrier_state_map + .pop_first() + .expect("should exist"); - let barrier = barrier_state.barrier.clone(); + assert_eq!(prev_epoch, popped_prev_epoch); - self.await_epoch_completed_futures.push_back({ - instrument_complete_barrier_future( - complete_barrier_future, - barrier, - self.barrier_await_tree_reg.as_ref(), - create_mview_progress, - ) - }); - } + let create_mview_progress = must_match!(barrier_state.inner, ManagedBarrierStateInner::AllCollected(create_mview_progress) => { + create_mview_progress + }); + ( + barrier_state.barrier, + barrier_state.table_ids, + create_mview_progress, + ) } +} +impl PartialGraphManagedBarrierState { /// Collect a `barrier` from the actor with `actor_id`. pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) { tracing::debug!( @@ -833,7 +728,6 @@ impl PartialGraphManagedBarrierState { actor_id, epoch.curr ); assert_eq!(barrier.epoch.curr, epoch.curr); - self.may_have_collected_all(epoch.prev); } Some(BarrierState { inner, .. }) => { panic!( @@ -917,79 +811,20 @@ impl PartialGraphManagedBarrierState { barrier: barrier.clone(), inner: ManagedBarrierStateInner::Issued(IssuedState { remaining_actors: BTreeSet::from_iter(actor_ids_to_collect), - mutation: barrier.mutation.clone(), barrier_inflight_latency: timer, - kind: barrier.kind, - table_ids, }), + table_ids, }, ); - self.may_have_collected_all(barrier.epoch.prev); - } - - /// Return a future that yields the next completed epoch. The future is cancellation safe. - pub(crate) fn poll_next_completed_barrier(&mut self, cx: &mut Context<'_>) -> Poll { - ready!(self.await_epoch_completed_futures.next().poll_unpin(cx)) - .map(|(barrier, result)| { - let state = self - .epoch_barrier_state_map - .get_mut(&barrier.epoch.prev) - .expect("should exist"); - // sanity check on barrier state - assert_matches!(&state.inner, ManagedBarrierStateInner::AllCollected); - state.inner = ManagedBarrierStateInner::Completed(result); - barrier - }) - .map(Poll::Ready) - .unwrap_or(Poll::Pending) - } - - /// Pop the completion result of an completed epoch. - /// Return: - /// - `Err(_)` `prev_epoch` is not an epoch to be collected. - /// - `Ok(None)` when `prev_epoch` exists but has not completed. - /// - `Ok(Some(_))` when `prev_epoch` has completed but not been reclaimed yet. - /// The `BarrierCompleteResult` will be popped out. - pub(crate) fn pop_completed_epoch( - &mut self, - prev_epoch: u64, - ) -> StreamResult>> { - let state = self - .epoch_barrier_state_map - .get(&prev_epoch) - .ok_or_else(|| { - // It's still possible that `collect_complete_receiver` does not contain the target epoch - // when receiving collect_barrier request. Because `collect_complete_receiver` could - // be cleared when CN is under recovering. We should return error rather than panic. - anyhow!( - "barrier collect complete receiver for prev epoch {} not exists", - prev_epoch - ) - })?; - match &state.inner { - ManagedBarrierStateInner::Completed(_) => { - match self - .epoch_barrier_state_map - .remove(&prev_epoch) - .expect("should exists") - .inner - { - ManagedBarrierStateInner::Completed(result) => Ok(Some(result)), - _ => unreachable!(), - } - } - _ => Ok(None), - } } #[cfg(test)] async fn pop_next_completed_epoch(&mut self) -> u64 { - let barrier = poll_fn(|cx| self.poll_next_completed_barrier(cx)).await; - let _ = self - .pop_completed_epoch(barrier.epoch.prev) - .unwrap() - .unwrap(); - barrier.epoch.prev + if let Some(barrier) = self.may_have_collected_all() { + self.pop_barrier_to_complete(barrier.epoch.prev); + return barrier.epoch.prev; + } + pending().await } } From c1162ab701d3df0b8a0a3f0bc8564508b0090330 Mon Sep 17 00:00:00 2001 From: Eric Fu Date: Wed, 20 Nov 2024 15:44:57 +0800 Subject: [PATCH 04/11] test: fix test cases of `batch/types` (#19441) --- e2e_test/batch/distribution_mode.slt | 2 +- e2e_test/batch/local_mode.slt | 2 +- e2e_test/batch/types/list/list_case.slt.part | 4 +-- e2e_test/batch/types/list/list_cast.slt.part | 34 +++++++++---------- .../batch/types/list/list_storage.slt.part | 8 ++--- .../list/multi-dimentional_list_cast.slt.part | 19 ++--------- 6 files changed, 28 insertions(+), 41 deletions(-) diff --git a/e2e_test/batch/distribution_mode.slt b/e2e_test/batch/distribution_mode.slt index f101e5c5446e..6f46a69a4352 100644 --- a/e2e_test/batch/distribution_mode.slt +++ b/e2e_test/batch/distribution_mode.slt @@ -10,7 +10,7 @@ include ./order/*.slt.part include ./join/*.slt.part include ./join/*/*.slt.part include ./aggregate/*.slt.part -include ./types/*.slt.part +include ./types/**/*.slt.part include ./functions/*.slt.part include ./over_window/main.slt.part include ./subquery/**/*.slt.part diff --git a/e2e_test/batch/local_mode.slt b/e2e_test/batch/local_mode.slt index a64fd49b85c8..9194ddfb1a83 100644 --- a/e2e_test/batch/local_mode.slt +++ b/e2e_test/batch/local_mode.slt @@ -10,7 +10,7 @@ include ./order/*.slt.part include ./join/*.slt.part include ./join/*/*.slt.part include ./aggregate/*.slt.part -include ./types/*.slt.part +include ./types/**/*.slt.part include ./catalog/*.slt.part include ./functions/*.slt.part include ./over_window/main.slt.part diff --git a/e2e_test/batch/types/list/list_case.slt.part b/e2e_test/batch/types/list/list_case.slt.part index dab3ece05cdc..20c87b8d0624 100644 --- a/e2e_test/batch/types/list/list_case.slt.part +++ b/e2e_test/batch/types/list/list_case.slt.part @@ -20,7 +20,7 @@ SELECT case when i%2=0 then ARRAY[i] else ARRAY[-i] end from (select generate_se {-3} {4} -query I +query T SELECT case when i%2=0 then NULL else ARRAY[i] end from (select generate_series as i from generate_series(0,9,1)) as t; ---- NULL @@ -34,7 +34,7 @@ NULL NULL {9} -query I +query T with a as ( SELECT (case when i%2=0 then NULL else ARRAY[i] end) as i from (select generate_series as i from generate_series(0,9,1)) as t ) diff --git a/e2e_test/batch/types/list/list_cast.slt.part b/e2e_test/batch/types/list/list_cast.slt.part index f968011cff7e..842a1349f7a3 100644 --- a/e2e_test/batch/types/list/list_cast.slt.part +++ b/e2e_test/batch/types/list/list_cast.slt.part @@ -1,45 +1,45 @@ statement ok SET RW_IMPLICIT_FLUSH TO true; -query I -select {1,2,3}::double[]; +query T +select array[1,2,3]::double[]; ---- {1,2,3} -query I -select {1.4,2.5,3.6}::int[]; +query T +select array[1.4,2.5,3.6]::int[]; ---- {1,3,4} -query I -select {'1','2','3'}::int[]; +query T +select array['1','2','3']::int[]; ---- {1,2,3} -statement error -select {'1','2','a'}::int[]; +statement error invalid digit +select array['1','2','a']::int[]; -query I -select {{1,2.4},{3,4.7},null,{null}::int[]}::int[][]; +query T +select array[array[1,2.4],array[3,4.7],null,array[null]::int[]]::int[][]; ---- {{1,2},{3,5},NULL,{NULL}} statement ok create table t (a double[]); -statement error -insert into t values ({null}); +statement error cannot cast +insert into t values (array[null]); statement ok -insert into t values ({null::double}); +insert into t values (array[null::double]); statement ok -insert into t values ({null}::double[]); +insert into t values (array[null]::double[]); statement ok insert into t values (null); -query I +query T select * from t order by 1; ---- {NULL} @@ -47,9 +47,9 @@ select * from t order by 1; NULL statement ok -insert into t values ({3.4, 4.3}); +insert into t values (array[3.4, 4.3]); -query I +query T select a::int[] from t order by 1; ---- {3,4} diff --git a/e2e_test/batch/types/list/list_storage.slt.part b/e2e_test/batch/types/list/list_storage.slt.part index 2c993e1c79bc..70fb9bc1fe4b 100644 --- a/e2e_test/batch/types/list/list_storage.slt.part +++ b/e2e_test/batch/types/list/list_storage.slt.part @@ -10,11 +10,11 @@ CREATE TABLE a(b INTEGER[]); statement ok INSERT INTO a VALUES (ARRAY[1, 2]), (NULL), (ARRAY[3, 4, 5, 6]), (ARRAY[NULL, 7]); -query I rowsort +query T rowsort SELECT * FROM a ---- -{1,2} NULL +{1,2} {3,4,5,6} {NULL,7} @@ -24,13 +24,13 @@ CREATE TABLE c(b VARCHAR[]); statement ok INSERT INTO c VALUES (ARRAY['hello', 'world']), (NULL), (ARRAY['fejwfoaejwfoijwafew', 'b', 'c']), (ARRAY[NULL, 'XXXXXXXXXXXXXXXXXXXXXXXX']); -query I rowsort +query T rowsort SELECT * FROM c ---- -{hello,world} NULL {NULL,XXXXXXXXXXXXXXXXXXXXXXXX} {fejwfoaejwfoijwafew,b,c} +{hello,world} statement ok drop table a; diff --git a/e2e_test/batch/types/list/multi-dimentional_list_cast.slt.part b/e2e_test/batch/types/list/multi-dimentional_list_cast.slt.part index 8a67840a6c20..9345e30dec65 100644 --- a/e2e_test/batch/types/list/multi-dimentional_list_cast.slt.part +++ b/e2e_test/batch/types/list/multi-dimentional_list_cast.slt.part @@ -1,25 +1,12 @@ -query I +query T select array[array[1, 2], array[3, 4]]; ---- {{1,2},{3,4}} -query I +query T select array[[1, 2], [3, 4]]; ---- {{1,2},{3,4}} -query I +query error sql parser error select array[[array[1, 2]], [[3, 4]]]; ----- -{{{1,2}},{{3,4}}} - -query I -select array[[[1, 2]], [array[3, 4]]]; ----- -{{{1,2}},{{3,4}}} - -statement error syntax error at or near -select array[array[1, 2], [3, 4]]; - -statement error syntax error at or near -select array[[1, 2], array[3, 4]]; \ No newline at end of file From 5f1a59b000b92a2c81163c520564827a0a798ff5 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 20 Nov 2024 16:27:26 +0800 Subject: [PATCH 05/11] refactor(frontend): rework `UPDATE` to support subqueries (#19402) Signed-off-by: Bugen Zhao --- .../{dml.slt.part => dml_basic.slt.part} | 0 e2e_test/batch/basic/dml_update.slt.part | 132 +++++++++++ proto/batch_plan.proto | 11 +- src/batch/src/executor/update.rs | 66 +++--- .../tests/testdata/input/update.yaml | 29 ++- .../testdata/output/index_selection.yaml | 18 +- .../tests/testdata/output/update.yaml | 172 +++++++++----- src/frontend/src/binder/expr/subquery.rs | 16 +- src/frontend/src/binder/mod.rs | 2 +- src/frontend/src/binder/update.rs | 219 ++++++++++-------- src/frontend/src/error.rs | 6 +- src/frontend/src/expr/subquery.rs | 4 + .../src/optimizer/plan_node/batch_update.rs | 22 +- .../src/optimizer/plan_node/generic/update.rs | 45 ++-- .../src/optimizer/plan_node/logical_update.rs | 32 +-- src/frontend/src/planner/select.rs | 19 +- src/frontend/src/planner/update.rs | 79 +++++-- 17 files changed, 589 insertions(+), 283 deletions(-) rename e2e_test/batch/basic/{dml.slt.part => dml_basic.slt.part} (100%) create mode 100644 e2e_test/batch/basic/dml_update.slt.part diff --git a/e2e_test/batch/basic/dml.slt.part b/e2e_test/batch/basic/dml_basic.slt.part similarity index 100% rename from e2e_test/batch/basic/dml.slt.part rename to e2e_test/batch/basic/dml_basic.slt.part diff --git a/e2e_test/batch/basic/dml_update.slt.part b/e2e_test/batch/basic/dml_update.slt.part new file mode 100644 index 000000000000..fcc3bbdfce9a --- /dev/null +++ b/e2e_test/batch/basic/dml_update.slt.part @@ -0,0 +1,132 @@ +# Extension to `dml_basic.slt.part` for testing advanced `UPDATE` statements. + +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t (v1 int default 1919, v2 int default 810); + +statement ok +insert into t values (114, 514); + + +# Single assignment, to subquery. +statement ok +update t set v1 = (select 666); + +query II +select * from t; +---- +666 514 + +# Single assignment, to runtime-cardinality subquery returning 1 row. +statement ok +update t set v1 = (select generate_series(888, 888)); + +query II +select * from t; +---- +888 514 + +# Single assignment, to runtime-cardinality subquery returning 0 rows (set to NULL). +statement ok +update t set v1 = (select generate_series(1, 0)); + +query II +select * from t; +---- +NULL 514 + +# Single assignment, to runtime-cardinality subquery returning multiple rows. +statement error Scalar subquery produced more than one row +update t set v1 = (select generate_series(1, 2)); + +# Single assignment, to correlated subquery. +statement ok +update t set v1 = (select count(*) from t as source where source.v2 = t.v2); + +query II +select * from t; +---- +1 514 + +# Single assignment, to subquery with mismatched column count. +statement error must return only one column +update t set v1 = (select 666, 888); + + +# Multiple assignment clauses. +statement ok +update t set v1 = 1919, v2 = 810; + +query II +select * from t; +---- +1919 810 + +# Multiple assignments to the same column. +statement error multiple assignments to the same column +update t set v1 = 1, v1 = 2; + +statement error multiple assignments to the same column +update t set (v1, v1) = (1, 2); + +statement error multiple assignments to the same column +update t set (v1, v2) = (1, 2), v2 = 2; + +# Multiple assignments, to subquery. +statement ok +update t set (v1, v2) = (select 666, 888); + +query II +select * from t; +---- +666 888 + +# Multiple assignments, to subquery with cast. +statement ok +update t set (v1, v2) = (select 888.88, 999); + +query II +select * from t; +---- +889 999 + +# Multiple assignments, to subquery with cast failure. +# TODO: this currently shows `cannot cast type "record" to "record"` because we wrap the subquery result +# into a struct, which is not quite clear. +statement error cannot cast type +update t set (v1, v2) = (select '888.88', 999); + +# Multiple assignments, to subquery with mismatched column count. +statement error number of columns does not match number of values +update t set (v1, v2) = (select 666); + +# Multiple assignments, to scalar expression. +statement error source for a multiple-column UPDATE item must be a sub-SELECT or ROW\(\) expression +update t set (v1, v2) = v1 + 1; + + +# Assignment to system columns. +statement error update modifying column `_rw_timestamp` is unsupported +update t set _rw_timestamp = _rw_timestamp + interval '1 second'; + + +# https://github.com/risingwavelabs/risingwave/pull/19402#pullrequestreview-2444427475 +# https://github.com/risingwavelabs/risingwave/pull/19452 +statement ok +create table y (v1 int, v2 int); + +statement ok +insert into y values (11, 11), (22, 22); + +statement error Scalar subquery produced more than one row +update t set (v1, v2) = (select y.v1, y.v2 from y); + +statement ok +drop table y; + + +# Cleanup. +statement ok +drop table t; diff --git a/proto/batch_plan.proto b/proto/batch_plan.proto index b46230b2438d..f10092d952ac 100644 --- a/proto/batch_plan.proto +++ b/proto/batch_plan.proto @@ -173,11 +173,12 @@ message UpdateNode { // Id of the table to perform updating. uint32 table_id = 1; // Version of the table. - uint64 table_version_id = 4; - repeated expr.ExprNode exprs = 2; - bool returning = 3; - // The columns indices in the input schema, representing the columns need to send to streamDML exeuctor. - repeated uint32 update_column_indices = 5; + uint64 table_version_id = 2; + // Expressions to generate `U-` records. + repeated expr.ExprNode old_exprs = 3; + // Expressions to generate `U+` records. + repeated expr.ExprNode new_exprs = 4; + bool returning = 5; // Session id is used to ensure that dml data from the same session should be sent to a fixed worker node and channel. uint32 session_id = 6; diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index a753aef840f5..95f1963cf582 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -42,13 +42,13 @@ pub struct UpdateExecutor { table_version_id: TableVersionId, dml_manager: DmlManagerRef, child: BoxedExecutor, - exprs: Vec, + old_exprs: Vec, + new_exprs: Vec, chunk_size: usize, schema: Schema, identity: String, returning: bool, txn_id: TxnId, - update_column_indices: Vec, session_id: u32, } @@ -59,11 +59,11 @@ impl UpdateExecutor { table_version_id: TableVersionId, dml_manager: DmlManagerRef, child: BoxedExecutor, - exprs: Vec, + old_exprs: Vec, + new_exprs: Vec, chunk_size: usize, identity: String, returning: bool, - update_column_indices: Vec, session_id: u32, ) -> Self { let chunk_size = chunk_size.next_multiple_of(2); @@ -75,7 +75,8 @@ impl UpdateExecutor { table_version_id, dml_manager, child, - exprs, + old_exprs, + new_exprs, chunk_size, schema: if returning { table_schema @@ -87,7 +88,6 @@ impl UpdateExecutor { identity, returning, txn_id, - update_column_indices, session_id, } } @@ -109,7 +109,7 @@ impl Executor for UpdateExecutor { impl UpdateExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] - async fn do_execute(mut self: Box) { + async fn do_execute(self: Box) { let table_dml_handle = self .dml_manager .table_dml_handle(self.table_id, self.table_version_id)?; @@ -122,15 +122,12 @@ impl UpdateExecutor { assert_eq!( data_types, - self.exprs.iter().map(|e| e.return_type()).collect_vec(), + self.new_exprs.iter().map(|e| e.return_type()).collect_vec(), "bad update schema" ); assert_eq!( data_types, - self.update_column_indices - .iter() - .map(|i: &usize| self.child.schema()[*i].data_type.clone()) - .collect_vec(), + self.old_exprs.iter().map(|e| e.return_type()).collect_vec(), "bad update schema" ); @@ -159,27 +156,35 @@ impl UpdateExecutor { let mut rows_updated = 0; #[for_await] - for data_chunk in self.child.execute() { - let data_chunk = data_chunk?; + for input in self.child.execute() { + let input = input?; + + let old_data_chunk = { + let mut columns = Vec::with_capacity(self.old_exprs.len()); + for expr in &self.old_exprs { + let column = expr.eval(&input).await?; + columns.push(column); + } + + DataChunk::new(columns, input.visibility().clone()) + }; let updated_data_chunk = { - let mut columns = Vec::with_capacity(self.exprs.len()); - for expr in &mut self.exprs { - let column = expr.eval(&data_chunk).await?; + let mut columns = Vec::with_capacity(self.new_exprs.len()); + for expr in &self.new_exprs { + let column = expr.eval(&input).await?; columns.push(column); } - DataChunk::new(columns, data_chunk.visibility().clone()) + DataChunk::new(columns, input.visibility().clone()) }; if self.returning { yield updated_data_chunk.clone(); } - for (row_delete, row_insert) in data_chunk - .project(&self.update_column_indices) - .rows() - .zip_eq_debug(updated_data_chunk.rows()) + for (row_delete, row_insert) in + (old_data_chunk.rows()).zip_eq_debug(updated_data_chunk.rows()) { rows_updated += 1; // If row_delete == row_insert, we don't need to do a actual update @@ -227,34 +232,35 @@ impl BoxedExecutorBuilder for UpdateExecutor { let table_id = TableId::new(update_node.table_id); - let exprs: Vec<_> = update_node - .get_exprs() + let old_exprs: Vec<_> = update_node + .get_old_exprs() .iter() .map(build_from_prost) .try_collect()?; - let update_column_indices = update_node - .update_column_indices + let new_exprs: Vec<_> = update_node + .get_new_exprs() .iter() - .map(|x| *x as usize) - .collect_vec(); + .map(build_from_prost) + .try_collect()?; Ok(Box::new(Self::new( table_id, update_node.table_version_id, source.context().dml_manager(), child, - exprs, + old_exprs, + new_exprs, source.context.get_config().developer.chunk_size, source.plan_node().get_identity().clone(), update_node.returning, - update_column_indices, update_node.session_id, ))) } } #[cfg(test)] +#[cfg(any())] mod tests { use std::sync::Arc; diff --git a/src/frontend/planner_test/tests/testdata/input/update.yaml b/src/frontend/planner_test/tests/testdata/input/update.yaml index 65c0f47eb4cd..744735af843d 100644 --- a/src/frontend/planner_test/tests/testdata/input/update.yaml +++ b/src/frontend/planner_test/tests/testdata/input/update.yaml @@ -76,7 +76,7 @@ update t set v2 = 3; expected_outputs: - binder_error -- name: update subquery +- name: update subquery selection sql: | create table t (a int, b int); update t set a = 777 where b not in (select a from t); @@ -98,10 +98,27 @@ update t set a = a + 1; expected_outputs: - batch_distributed_plan -- name: update table with subquery in the set clause +- name: update table to subquery sql: | - create table t1 (v1 int primary key, v2 int); - create table t2 (v1 int primary key, v2 int); - update t1 set v1 = (select v1 from t2 where t1.v2 = t2.v2); + create table t (v1 int, v2 int); + update t set v1 = (select 666); + expected_outputs: + - batch_plan +- name: update table to subquery with runtime cardinality + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select generate_series(888, 888)); + expected_outputs: + - batch_plan +- name: update table to correlated subquery + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select count(*) from t as source where source.v2 = t.v2); expected_outputs: - - binder_error + - batch_plan +- name: update table to subquery with multiple assignments + sql: | + create table t (v1 int, v2 int); + update t set (v1, v2) = (select 666.66, 777); + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/index_selection.yaml b/src/frontend/planner_test/tests/testdata/output/index_selection.yaml index a6240c69f395..349c5f7d8901 100644 --- a/src/frontend/planner_test/tests/testdata/output/index_selection.yaml +++ b/src/frontend/planner_test/tests/testdata/output/index_selection.yaml @@ -213,16 +213,18 @@ update t1 set c = 3 where a = 1 and b = 2; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t1, exprs: [$0, $1, 3:Int64, $3] } + └─BatchUpdate { table: t1, exprs: [$0, $1, $5, $3] } └─BatchExchange { order: [], dist: Single } - └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } - └─BatchExchange { order: [], dist: UpstreamHashShard(idx2.t1._row_id) } - └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } + └─BatchProject { exprs: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp, 3:Int64] } + └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } + └─BatchExchange { order: [], dist: UpstreamHashShard(idx2.t1._row_id) } + └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } batch_local_plan: |- - BatchUpdate { table: t1, exprs: [$0, $1, 3:Int64, $3] } - └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } - └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } + BatchUpdate { table: t1, exprs: [$0, $1, $5, $3] } + └─BatchProject { exprs: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp, 3:Int64] } + └─BatchLookupJoin { type: Inner, predicate: idx2.t1._row_id IS NOT DISTINCT FROM t1._row_id, output: [t1.a, t1.b, t1.c, t1._row_id, t1._rw_timestamp], lookup table: t1 } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: idx2, columns: [idx2.t1._row_id], scan_ranges: [idx2.b = Decimal(Normalized(2)) AND idx2.a = Int32(1)], distribution: SomeShard } - sql: | create table t1 (a int, b numeric, c bigint, p int); create materialized view v as select count(*) as cnt, p from t1 group by p; diff --git a/src/frontend/planner_test/tests/testdata/output/update.yaml b/src/frontend/planner_test/tests/testdata/output/update.yaml index 19d6673d77f9..4a12b492660a 100644 --- a/src/frontend/planner_test/tests/testdata/output/update.yaml +++ b/src/frontend/planner_test/tests/testdata/output/update.yaml @@ -4,9 +4,10 @@ update t set v1 = 0; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [0:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, 0:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set v1 = true; @@ -16,72 +17,81 @@ update t set v1 = v2 + 1; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 real); update t set v1 = v2; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [$1::Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, t.v2::Int32 as $expr1] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 real); update t set v1 = DEFAULT; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [null:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, null:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set v1 = v2 + 1 where v2 > 0; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchFilter { predicate: (t.v2 > 0:Int32) } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1] } + └─BatchFilter { predicate: (t.v2 > 0:Int32) } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set (v1, v2) = (v2 + 1, v1 - 1) where v1 != v2; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), ($0 - 1:Int32), $2] } + └─BatchUpdate { table: t, exprs: [$4, $5, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchFilter { predicate: (t.v1 <> t.v2) } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } + └─BatchFilter { predicate: (t.v1 <> t.v2) } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set (v1, v2) = (v2 + 1, v1 - 1) where v1 != v2 returning *, v2+1, v1-1; logical_plan: |- - LogicalProject { exprs: [t.v1, t.v2, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } - └─LogicalUpdate { table: t, exprs: [($1 + 1:Int32), ($0 - 1:Int32), $2], returning: true } - └─LogicalFilter { predicate: (t.v1 <> t.v2) } - └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } + LogicalProject { exprs: [, , ( + 1:Int32) as $expr3, ( - 1:Int32) as $expr4] } + └─LogicalUpdate { table: t, exprs: [$4, $5, $2], returning: true } + └─LogicalProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } + └─LogicalFilter { predicate: (t.v1 <> t.v2) } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchProject { exprs: [t.v1, t.v2, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } - └─BatchUpdate { table: t, exprs: [($1 + 1:Int32), ($0 - 1:Int32), $2], returning: true } + └─BatchProject { exprs: [, , ( + 1:Int32) as $expr3, ( - 1:Int32) as $expr4] } + └─BatchUpdate { table: t, exprs: [$4, $5, $2], returning: true } └─BatchExchange { order: [], dist: Single } - └─BatchFilter { predicate: (t.v1 <> t.v2) } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (t.v2 + 1:Int32) as $expr1, (t.v1 - 1:Int32) as $expr2] } + └─BatchFilter { predicate: (t.v1 <> t.v2) } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - name: update with returning statement, should keep `Update` sql: | create table t (v int); update t set v = 114 returning 514; logical_plan: |- LogicalProject { exprs: [514:Int32] } - └─LogicalUpdate { table: t, exprs: [114:Int32, $1], returning: true } - └─LogicalScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp] } + └─LogicalUpdate { table: t, exprs: [$3, $1], returning: true } + └─LogicalProject { exprs: [t.v, t._row_id, t._rw_timestamp, 114:Int32] } + └─LogicalScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } └─BatchProject { exprs: [514:Int32] } - └─BatchUpdate { table: t, exprs: [114:Int32, $1], returning: true } + └─BatchUpdate { table: t, exprs: [$3, $1], returning: true } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v, t._row_id, t._rw_timestamp, 114:Int32] } + └─BatchScan { table: t, columns: [t.v, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int primary key, v2 int); update t set (v2, v1) = (v1, v2); @@ -90,22 +100,25 @@ create table t (v1 int default 1+1, v2 int); update t set v1 = default; logical_plan: |- - LogicalUpdate { table: t, exprs: [(1:Int32 + 1:Int32), $1, $2] } - └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } + LogicalUpdate { table: t, exprs: [$4, $1, $2] } + └─LogicalProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, (1:Int32 + 1:Int32) as $expr1] } + └─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [2:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, 2:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - name: update table with generated columns sql: | create table t(v1 int as v2-1, v2 int, v3 int as v2+1); update t set v2 = 3; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [3:Int32, $3] } + └─BatchUpdate { table: t, exprs: [$5, $3] } └─BatchExchange { order: [], dist: Single } - └─BatchScan { table: t, columns: [t.v1, t.v2, t.v3, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchProject { exprs: [t.v1, t.v2, t.v3, t._row_id, t._rw_timestamp, 3:Int32] } + └─BatchScan { table: t, columns: [t.v1, t.v2, t.v3, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - name: update generated column sql: | create table t(v1 int as v2-1, v2 int, v3 int as v2+1); @@ -116,25 +129,27 @@ create table t(v1 int as v2-1, v2 int, v3 int as v2+1, primary key (v3)); update t set v2 = 3; binder_error: 'Bind error: update modifying the column referenced by generated columns that are part of the primary key is not allowed' -- name: update subquery +- name: update subquery selection sql: | create table t (a int, b int); update t set a = 777 where b not in (select a from t); logical_plan: |- - LogicalUpdate { table: t, exprs: [777:Int32, $1, $2] } - └─LogicalApply { type: LeftAnti, on: (t.b = t.a), correlated_id: 1 } - ├─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } - └─LogicalProject { exprs: [t.a] } - └─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } + LogicalUpdate { table: t, exprs: [$4, $1, $2] } + └─LogicalProject { exprs: [t.a, t.b, t._row_id, t._rw_timestamp, 777:Int32] } + └─LogicalApply { type: LeftAnti, on: (t.b = t.a), correlated_id: 1 } + ├─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } + └─LogicalProject { exprs: [t.a] } + └─LogicalScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp] } batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [777:Int32, $1, $2] } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } └─BatchExchange { order: [], dist: Single } - └─BatchHashJoin { type: LeftAnti, predicate: t.b = t.a, output: all } - ├─BatchExchange { order: [], dist: HashShard(t.b) } - │ └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } - └─BatchExchange { order: [], dist: HashShard(t.a) } - └─BatchScan { table: t, columns: [t.a], distribution: SomeShard } + └─BatchProject { exprs: [t.a, t.b, t._row_id, t._rw_timestamp, 777:Int32] } + └─BatchHashJoin { type: LeftAnti, predicate: t.b = t.a, output: all } + ├─BatchExchange { order: [], dist: HashShard(t.b) } + │ └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchExchange { order: [], dist: HashShard(t.a) } + └─BatchScan { table: t, columns: [t.a], distribution: SomeShard } - name: delete subquery sql: | create table t (a int, b int); @@ -163,12 +178,65 @@ batch_distributed_plan: |- BatchSimpleAgg { aggs: [sum()] } └─BatchExchange { order: [], dist: Single } - └─BatchUpdate { table: t, exprs: [($0 + 1:Int32), $1, $2] } - └─BatchExchange { order: [], dist: HashShard(t.a, t.b, t._row_id, t._rw_timestamp) } - └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } -- name: update table with subquery in the set clause + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchExchange { order: [], dist: HashShard(t.a, t.b, t._row_id, t._rw_timestamp, $expr1) } + └─BatchProject { exprs: [t.a, t.b, t._row_id, t._rw_timestamp, (t.a + 1:Int32) as $expr1] } + └─BatchScan { table: t, columns: [t.a, t.b, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } +- name: update table to subquery + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select 666); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchValues { rows: [[666:Int32]] } +- name: update table to subquery with runtime cardinality + sql: | + create table t (v1 int, v2 int); + update t set v1 = (select generate_series(888, 888)); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchMaxOneRow + └─BatchProject { exprs: [GenerateSeries(888:Int32, 888:Int32)] } + └─BatchProjectSet { select_list: [GenerateSeries(888:Int32, 888:Int32)] } + └─BatchValues { rows: [[]] } +- name: update table to correlated subquery sql: | - create table t1 (v1 int primary key, v2 int); - create table t2 (v1 int primary key, v2 int); - update t1 set v1 = (select v1 from t2 where t1.v2 = t2.v2); - binder_error: 'Bind error: subquery on the right side of assignment is unsupported' + create table t (v1 int, v2 int); + update t set v1 = (select count(*) from t as source where source.v2 = t.v2); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [$4, $1, $2] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, count(1:Int32)::Int32 as $expr1] } + └─BatchHashJoin { type: LeftOuter, predicate: t.v2 IS NOT DISTINCT FROM t.v2, output: [t.v1, t.v2, t._row_id, t._rw_timestamp, count(1:Int32)] } + ├─BatchExchange { order: [], dist: HashShard(t.v2) } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchHashAgg { group_key: [t.v2], aggs: [count(1:Int32)] } + └─BatchHashJoin { type: LeftOuter, predicate: t.v2 IS NOT DISTINCT FROM t.v2, output: [t.v2, 1:Int32] } + ├─BatchHashAgg { group_key: [t.v2], aggs: [] } + │ └─BatchExchange { order: [], dist: HashShard(t.v2) } + │ └─BatchScan { table: t, columns: [t.v2], distribution: SomeShard } + └─BatchExchange { order: [], dist: HashShard(t.v2) } + └─BatchProject { exprs: [t.v2, 1:Int32] } + └─BatchFilter { predicate: IsNotNull(t.v2) } + └─BatchScan { table: t, columns: [t.v2], distribution: SomeShard } +- name: update table to subquery with multiple assignments + sql: | + create table t (v1 int, v2 int); + update t set (v1, v2) = (select 666.66, 777); + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [Field($4, 0:Int32), Field($4, 1:Int32), $2] } + └─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, $expr10011::Struct(StructType { field_names: [], field_types: [Int32, Int32] }) as $expr1] } + └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) } + └─BatchValues { rows: [['(666.66,777)':Struct(StructType { field_names: [], field_types: [Decimal, Int32] })]] } diff --git a/src/frontend/src/binder/expr/subquery.rs b/src/frontend/src/binder/expr/subquery.rs index 51819116771f..c31a5d653aeb 100644 --- a/src/frontend/src/binder/expr/subquery.rs +++ b/src/frontend/src/binder/expr/subquery.rs @@ -15,20 +15,16 @@ use risingwave_sqlparser::ast::Query; use crate::binder::Binder; -use crate::error::{ErrorCode, Result}; +use crate::error::{bail_bind_error, Result}; use crate::expr::{ExprImpl, Subquery, SubqueryKind}; impl Binder { - pub(super) fn bind_subquery_expr( - &mut self, - query: Query, - kind: SubqueryKind, - ) -> Result { + pub fn bind_subquery_expr(&mut self, query: Query, kind: SubqueryKind) -> Result { let query = self.bind_query(query)?; - if !matches!(kind, SubqueryKind::Existential) && query.data_types().len() != 1 { - return Err( - ErrorCode::BindError("Subquery must return only one column".to_string()).into(), - ); + if !matches!(kind, SubqueryKind::Existential | SubqueryKind::UpdateSet) + && query.data_types().len() != 1 + { + bail_bind_error!("Subquery must return only one column"); } Ok(Subquery::new(query, kind).into()) } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index b346dc45ca2d..4560e51bd656 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -58,7 +58,7 @@ pub use relation::{ pub use select::{BoundDistinct, BoundSelect}; pub use set_expr::*; pub use statement::BoundStatement; -pub use update::BoundUpdate; +pub use update::{BoundUpdate, UpdateProject}; pub use values::BoundValues; use crate::catalog::catalog_service::CatalogReadGuard; diff --git a/src/frontend/src/binder/update.rs b/src/frontend/src/binder/update.rs index 9cc80dbde447..f57ad1d19798 100644 --- a/src/frontend/src/binder/update.rs +++ b/src/frontend/src/binder/update.rs @@ -12,23 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::hash_map::Entry; use std::collections::{BTreeMap, HashMap}; use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::catalog::{Schema, TableVersionId}; +use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem}; use super::statement::RewriteExprsRecursive; use super::{Binder, BoundBaseTable}; use crate::catalog::TableId; -use crate::error::{ErrorCode, Result, RwError}; -use crate::expr::{Expr as _, ExprImpl, InputRef}; +use crate::error::{bail_bind_error, bind_error, ErrorCode, Result, RwError}; +use crate::expr::{Expr as _, ExprImpl, SubqueryKind}; use crate::user::UserId; use crate::TableCatalog; +/// Project into `exprs` in `BoundUpdate` to get the new values for updating. +#[derive(Debug, Clone, Copy)] +pub enum UpdateProject { + /// Use the expression at the given index in `exprs`. + Simple(usize), + /// Use the `i`-th field of the expression (returning a struct) at the given index in `exprs`. + Composite(usize, usize), +} + +impl UpdateProject { + /// Offset the index by `i`. + pub fn offset(self, i: usize) -> Self { + match self { + UpdateProject::Simple(index) => UpdateProject::Simple(index + i), + UpdateProject::Composite(index, j) => UpdateProject::Composite(index + i, j), + } + } +} + #[derive(Debug, Clone)] pub struct BoundUpdate { /// Id of the table to perform updating. @@ -48,10 +67,14 @@ pub struct BoundUpdate { pub selection: Option, - /// Expression used to project to the updated row. The assigned columns will use the new - /// expression, and the other columns will be simply `InputRef`. + /// Expression used to evaluate the new values for the columns. pub exprs: Vec, + /// Mapping from the index of the column to be updated, to the index of the expression in `exprs`. + /// + /// By constructing two `Project` nodes with `exprs` and `projects`, we can get the new values. + pub projects: HashMap, + // used for the 'RETURNING" keyword to indicate the returning items and schema // if the list is empty and the schema is None, the output schema will be a INT64 as the // affected row cnt @@ -124,107 +147,112 @@ impl Binder { let selection = selection.map(|expr| self.bind_expr(expr)).transpose()?; - let mut assignment_exprs = HashMap::new(); - for Assignment { id, value } in assignments { - // FIXME: Parsing of `id` is not strict. It will even treat `a.b` as `(a, b)`. - let assignments = match (id.as_slice(), value) { - // _ = (subquery) - (_ids, AssignmentValue::Expr(Expr::Subquery(_))) => { - return Err(ErrorCode::BindError( - "subquery on the right side of assignment is unsupported".to_owned(), - ) - .into()) - } - // col = expr - ([id], value) => { - vec![(id.clone(), value)] - } - // (col1, col2) = (expr1, expr2) - // TODO: support `DEFAULT` in multiple assignments - (ids, AssignmentValue::Expr(Expr::Row(values))) if ids.len() == values.len() => id - .into_iter() - .zip_eq_fast(values.into_iter().map(AssignmentValue::Expr)) - .collect(), - // (col1, col2) = - _ => { - return Err(ErrorCode::BindError( - "number of columns does not match number of values".to_owned(), - ) - .into()) - } + let mut exprs = Vec::new(); + let mut projects = HashMap::new(); + + macro_rules! record { + ($id:expr, $project:expr) => { + let id_index = $id.as_input_ref().unwrap().index; + projects + .try_insert(id_index, $project) + .map_err(|_e| bind_error!("multiple assignments to the same column"))?; }; + } - for (id, value) in assignments { - let id_expr = self.bind_expr(Expr::Identifier(id.clone()))?; - let id_index = if let Some(id_input_ref) = id_expr.clone().as_input_ref() { - let id_index = id_input_ref.index; - if table - .table_catalog - .pk() - .iter() - .any(|k| k.column_index == id_index) - { - return Err(ErrorCode::BindError( - "update modifying the PK column is unsupported".to_owned(), - ) - .into()); - } - if table - .table_catalog - .generated_col_idxes() - .contains(&id_index) - { - return Err(ErrorCode::BindError( - "update modifying the generated column is unsupported".to_owned(), - ) - .into()); + for Assignment { id, value } in assignments { + let ids: Vec<_> = id + .into_iter() + .map(|id| self.bind_expr(Expr::Identifier(id))) + .try_collect()?; + + match (ids.as_slice(), value) { + // `SET col1 = DEFAULT`, `SET (col1, col2, ...) = DEFAULT` + (ids, AssignmentValue::Default) => { + for id in ids { + let id_index = id.as_input_ref().unwrap().index; + let expr = default_columns_from_catalog + .get(&id_index) + .cloned() + .unwrap_or_else(|| ExprImpl::literal_null(id.return_type())); + + exprs.push(expr); + record!(id, UpdateProject::Simple(exprs.len() - 1)); } - if cols_refed_by_generated_pk.contains(id_index) { - return Err(ErrorCode::BindError( - "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(), - ) - .into()); + } + + // `SET col1 = expr` + ([id], AssignmentValue::Expr(expr)) => { + let expr = self.bind_expr(expr)?.cast_assign(id.return_type())?; + exprs.push(expr); + record!(id, UpdateProject::Simple(exprs.len() - 1)); + } + // `SET (col1, col2, ...) = (val1, val2, ...)` + (ids, AssignmentValue::Expr(Expr::Row(values))) => { + if ids.len() != values.len() { + bail_bind_error!("number of columns does not match number of values"); } - id_index - } else { - unreachable!() - }; - - let value_expr = match value { - AssignmentValue::Expr(expr) => { - self.bind_expr(expr)?.cast_assign(id_expr.return_type())? + + for (id, value) in ids.iter().zip_eq_fast(values) { + let expr = self.bind_expr(value)?.cast_assign(id.return_type())?; + exprs.push(expr); + record!(id, UpdateProject::Simple(exprs.len() - 1)); } - AssignmentValue::Default => default_columns_from_catalog - .get(&id_index) - .cloned() - .unwrap_or_else(|| ExprImpl::literal_null(id_expr.return_type())), - }; - - match assignment_exprs.entry(id_expr) { - Entry::Occupied(_) => { - return Err(ErrorCode::BindError( - "multiple assignments to same column".to_owned(), - ) - .into()) + } + // `SET (col1, col2, ...) = (SELECT ...)` + (ids, AssignmentValue::Expr(Expr::Subquery(subquery))) => { + let expr = self.bind_subquery_expr(*subquery, SubqueryKind::UpdateSet)?; + + if expr.return_type().as_struct().len() != ids.len() { + bail_bind_error!("number of columns does not match number of values"); } - Entry::Vacant(v) => { - v.insert(value_expr); + + let target_type = DataType::new_unnamed_struct( + ids.iter().map(|id| id.return_type()).collect(), + ); + let expr = expr.cast_assign(target_type)?; + + exprs.push(expr); + + for (i, id) in ids.iter().enumerate() { + record!(id, UpdateProject::Composite(exprs.len() - 1, i)); } } + + (_ids, AssignmentValue::Expr(_expr)) => { + bail_bind_error!("source for a multiple-column UPDATE item must be a sub-SELECT or ROW() expression"); + } } } - let exprs = table - .table_catalog - .columns() - .iter() - .enumerate() - .filter_map(|(i, c)| { - c.can_dml() - .then_some(InputRef::new(i, c.data_type().clone()).into()) - }) - .map(|c| assignment_exprs.remove(&c).unwrap_or(c)) - .collect_vec(); + // Check whether updating these columns is allowed. + for &id_index in projects.keys() { + if (table.table_catalog.pk()) + .iter() + .any(|k| k.column_index == id_index) + { + return Err(ErrorCode::BindError( + "update modifying the PK column is unsupported".to_owned(), + ) + .into()); + } + if (table.table_catalog.generated_col_idxes()).contains(&id_index) { + return Err(ErrorCode::BindError( + "update modifying the generated column is unsupported".to_owned(), + ) + .into()); + } + if cols_refed_by_generated_pk.contains(id_index) { + return Err(ErrorCode::BindError( + "update modifying the column referenced by generated columns that are part of the primary key is not allowed".to_owned(), + ) + .into()); + } + + let col = &table.table_catalog.columns()[id_index]; + if !col.can_dml() { + bail_bind_error!("update modifying column `{}` is unsupported", col.name()); + } + } let (returning_list, fields) = self.bind_returning_list(returning_items)?; let returning = !returning_list.is_empty(); @@ -236,6 +264,7 @@ impl Binder { owner, table, selection, + projects, exprs, returning_list, returning_schema: if returning { diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 3092c9bee91a..f0cf35e85966 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -33,8 +33,8 @@ use tokio::task::JoinError; // - Some variants are never constructed. // - Some variants store a type-erased `BoxedError` to resolve the reverse dependency. // It's not necessary anymore as the error type is now defined at the top-level. -#[derive(Error, thiserror_ext::ReportDebug, thiserror_ext::Box)] -#[thiserror_ext(newtype(name = RwError, backtrace))] +#[derive(Error, thiserror_ext::ReportDebug, thiserror_ext::Box, thiserror_ext::Macro)] +#[thiserror_ext(newtype(name = RwError, backtrace), macro(path = "crate::error"))] pub enum ErrorCode { #[error("internal error: {0}")] InternalError(String), @@ -105,7 +105,7 @@ pub enum ErrorCode { // TODO: use a new type for bind error // TODO(error-handling): should prefer use error types than strings. #[error("Bind error: {0}")] - BindError(String), + BindError(#[message] String), // TODO: only keep this one #[error("Failed to bind expression: {expr}: {error}")] BindErrorRoot { diff --git a/src/frontend/src/expr/subquery.rs b/src/frontend/src/expr/subquery.rs index 62f59c934dd6..8460f73d5fbb 100644 --- a/src/frontend/src/expr/subquery.rs +++ b/src/frontend/src/expr/subquery.rs @@ -24,6 +24,9 @@ use crate::expr::{CorrelatedId, Depth}; pub enum SubqueryKind { /// Returns a scalar value (single column single row). Scalar, + /// Returns a scalar struct value composed of multiple columns. + /// Used in `UPDATE SET (col1, col2) = (SELECT ...)`. + UpdateSet, /// `EXISTS` | `NOT EXISTS` subquery (semi/anti-semi join). Returns a boolean. Existential, /// `IN` subquery. @@ -88,6 +91,7 @@ impl Expr for Subquery { assert_eq!(types.len(), 1, "Subquery with more than one column"); types[0].clone() } + SubqueryKind::UpdateSet => DataType::new_unnamed_struct(self.query.data_types()), SubqueryKind::Array => { let types = self.query.data_types(); assert_eq!(types.len(), 1, "Subquery with more than one column"); diff --git a/src/frontend/src/optimizer/plan_node/batch_update.rs b/src/frontend/src/optimizer/plan_node/batch_update.rs index d0351e6fdec2..28dfa79916cc 100644 --- a/src/frontend/src/optimizer/plan_node/batch_update.rs +++ b/src/frontend/src/optimizer/plan_node/batch_update.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use itertools::Itertools; use risingwave_common::catalog::Schema; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::UpdateNode; @@ -84,20 +83,21 @@ impl ToDistributedBatch for BatchUpdate { impl ToBatchPb for BatchUpdate { fn to_batch_prost_body(&self) -> NodeBody { - let exprs = self.core.exprs.iter().map(|x| x.to_expr_proto()).collect(); - - let update_column_indices = self - .core - .update_column_indices + let old_exprs = (self.core.old_exprs) + .iter() + .map(|x| x.to_expr_proto()) + .collect(); + let new_exprs = (self.core.new_exprs) .iter() - .map(|i| *i as _) - .collect_vec(); + .map(|x| x.to_expr_proto()) + .collect(); + NodeBody::Update(UpdateNode { - exprs, table_id: self.core.table_id.table_id(), table_version_id: self.core.table_version_id, returning: self.core.returning, - update_column_indices, + old_exprs, + new_exprs, session_id: self.base.ctx().session_ctx().session_id().0 as u32, }) } @@ -125,6 +125,6 @@ impl ExprRewritable for BatchUpdate { impl ExprVisitable for BatchUpdate { fn visit_exprs(&self, v: &mut dyn ExprVisitor) { - self.core.exprs.iter().for_each(|e| v.visit_expr(e)); + self.core.visit_exprs(v); } } diff --git a/src/frontend/src/optimizer/plan_node/generic/update.rs b/src/frontend/src/optimizer/plan_node/generic/update.rs index 61d044f53c99..d68af1a01ae3 100644 --- a/src/frontend/src/optimizer/plan_node/generic/update.rs +++ b/src/frontend/src/optimizer/plan_node/generic/update.rs @@ -21,7 +21,7 @@ use risingwave_common::types::DataType; use super::{DistillUnit, GenericPlanNode, GenericPlanRef}; use crate::catalog::TableId; -use crate::expr::{ExprImpl, ExprRewriter}; +use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::utils::childless_record; use crate::optimizer::property::FunctionalDependencySet; use crate::OptimizerContextRef; @@ -35,15 +35,15 @@ pub struct Update { pub table_id: TableId, pub table_version_id: TableVersionId, pub input: PlanRef, - pub exprs: Vec, + pub old_exprs: Vec, + pub new_exprs: Vec, pub returning: bool, - pub update_column_indices: Vec, } impl Update { pub fn output_len(&self) -> usize { if self.returning { - self.input.schema().len() + self.new_exprs.len() } else { 1 } @@ -56,18 +56,19 @@ impl GenericPlanNode for Update { fn schema(&self) -> Schema { if self.returning { - self.input.schema().clone() + Schema::new( + self.new_exprs + .iter() + .map(|e| Field::unnamed(e.return_type())) + .collect(), + ) } else { Schema::new(vec![Field::unnamed(DataType::Int64)]) } } fn stream_key(&self) -> Option> { - if self.returning { - Some(self.input.stream_key()?.to_vec()) - } else { - Some(vec![]) - } + None } fn ctx(&self) -> OptimizerContextRef { @@ -81,27 +82,31 @@ impl Update { table_name: String, table_id: TableId, table_version_id: TableVersionId, - exprs: Vec, + old_exprs: Vec, + new_exprs: Vec, returning: bool, - update_column_indices: Vec, ) -> Self { Self { table_name, table_id, table_version_id, input, - exprs, + old_exprs, + new_exprs, returning, - update_column_indices, } } pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) { - self.exprs = self - .exprs - .iter() - .map(|e| r.rewrite_expr(e.clone())) - .collect(); + for exprs in [&mut self.old_exprs, &mut self.new_exprs] { + *exprs = exprs.iter().map(|e| r.rewrite_expr(e.clone())).collect(); + } + } + + pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) { + for exprs in [&self.old_exprs, &self.new_exprs] { + exprs.iter().for_each(|e| v.visit_expr(e)); + } } } @@ -109,7 +114,7 @@ impl DistillUnit for Update { fn distill_with_name<'a>(&self, name: impl Into>) -> XmlNode<'a> { let mut vec = Vec::with_capacity(if self.returning { 3 } else { 2 }); vec.push(("table", Pretty::from(self.table_name.clone()))); - vec.push(("exprs", Pretty::debug(&self.exprs))); + vec.push(("exprs", Pretty::debug(&self.new_exprs))); if self.returning { vec.push(("returning", Pretty::display(&true))); } diff --git a/src/frontend/src/optimizer/plan_node/logical_update.rs b/src/frontend/src/optimizer/plan_node/logical_update.rs index 127b6ed8b317..a5590501715b 100644 --- a/src/frontend/src/optimizer/plan_node/logical_update.rs +++ b/src/frontend/src/optimizer/plan_node/logical_update.rs @@ -12,17 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::catalog::TableVersionId; - use super::generic::GenericPlanRef; use super::utils::impl_distill_by_unit; use super::{ gen_filter_and_pushdown, generic, BatchUpdate, ColPrunable, ExprRewritable, Logical, LogicalProject, PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown, ToBatch, ToStream, }; -use crate::catalog::TableId; use crate::error::Result; -use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor}; +use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::{ ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext, @@ -46,25 +43,6 @@ impl From> for LogicalUpdate { } } -impl LogicalUpdate { - #[must_use] - pub fn table_id(&self) -> TableId { - self.core.table_id - } - - pub fn exprs(&self) -> &[ExprImpl] { - self.core.exprs.as_ref() - } - - pub fn has_returning(&self) -> bool { - self.core.returning - } - - pub fn table_version_id(&self) -> TableVersionId { - self.core.table_version_id - } -} - impl PlanTreeNodeUnary for LogicalUpdate { fn input(&self) -> PlanRef { self.core.input.clone() @@ -86,15 +64,15 @@ impl ExprRewritable for LogicalUpdate { } fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { - let mut new = self.core.clone(); - new.exprs = new.exprs.into_iter().map(|e| r.rewrite_expr(e)).collect(); - Self::from(new).into() + let mut core = self.core.clone(); + core.rewrite_exprs(r); + Self::from(core).into() } } impl ExprVisitable for LogicalUpdate { fn visit_exprs(&self, v: &mut dyn ExprVisitor) { - self.core.exprs.iter().for_each(|e| v.visit_expr(e)); + self.core.visit_exprs(v); } } diff --git a/src/frontend/src/planner/select.rs b/src/frontend/src/planner/select.rs index a9e7dd3526ed..ebed01351f7d 100644 --- a/src/frontend/src/planner/select.rs +++ b/src/frontend/src/planner/select.rs @@ -320,7 +320,7 @@ impl Planner { /// /// The [`InputRef`]s' indexes start from `root.schema().len()`, /// which means they are additional columns beyond the original `root`. - fn substitute_subqueries( + pub(super) fn substitute_subqueries( &mut self, mut root: PlanRef, mut exprs: Vec, @@ -366,10 +366,27 @@ impl Planner { .zip_eq_fast(rewriter.correlated_indices_collection) .zip_eq_fast(rewriter.correlated_ids) { + let return_type = subquery.return_type(); let subroot = self.plan_query(subquery.query)?; let right = match subquery.kind { SubqueryKind::Scalar => subroot.into_unordered_subplan(), + SubqueryKind::UpdateSet => { + let plan = subroot.into_unordered_subplan(); + + // Compose all input columns into a struct with `ROW` function. + let all_input_refs = plan + .schema() + .data_types() + .into_iter() + .enumerate() + .map(|(i, data_type)| InputRef::new(i, data_type).into()) + .collect::>(); + let call = + FunctionCall::new_unchecked(ExprType::Row, all_input_refs, return_type); + + LogicalProject::create(plan, vec![call.into()]) + } SubqueryKind::Existential => { self.create_exists(subroot.into_unordered_subplan())? } diff --git a/src/frontend/src/planner/update.rs b/src/frontend/src/planner/update.rs index ddf9ab0bdf9a..2db18ac0e292 100644 --- a/src/frontend/src/planner/update.rs +++ b/src/frontend/src/planner/update.rs @@ -13,41 +13,92 @@ // limitations under the License. use fixedbitset::FixedBitSet; -use itertools::Itertools; +use risingwave_common::types::{DataType, Scalar}; +use risingwave_pb::expr::expr_node::Type; use super::Planner; -use crate::binder::BoundUpdate; +use crate::binder::{BoundUpdate, UpdateProject}; use crate::error::Result; +use crate::expr::{ExprImpl, FunctionCall, InputRef, Literal}; +use crate::optimizer::plan_node::generic::GenericPlanRef; use crate::optimizer::plan_node::{generic, LogicalProject, LogicalUpdate}; use crate::optimizer::property::{Order, RequiredDist}; use crate::optimizer::{PlanRef, PlanRoot}; impl Planner { pub(super) fn plan_update(&mut self, update: BoundUpdate) -> Result { + let returning = !update.returning_list.is_empty(); + let scan = self.plan_base_table(&update.table)?; let input = if let Some(expr) = update.selection { self.plan_where(scan, expr)? } else { scan }; - let returning = !update.returning_list.is_empty(); - let update_column_indices = update - .table - .table_catalog - .columns() - .iter() - .enumerate() - .filter_map(|(i, c)| c.can_dml().then_some(i)) - .collect_vec(); + let old_schema_len = input.schema().len(); + + // Extend table scan with updated columns. + let with_new: PlanRef = { + let mut plan = input; + + let mut exprs: Vec = plan + .schema() + .data_types() + .into_iter() + .enumerate() + .map(|(index, data_type)| InputRef::new(index, data_type).into()) + .collect(); + + exprs.extend(update.exprs); + + // Substitute subqueries into `LogicalApply`s. + if exprs.iter().any(|e| e.has_subquery()) { + (plan, exprs) = self.substitute_subqueries(plan, exprs)?; + } + + LogicalProject::new(plan, exprs).into() + }; + + let mut olds = Vec::new(); + let mut news = Vec::new(); + + for (i, col) in update.table.table_catalog.columns().iter().enumerate() { + // Skip generated columns and system columns. + if !col.can_dml() { + continue; + } + let data_type = col.data_type(); + + let old: ExprImpl = InputRef::new(i, data_type.clone()).into(); + + let new: ExprImpl = match (update.projects.get(&i)).map(|p| p.offset(old_schema_len)) { + Some(UpdateProject::Simple(j)) => InputRef::new(j, data_type.clone()).into(), + Some(UpdateProject::Composite(j, field)) => FunctionCall::new_unchecked( + Type::Field, + vec![ + InputRef::new(j, with_new.schema().data_types()[j].clone()).into(), // struct + Literal::new(Some((field as i32).to_scalar_value()), DataType::Int32) + .into(), + ], + data_type.clone(), + ) + .into(), + + None => old.clone(), + }; + + olds.push(old); + news.push(new); + } let mut plan: PlanRef = LogicalUpdate::from(generic::Update::new( - input, + with_new, update.table_name.clone(), update.table_id, update.table_version_id, - update.exprs, + olds, + news, returning, - update_column_indices, )) .into(); From c1e8f9a2d177d8a9c9733e06f1c9ce77517582d3 Mon Sep 17 00:00:00 2001 From: xiangjinwu <17769960+xiangjinwu@users.noreply.github.com> Date: Wed, 20 Nov 2024 17:21:39 +0800 Subject: [PATCH 06/11] fix(source): `REFRESH SCHEMA` shall keep `INCLUDE` pk for `UPSERT` (#19384) --- .../source_inline/kafka/avro/alter_table.slt | 17 +++++++++++++++++ src/frontend/src/handler/alter_table_column.rs | 2 ++ src/frontend/src/handler/create_sink.rs | 2 ++ src/frontend/src/handler/create_table.rs | 3 ++- 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/e2e_test/source_inline/kafka/avro/alter_table.slt b/e2e_test/source_inline/kafka/avro/alter_table.slt index 330cdc490cdb..08a98c2cca4c 100644 --- a/e2e_test/source_inline/kafka/avro/alter_table.slt +++ b/e2e_test/source_inline/kafka/avro/alter_table.slt @@ -78,3 +78,20 @@ ABC statement ok drop table t; + +statement ok +create table t (primary key (kafka_key)) +INCLUDE key as kafka_key +WITH ( + ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, + topic = 'avro_alter_table_test' +) +FORMAT UPSERT ENCODE AVRO ( + schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}' +); + +statement ok +ALTER TABLE t REFRESH SCHEMA; + +statement ok +drop table t; diff --git a/src/frontend/src/handler/alter_table_column.rs b/src/frontend/src/handler/alter_table_column.rs index cf55b82a4750..19f8355a77bc 100644 --- a/src/frontend/src/handler/alter_table_column.rs +++ b/src/frontend/src/handler/alter_table_column.rs @@ -180,6 +180,7 @@ pub async fn get_replace_table_plan( wildcard_idx, cdc_table_info, format_encode, + include_column_options, .. } = new_definition else { @@ -206,6 +207,7 @@ pub async fn get_replace_table_plan( with_version_column, cdc_table_info, new_version_columns, + include_column_options, ) .await?; diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index a7c997d6232e..e280f9090926 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -559,6 +559,7 @@ pub(crate) async fn reparse_table_for_sink( append_only, on_conflict, with_version_column, + include_column_options, .. } = definition else { @@ -581,6 +582,7 @@ pub(crate) async fn reparse_table_for_sink( with_version_column, None, None, + include_column_options, ) .await?; diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index 6118ba5ccd36..ff2e41037078 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -1310,6 +1310,7 @@ pub async fn generate_stream_graph_for_replace_table( with_version_column: Option, cdc_table_info: Option, new_version_columns: Option>, + include_column_options: IncludeOption, ) -> Result<(StreamFragmentGraph, Table, Option, TableJobType)> { use risingwave_pb::catalog::table::OptionalAssociatedSourceId; @@ -1328,7 +1329,7 @@ pub async fn generate_stream_graph_for_replace_table( append_only, on_conflict, with_version_column, - vec![], + include_column_options, ) .await?, TableJobType::General, From a0b65fd618483e38ed351a4d691639ce65c8fa6c Mon Sep 17 00:00:00 2001 From: Eric Fu Date: Wed, 20 Nov 2024 17:24:01 +0800 Subject: [PATCH 07/11] feat(pgwire): support struct type in extended mode (#19450) --- ci/scripts/run-e2e-test.sh | 4 ++++ e2e_test/python_client/main.py | 19 ++++++++++++++++++ src/common/src/types/postgres_type.rs | 5 ++--- src/common/src/types/to_binary.rs | 29 +++++++++++++++++++++++++-- 4 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 e2e_test/python_client/main.py diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index a8601fbb0ebe..e84ead4a81df 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -109,6 +109,10 @@ sqllogictest -p 4566 -d dev './e2e_test/ttl/ttl.slt' sqllogictest -p 4566 -d dev './e2e_test/database/prepare.slt' sqllogictest -p 4566 -d test './e2e_test/database/test.slt' +echo "--- e2e, $mode, python_client" +python3 -m pip install --break-system-packages psycopg +python3 ./e2e_test/python_client/main.py + echo "--- e2e, $mode, subscription" python3 -m pip install --break-system-packages psycopg2-binary sqllogictest -p 4566 -d dev './e2e_test/subscription/check_sql_statement.slt' diff --git a/e2e_test/python_client/main.py b/e2e_test/python_client/main.py new file mode 100644 index 000000000000..bb41ba6c38f3 --- /dev/null +++ b/e2e_test/python_client/main.py @@ -0,0 +1,19 @@ +import psycopg + +def test_psycopg_extended_mode(): + conn = psycopg.connect(host='localhost', port='4566', dbname='dev', user='root') + with conn.cursor() as cur: + cur.execute("select Array[1::bigint, 2::bigint, 3::bigint]", binary=True) + assert cur.fetchone() == ([1, 2, 3],) + + cur.execute("select Array['foo', null, 'bar']", binary=True) + assert cur.fetchone() == (['foo', None, 'bar'],) + + cur.execute("select ROW('123 Main St', 'New York', '10001')", binary=True) + assert cur.fetchone() == (('123 Main St', 'New York', '10001'),) + + cur.execute("select array[ROW('123 Main St', 'New York', '10001'), ROW('234 Main St', null, '10001')]", binary=True) + assert cur.fetchone() == ([('123 Main St', 'New York', '10001'), ('234 Main St', None, '10001')],) + +if __name__ == '__main__': + test_psycopg_extended_mode() diff --git a/src/common/src/types/postgres_type.rs b/src/common/src/types/postgres_type.rs index d85f08ed59cc..c84f3e19f309 100644 --- a/src/common/src/types/postgres_type.rs +++ b/src/common/src/types/postgres_type.rs @@ -116,7 +116,7 @@ impl DataType { )* DataType::Int256 => 1302, DataType::Serial => 1016, - DataType::Struct(_) => -1, + DataType::Struct(_) => 2287, // pseudo-type of array[struct] (see `pg_type.dat`) DataType::List { .. } => unreachable!("Never reach here!"), DataType::Map(_) => 1304, } @@ -125,8 +125,7 @@ impl DataType { DataType::Int256 => 1301, DataType::Map(_) => 1303, // TODO: Support to give a new oid for custom struct type. #9434 - // 1043 is varchar - DataType::Struct(_) => 1043, + DataType::Struct(_) => 2249, // pseudo-type of struct (see `pg_type.dat`) } } } diff --git a/src/common/src/types/to_binary.rs b/src/common/src/types/to_binary.rs index 7c5e88dbc10c..294f96bc7045 100644 --- a/src/common/src/types/to_binary.rs +++ b/src/common/src/types/to_binary.rs @@ -14,12 +14,13 @@ use bytes::{BufMut, Bytes, BytesMut}; use postgres_types::{ToSql, Type}; +use rw_iter_util::ZipEqFast; use super::{ DataType, Date, Decimal, Interval, ScalarRefImpl, Serial, Time, Timestamp, Timestamptz, F32, F64, }; -use crate::array::ListRef; +use crate::array::{ListRef, StructRef}; use crate::error::NotImplemented; /// Error type for [`ToBinary`] trait. @@ -116,6 +117,29 @@ impl ToBinary for ListRef<'_> { } } +impl ToBinary for StructRef<'_> { + fn to_binary_with_type(&self, ty: &DataType) -> Result { + // Reference: Postgres code `src/backend/utils/adt/rowtypes.c` + // https://github.com/postgres/postgres/blob/a3699daea2026de324ed7cc7115c36d3499010d3/src/backend/utils/adt/rowtypes.c#L687 + let mut buf = BytesMut::new(); + buf.put_i32(ty.as_struct().len() as i32); // number of columns + for (datum, field_ty) in self.iter_fields_ref().zip_eq_fast(ty.as_struct().types()) { + buf.put_i32(field_ty.to_oid()); // column type + match datum { + None => { + buf.put_i32(-1); // -1 length means a NULL + } + Some(value) => { + let data = value.to_binary_with_type(field_ty)?; + buf.put_i32(data.len() as i32); // Length of element + buf.put(data); + } + } + } + Ok(buf.into()) + } +} + impl ToBinary for ScalarRefImpl<'_> { fn to_binary_with_type(&self, ty: &DataType) -> Result { match self { @@ -137,7 +161,8 @@ impl ToBinary for ScalarRefImpl<'_> { ScalarRefImpl::Bytea(v) => v.to_binary_with_type(ty), ScalarRefImpl::Jsonb(v) => v.to_binary_with_type(ty), ScalarRefImpl::List(v) => v.to_binary_with_type(ty), - ScalarRefImpl::Struct(_) | ScalarRefImpl::Map(_) => { + ScalarRefImpl::Struct(v) => v.to_binary_with_type(ty), + ScalarRefImpl::Map(_) => { bail_not_implemented!( issue = 7949, "the pgwire extended-mode encoding for {ty} is unsupported" From 34bb3cb0a19edddd629a9a9eadc3ed7f57256dd3 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 20 Nov 2024 17:49:36 +0800 Subject: [PATCH 08/11] fix(iceberg): fix iceberg parquet file size in bytes statistics (#19471) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3cc97411dd88..c8be3e154ddd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6329,7 +6329,7 @@ dependencies = [ [[package]] name = "icelake" version = "0.3.141592654" -source = "git+https://github.com/risingwavelabs/icelake.git?rev=68bc9654ea1a1f696b79718b7bcbd79f43488186#68bc9654ea1a1f696b79718b7bcbd79f43488186" +source = "git+https://github.com/risingwavelabs/icelake.git?rev=0ec44fa826c91139c9cf459b005741df990ae9da#0ec44fa826c91139c9cf459b005741df990ae9da" dependencies = [ "anyhow", "apache-avro 0.17.0 (git+https://github.com/apache/avro.git)", diff --git a/Cargo.toml b/Cargo.toml index bb697d6d0d33..c260bf8c5293 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -138,7 +138,7 @@ otlp-embedded = { git = "https://github.com/risingwavelabs/otlp-embedded", rev = prost = { version = "0.13" } prost-build = { version = "0.13" } # branch rw_patch -icelake = { git = "https://github.com/risingwavelabs/icelake.git", rev = "68bc9654ea1a1f696b79718b7bcbd79f43488186", features = [ +icelake = { git = "https://github.com/risingwavelabs/icelake.git", rev = "0ec44fa826c91139c9cf459b005741df990ae9da", features = [ "prometheus", ] } # branch dev-rebase-main-20241030 From b9c3f709cc927682738ffc6fee1528be04012cd6 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 20 Nov 2024 18:01:04 +0800 Subject: [PATCH 09/11] fix: fix potential data loss for shared source (#19443) --- .../cdc/mysql/mysql_create_drop.slt.serial | 6 -- .../kafka/shared_source.slt.serial | 69 +++++++++---- src/batch/src/executor/source.rs | 4 +- src/connector/src/source/base.rs | 8 +- src/connector/src/source/cdc/mod.rs | 2 +- .../src/source/kafka/enumerator/client.rs | 2 - .../src/source/kafka/source/reader.rs | 34 ++++++- src/connector/src/source/kafka/split.rs | 12 --- src/connector/src/source/reader/reader.rs | 28 ++++-- .../src/executor/source/fetch_executor.rs | 4 +- .../source/source_backfill_executor.rs | 1 - .../src/executor/source/source_executor.rs | 97 +++++++++---------- .../src/from_proto/source/trad_source.rs | 2 +- .../src/task/barrier_manager/progress.rs | 1 + 14 files changed, 161 insertions(+), 109 deletions(-) diff --git a/e2e_test/source_inline/cdc/mysql/mysql_create_drop.slt.serial b/e2e_test/source_inline/cdc/mysql/mysql_create_drop.slt.serial index fde008079dc6..2766f37fefe1 100644 --- a/e2e_test/source_inline/cdc/mysql/mysql_create_drop.slt.serial +++ b/e2e_test/source_inline/cdc/mysql/mysql_create_drop.slt.serial @@ -49,12 +49,6 @@ create source s with ( sleep 2s -# At the beginning, the source is paused. It will resume after a downstream is created. -system ok -internal_table.mjs --name s --type '' --count ----- -count: 0 - statement ok create table tt1_shared (v1 int, diff --git a/e2e_test/source_inline/kafka/shared_source.slt.serial b/e2e_test/source_inline/kafka/shared_source.slt.serial index 3397f90f081d..af6b371d21c4 100644 --- a/e2e_test/source_inline/kafka/shared_source.slt.serial +++ b/e2e_test/source_inline/kafka/shared_source.slt.serial @@ -59,11 +59,17 @@ select count(*) from rw_internal_tables where name like '%s0%'; sleep 1s -# SourceExecutor's ingestion does not start (state table is empty), even after sleep +statement ok +flush; + +# SourceExecutor's starts from latest. system ok internal_table.mjs --name s0 --type source ---- -(empty) +0,"{""split_info"": {""partition"": 0, ""start_offset"": 0, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" +1,"{""split_info"": {""partition"": 1, ""start_offset"": 0, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" +2,"{""split_info"": {""partition"": 2, ""start_offset"": 1, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" +3,"{""split_info"": {""partition"": 3, ""start_offset"": 2, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" statement ok @@ -72,12 +78,6 @@ create materialized view mv_1 as select * from s0; # Wait enough time to ensure SourceExecutor consumes all Kafka data. sleep 2s -# SourceExecutor's ingestion started, but it only starts from latest (offset 1). -system ok -internal_table.mjs --name s0 --type source ----- -(empty) - # SourceBackfill starts from offset 0, with backfill_info: HasDataToBackfill { latest_offset: "0" } (decided by kafka high watermark). # (meaning upstream already consumed offset 0, so we only need to backfill to offset 0) @@ -144,7 +144,7 @@ EOF sleep 2s -# SourceExecutor's finally got new data now. +# SourceExecutor's got new data. system ok internal_table.mjs --name s0 --type source ---- @@ -185,16 +185,6 @@ select v1, v2 from mv_1; 4 dd -# start_offset changed to 1 -system ok -internal_table.mjs --name s0 --type source ----- -0,"{""split_info"": {""partition"": 0, ""start_offset"": 1, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" -1,"{""split_info"": {""partition"": 1, ""start_offset"": 1, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" -2,"{""split_info"": {""partition"": 2, ""start_offset"": 2, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" -3,"{""split_info"": {""partition"": 3, ""start_offset"": 3, ""stop_offset"": null, ""topic"": ""shared_source""}, ""split_type"": ""kafka""}" - - # Transition from SourceCachingUp to Finished after consuming one upstream message. system ok internal_table.mjs --name mv_1 --type sourcebackfill @@ -334,6 +324,47 @@ internal_table.mjs --name s0 --type source # # risedev psql -c "select name, flags, parallelism from rw_fragments JOIN rw_relations ON rw_fragments.table_id = rw_relations.id order by name;" +# Test: rate limit and resume won't lose data + +statement ok +alter source s0 set source_rate_limit to 0; + + +system ok +cat < @@ -108,7 +108,7 @@ impl TryFromBTreeMap for P { } } -pub async fn create_split_reader( +pub async fn create_split_reader( prop: P, splits: Vec, parser_config: ParserConfig, @@ -375,6 +375,10 @@ pub trait SplitReader: Sized + Send { fn backfill_info(&self) -> HashMap { HashMap::new() } + + async fn seek_to_latest(&mut self) -> Result> { + Err(anyhow!("seek_to_latest is not supported for this connector").into()) + } } /// Information used to determine whether we should start and finish source backfill. diff --git a/src/connector/src/source/cdc/mod.rs b/src/connector/src/source/cdc/mod.rs index 3f3626449153..9e99c7db9e5e 100644 --- a/src/connector/src/source/cdc/mod.rs +++ b/src/connector/src/source/cdc/mod.rs @@ -58,7 +58,7 @@ pub fn build_cdc_table_id(source_id: u32, external_table_name: &str) -> String { format!("{}.{}", source_id, external_table_name) } -pub trait CdcSourceTypeTrait: Send + Sync + Clone + 'static { +pub trait CdcSourceTypeTrait: Send + Sync + Clone + std::fmt::Debug + 'static { const CDC_CONNECTOR_NAME: &'static str; fn source_type() -> CdcSourceType; } diff --git a/src/connector/src/source/kafka/enumerator/client.rs b/src/connector/src/source/kafka/enumerator/client.rs index efcd2c297348..541c3757c27c 100644 --- a/src/connector/src/source/kafka/enumerator/client.rs +++ b/src/connector/src/source/kafka/enumerator/client.rs @@ -191,7 +191,6 @@ impl SplitEnumerator for KafkaSplitEnumerator { partition, start_offset: start_offsets.remove(&partition).unwrap(), stop_offset: stop_offsets.remove(&partition).unwrap(), - hack_seek_to_latest: false, }) .collect(); @@ -299,7 +298,6 @@ impl KafkaSplitEnumerator { partition: *partition, start_offset: Some(start_offset), stop_offset: Some(stop_offset), - hack_seek_to_latest:false } }) .collect::>()) diff --git a/src/connector/src/source/kafka/source/reader.rs b/src/connector/src/source/kafka/source/reader.rs index b9523eca98b5..20fde897ceb4 100644 --- a/src/connector/src/source/kafka/source/reader.rs +++ b/src/connector/src/source/kafka/source/reader.rs @@ -37,13 +37,15 @@ use crate::source::kafka::{ }; use crate::source::{ into_chunk_stream, BackfillInfo, BoxChunkSourceStream, Column, SourceContextRef, SplitId, - SplitMetaData, SplitReader, + SplitImpl, SplitMetaData, SplitReader, }; pub struct KafkaSplitReader { consumer: StreamConsumer, offsets: HashMap, Option)>, backfill_info: HashMap, + splits: Vec, + sync_call_timeout: Duration, bytes_per_second: usize, max_num_messages: usize, parser_config: ParserConfig, @@ -110,12 +112,10 @@ impl SplitReader for KafkaSplitReader { let mut offsets = HashMap::new(); let mut backfill_info = HashMap::new(); - for split in splits { + for split in splits.clone() { offsets.insert(split.id(), (split.start_offset, split.stop_offset)); - if split.hack_seek_to_latest { - tpl.add_partition_offset(split.topic.as_str(), split.partition, Offset::End)?; - } else if let Some(offset) = split.start_offset { + if let Some(offset) = split.start_offset { tpl.add_partition_offset( split.topic.as_str(), split.partition, @@ -168,8 +168,10 @@ impl SplitReader for KafkaSplitReader { Ok(Self { consumer, offsets, + splits, backfill_info, bytes_per_second, + sync_call_timeout: properties.common.sync_call_timeout, max_num_messages, parser_config, source_ctx, @@ -185,6 +187,28 @@ impl SplitReader for KafkaSplitReader { fn backfill_info(&self) -> HashMap { self.backfill_info.clone() } + + async fn seek_to_latest(&mut self) -> Result> { + let mut latest_splits: Vec = Vec::new(); + let mut tpl = TopicPartitionList::with_capacity(self.splits.len()); + for mut split in self.splits.clone() { + // we can't get latest offset if we use Offset::End, so we just fetch watermark here. + let (_low, high) = self + .consumer + .fetch_watermarks( + split.topic.as_str(), + split.partition, + self.sync_call_timeout, + ) + .await?; + tpl.add_partition_offset(split.topic.as_str(), split.partition, Offset::Offset(high))?; + split.start_offset = Some(high - 1); + latest_splits.push(split.into()); + } + // replace the previous assignment + self.consumer.assign(&tpl)?; + Ok(latest_splits) + } } impl KafkaSplitReader { diff --git a/src/connector/src/source/kafka/split.rs b/src/connector/src/source/kafka/split.rs index 791836ac2c85..fa969bb37111 100644 --- a/src/connector/src/source/kafka/split.rs +++ b/src/connector/src/source/kafka/split.rs @@ -32,12 +32,6 @@ pub struct KafkaSplit { /// A better approach would be to make it **inclusive**. pub(crate) start_offset: Option, pub(crate) stop_offset: Option, - #[serde(skip)] - /// Used by shared source to hackily seek to the latest offset without fetching start offset first. - /// XXX: But why do we fetch low watermark for latest start offset..? - /// - /// When this is `true`, `start_offset` will be ignored. - pub(crate) hack_seek_to_latest: bool, } impl SplitMetaData for KafkaSplit { @@ -72,16 +66,10 @@ impl KafkaSplit { partition, start_offset, stop_offset, - hack_seek_to_latest: false, } } pub fn get_topic_and_partition(&self) -> (String, i32) { (self.topic.clone(), self.partition) } - - /// This should only be used for a fresh split, not persisted in state table yet. - pub fn seek_to_latest_offset(&mut self) { - self.hack_seek_to_latest = true; - } } diff --git a/src/connector/src/source/reader/reader.rs b/src/connector/src/source/reader/reader.rs index 89335f8f0d80..f849e7ba21aa 100644 --- a/src/connector/src/source/reader/reader.rs +++ b/src/connector/src/source/reader/reader.rs @@ -37,8 +37,8 @@ use crate::source::filesystem::opendal_source::{ use crate::source::filesystem::{FsPageItem, OpendalFsSplit}; use crate::source::{ create_split_reader, BackfillInfo, BoxChunkSourceStream, BoxTryStream, Column, - ConnectorProperties, ConnectorState, SourceColumnDesc, SourceContext, SplitId, SplitReader, - WaitCheckpointTask, + ConnectorProperties, ConnectorState, SourceColumnDesc, SourceContext, SplitId, SplitImpl, + SplitReader, WaitCheckpointTask, }; use crate::{dispatch_source_prop, WithOptionsSecResolved}; @@ -211,14 +211,17 @@ impl SourceReader { } /// Build `SplitReader`s and then `BoxChunkSourceStream` from the given `ConnectorState` (`SplitImpl`s). + /// + /// If `seek_to_latest` is true, will also return the latest splits after seek. pub async fn build_stream( &self, state: ConnectorState, column_ids: Vec, source_ctx: Arc, - ) -> ConnectorResult { + seek_to_latest: bool, + ) -> ConnectorResult<(BoxChunkSourceStream, Option>)> { let Some(splits) = state else { - return Ok(pending().boxed()); + return Ok((pending().boxed(), None)); }; let config = self.config.clone(); let columns = self.get_target_columns(column_ids)?; @@ -243,7 +246,7 @@ impl SourceReader { let support_multiple_splits = config.support_multiple_splits(); dispatch_source_prop!(config, prop, { - let readers = if support_multiple_splits { + let mut readers = if support_multiple_splits { tracing::debug!( "spawning connector split reader for multiple splits {:?}", splits @@ -268,7 +271,20 @@ impl SourceReader { .await? }; - Ok(select_all(readers.into_iter().map(|r| r.into_stream())).boxed()) + let latest_splits = if seek_to_latest { + let mut latest_splits = Vec::new(); + for reader in &mut readers { + latest_splits.extend(reader.seek_to_latest().await?); + } + Some(latest_splits) + } else { + None + }; + + Ok(( + select_all(readers.into_iter().map(|r| r.into_stream())).boxed(), + latest_splits, + )) }) } } diff --git a/src/stream/src/executor/source/fetch_executor.rs b/src/stream/src/executor/source/fetch_executor.rs index 13bbac436d36..8964eaecff45 100644 --- a/src/stream/src/executor/source/fetch_executor.rs +++ b/src/stream/src/executor/source/fetch_executor.rs @@ -160,9 +160,9 @@ impl FsFetchExecutor { batch: SplitBatch, rate_limit_rps: Option, ) -> StreamExecutorResult { - let stream = source_desc + let (stream, _) = source_desc .source - .build_stream(batch, column_ids, Arc::new(source_ctx)) + .build_stream(batch, column_ids, Arc::new(source_ctx), false) .await .map_err(StreamExecutorError::connector_error)?; Ok(apply_rate_limit(stream, rate_limit_rps).boxed()) diff --git a/src/stream/src/executor/source/source_backfill_executor.rs b/src/stream/src/executor/source/source_backfill_executor.rs index 9df74a719d46..6ada0f2b62eb 100644 --- a/src/stream/src/executor/source/source_backfill_executor.rs +++ b/src/stream/src/executor/source/source_backfill_executor.rs @@ -609,7 +609,6 @@ impl SourceBackfillExecutorInner { .await?; if self.should_report_finished(&backfill_stage.states) { - tracing::debug!("progress finish"); self.progress.finish( barrier.epoch, backfill_stage.total_backfilled_rows(), diff --git a/src/stream/src/executor/source/source_executor.rs b/src/stream/src/executor/source/source_executor.rs index e0bbe3d1f6d9..118b33c08ae5 100644 --- a/src/stream/src/executor/source/source_executor.rs +++ b/src/stream/src/executor/source/source_executor.rs @@ -71,7 +71,7 @@ pub struct SourceExecutor { /// Rate limit in rows/s. rate_limit_rps: Option, - is_shared: bool, + is_shared_non_cdc: bool, } impl SourceExecutor { @@ -82,7 +82,7 @@ impl SourceExecutor { barrier_receiver: UnboundedReceiver, system_params: SystemParamsReaderRef, rate_limit_rps: Option, - is_shared: bool, + is_shared_non_cdc: bool, ) -> Self { Self { actor_ctx, @@ -91,7 +91,7 @@ impl SourceExecutor { barrier_receiver: Some(barrier_receiver), system_params, rate_limit_rps, - is_shared, + is_shared_non_cdc, } } @@ -116,11 +116,13 @@ impl SourceExecutor { })) } + /// If `seek_to_latest` is true, will also return the latest splits after seek. pub async fn build_stream_source_reader( &self, source_desc: &SourceDesc, state: ConnectorState, - ) -> StreamExecutorResult { + seek_to_latest: bool, + ) -> StreamExecutorResult<(BoxChunkSourceStream, Option>)> { let column_ids = source_desc .columns .iter() @@ -183,13 +185,16 @@ impl SourceExecutor { source_desc.source.config.clone(), schema_change_tx, ); - let stream = source_desc + let (stream, latest_splits) = source_desc .source - .build_stream(state, column_ids, Arc::new(source_ctx)) + .build_stream(state, column_ids, Arc::new(source_ctx), seek_to_latest) .await - .map_err(StreamExecutorError::connector_error); + .map_err(StreamExecutorError::connector_error)?; - Ok(apply_rate_limit(stream?, self.rate_limit_rps).boxed()) + Ok(( + apply_rate_limit(stream, self.rate_limit_rps).boxed(), + latest_splits, + )) } fn is_auto_schema_change_enable(&self) -> bool { @@ -367,10 +372,10 @@ impl SourceExecutor { ); // Replace the source reader with a new one of the new state. - let reader = self - .build_stream_source_reader(source_desc, Some(target_state.clone())) - .await? - .map_err(StreamExecutorError::connector_error); + let (reader, _) = self + .build_stream_source_reader(source_desc, Some(target_state.clone()), false) + .await?; + let reader = reader.map_err(StreamExecutorError::connector_error); stream.replace_data_stream(reader); @@ -459,7 +464,7 @@ impl SourceExecutor { }; core.split_state_store.init_epoch(first_epoch).await?; - + let mut is_uninitialized = self.actor_ctx.initial_dispatch_num == 0; for ele in &mut boot_state { if let Some(recover_state) = core .split_state_store @@ -467,42 +472,47 @@ impl SourceExecutor { .await? { *ele = recover_state; + // if state store is non-empty, we consider it's initialized. + is_uninitialized = false; } else { // This is a new split, not in state table. - if self.is_shared { - // For shared source, we start from latest and let the downstream SourceBackfillExecutors to read historical data. - // It's highly probable that the work of scanning historical data cannot be shared, - // so don't waste work on it. - // For more details, see https://github.com/risingwavelabs/risingwave/issues/16576#issuecomment-2095413297 - if ele.is_cdc_split() { - // shared CDC source already starts from latest. - continue; - } - match ele { - SplitImpl::Kafka(split) => { - split.seek_to_latest_offset(); - } - _ => unreachable!("only kafka source can be shared, got {:?}", ele), - } - } + // make sure it is written to state table later. + // Then even it receives no messages, we can observe it in state table. + core.updated_splits_in_epoch.insert(ele.id(), ele.clone()); } } // init in-memory split states with persisted state if any core.init_split_state(boot_state.clone()); - let mut is_uninitialized = self.actor_ctx.initial_dispatch_num == 0; // Return the ownership of `stream_source_core` to the source executor. self.stream_source_core = Some(core); let recover_state: ConnectorState = (!boot_state.is_empty()).then_some(boot_state); tracing::debug!(state = ?recover_state, "start with state"); - let source_chunk_reader = self - .build_stream_source_reader(&source_desc, recover_state) + let (source_chunk_reader, latest_splits) = self + .build_stream_source_reader( + &source_desc, + recover_state, + // For shared source, we start from latest and let the downstream SourceBackfillExecutors to read historical data. + // It's highly probable that the work of scanning historical data cannot be shared, + // so don't waste work on it. + // For more details, see https://github.com/risingwavelabs/risingwave/issues/16576#issuecomment-2095413297 + // Note that shared CDC source is special. It already starts from latest. + self.is_shared_non_cdc && is_uninitialized, + ) .instrument_await("source_build_reader") - .await? - .map_err(StreamExecutorError::connector_error); - + .await?; + let source_chunk_reader = source_chunk_reader.map_err(StreamExecutorError::connector_error); + if let Some(latest_splits) = latest_splits { + // make sure it is written to state table later. + // Then even it receives no messages, we can observe it in state table. + self.stream_source_core + .as_mut() + .unwrap() + .updated_splits_in_epoch + .extend(latest_splits.into_iter().map(|s| (s.id(), s))); + } // Merge the chunks from source and the barriers into a single stream. We prioritize // barriers over source data chunks here. let barrier_stream = barrier_to_message_stream(barrier_receiver).boxed(); @@ -510,14 +520,9 @@ impl SourceExecutor { StreamReaderWithPause::::new(barrier_stream, source_chunk_reader); let mut command_paused = false; - // - For shared source, pause until there's a MV. // - If the first barrier requires us to pause on startup, pause the stream. - if (self.is_shared && is_uninitialized) || is_pause_on_startup { - tracing::info!( - is_shared = self.is_shared, - is_uninitialized = is_uninitialized, - "source paused on startup" - ); + if is_pause_on_startup { + tracing::info!("source paused on startup"); stream.pause_stream(); command_paused = true; } @@ -562,14 +567,6 @@ impl SourceExecutor { let epoch = barrier.epoch; - if self.is_shared - && is_uninitialized - && barrier.has_more_downstream_fragments(self.actor_ctx.id) - { - stream.resume_stream(); - is_uninitialized = false; - } - if let Some(mutation) = barrier.mutation.as_deref() { match mutation { Mutation::Pause => { diff --git a/src/stream/src/from_proto/source/trad_source.rs b/src/stream/src/from_proto/source/trad_source.rs index 98746a672e43..4d4786eea3bf 100644 --- a/src/stream/src/from_proto/source/trad_source.rs +++ b/src/stream/src/from_proto/source/trad_source.rs @@ -232,7 +232,7 @@ impl ExecutorBuilder for SourceExecutorBuilder { barrier_receiver, system_params, source.rate_limit, - is_shared, + is_shared && !source.with_properties.is_cdc_connector(), ) .boxed() } diff --git a/src/stream/src/task/barrier_manager/progress.rs b/src/stream/src/task/barrier_manager/progress.rs index dba8f5050627..c860b8f430fa 100644 --- a/src/stream/src/task/barrier_manager/progress.rs +++ b/src/stream/src/task/barrier_manager/progress.rs @@ -250,6 +250,7 @@ impl CreateMviewProgressReporter { if let Some(BackfillState::DoneConsumingUpstreamTableOrSource(_)) = self.state { return; } + tracing::debug!("progress finish"); self.update_inner( epoch, BackfillState::DoneConsumingUpstreamTableOrSource(current_consumed_rows), From fe69ce3093b33e16e8f1682eddeac28789190764 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 20 Nov 2024 19:40:23 +0800 Subject: [PATCH 10/11] feat: support backfill_rate_limit for source backfill (#19445) --- .../alter/rate_limit_source_kafka_shared.slt | 134 ++++++++++++++++++ proto/stream_plan.proto | 10 +- src/connector/src/sink/encoder/json.rs | 2 +- src/meta/service/src/stream_service.rs | 4 +- src/meta/src/controller/streaming_job.rs | 13 +- src/meta/src/manager/metadata.rs | 4 +- .../source/source_backfill_executor.rs | 27 ++++ .../src/executor/source/source_executor.rs | 2 +- 8 files changed, 177 insertions(+), 19 deletions(-) create mode 100644 e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt diff --git a/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt b/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt new file mode 100644 index 000000000000..29c0b83aa40d --- /dev/null +++ b/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt @@ -0,0 +1,134 @@ +control substitution on + +############## Create kafka seed data + +statement ok +create table kafka_seed_data (v1 int); + +statement ok +insert into kafka_seed_data select * from generate_series(1, 1000); + +############## Sink into kafka + +statement ok +create sink kafka_sink +from + kafka_seed_data with ( + ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, + topic = 'test_rate_limit_shared', + type = 'append-only', + force_append_only='true' +); + +############## Source from kafka (rate_limit = 0) + +# Wait for the topic to create +skipif in-memory +sleep 5s + +statement ok +create source kafka_source (v1 int) with ( + ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, + topic = 'test_rate_limit_shared', + source_rate_limit = 0, +) FORMAT PLAIN ENCODE JSON + +statement ok +flush; + +############## Check data + +skipif in-memory +sleep 3s + +############## Create MV on source + +statement ok +create materialized view rl_mv1 as select count(*) from kafka_source; + +############## Although source is rate limited, the MV's SourceBackfill is not. + +statement ok +flush; + +query I +select * from rl_mv1; +---- +1000 + +############## Insert more data. They will not go into the MV. + +statement ok +insert into kafka_seed_data select * from generate_series(1, 1000); + +sleep 3s + +query I +select * from rl_mv1; +---- +1000 + +statement ok +SET BACKGROUND_DDL=true; + +statement ok +SET BACKFILL_RATE_LIMIT=0; + +statement ok +create materialized view rl_mv2 as select count(*) from kafka_source; + +sleep 1s + +query T +SELECT progress from rw_ddl_progress; +---- +0 rows consumed + +############## Alter Source (rate_limit = 0 --> rate_limit = 1000) + +statement ok +alter source kafka_source set source_rate_limit to 1000; + +sleep 3s + +query I +select * from rl_mv1; +---- +2000 + +query T +SELECT progress from rw_ddl_progress; +---- +0 rows consumed + + + +statement error +alter materialized view rl_mv2 set source_rate_limit = 1000; +---- +db error: ERROR: Failed to run the query + +Caused by: + sql parser error: expected SCHEMA/PARALLELISM/BACKFILL_RATE_LIMIT after SET, found: source_rate_limit +LINE 1: alter materialized view rl_mv2 set source_rate_limit = 1000; + ^ + + +statement ok +alter materialized view rl_mv2 set backfill_rate_limit = 2000; + +sleep 3s + +query ? +select * from rl_mv2; +---- +2000 + + +############## Cleanup + +statement ok +drop source kafka_source cascade; + +statement ok +drop table kafka_seed_data cascade; diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index ad0601189ec9..70c0d229394b 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -189,7 +189,7 @@ message StreamSource { map with_properties = 6; catalog.StreamSourceInfo info = 7; string source_name = 8; - // Streaming rate limit + // Source rate limit optional uint32 rate_limit = 9; map secret_refs = 10; } @@ -205,7 +205,7 @@ message StreamFsFetch { map with_properties = 6; catalog.StreamSourceInfo info = 7; string source_name = 8; - // Streaming rate limit + // Source rate limit optional uint32 rate_limit = 9; map secret_refs = 10; } @@ -231,7 +231,7 @@ message SourceBackfillNode { catalog.StreamSourceInfo info = 4; string source_name = 5; map with_properties = 6; - // Streaming rate limit + // Backfill rate limit optional uint32 rate_limit = 7; // fields above are the same as StreamSource @@ -609,7 +609,7 @@ message StreamScanNode { // Used iff `ChainType::Backfill`. plan_common.StorageTableDesc table_desc = 7; - // The rate limit for the stream scan node. + // The backfill rate limit for the stream scan node. optional uint32 rate_limit = 8; // Snapshot read every N barriers @@ -646,7 +646,7 @@ message StreamCdcScanNode { // The external table that will be backfilled for CDC. plan_common.ExternalTableDesc cdc_table_desc = 5; - // The rate limit for the stream cdc scan node. + // The backfill rate limit for the stream cdc scan node. optional uint32 rate_limit = 6; // Whether skip the backfill and only consume from upstream. diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 7691b3de5f44..b3c0580a5a78 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -204,7 +204,7 @@ fn datum_to_json_object( let data_type = field.data_type(); - tracing::debug!("datum_to_json_object: {:?}, {:?}", data_type, scalar_ref); + tracing::trace!("datum_to_json_object: {:?}, {:?}", data_type, scalar_ref); let value = match (data_type, scalar_ref) { (DataType::Boolean, ScalarRefImpl::Bool(v)) => { diff --git a/src/meta/service/src/stream_service.rs b/src/meta/service/src/stream_service.rs index 91f73a292025..dfd8ec21187f 100644 --- a/src/meta/service/src/stream_service.rs +++ b/src/meta/service/src/stream_service.rs @@ -114,12 +114,12 @@ impl StreamManagerService for StreamServiceImpl { } ThrottleTarget::Mv => { self.metadata_manager - .update_mv_rate_limit_by_table_id(TableId::from(request.id), request.rate) + .update_backfill_rate_limit_by_table_id(TableId::from(request.id), request.rate) .await? } ThrottleTarget::CdcTable => { self.metadata_manager - .update_mv_rate_limit_by_table_id(TableId::from(request.id), request.rate) + .update_backfill_rate_limit_by_table_id(TableId::from(request.id), request.rate) .await? } ThrottleTarget::Unspecified => { diff --git a/src/meta/src/controller/streaming_job.rs b/src/meta/src/controller/streaming_job.rs index d3e170aff75a..5139b5069d9d 100644 --- a/src/meta/src/controller/streaming_job.rs +++ b/src/meta/src/controller/streaming_job.rs @@ -1317,7 +1317,6 @@ impl CatalogController { .map(|(id, mask, stream_node)| (id, mask, stream_node.to_protobuf())) .collect_vec(); - // TODO: limit source backfill? fragments.retain_mut(|(_, fragment_type_mask, stream_node)| { let mut found = false; if *fragment_type_mask & PbFragmentTypeFlag::Source as i32 != 0 { @@ -1384,7 +1383,7 @@ impl CatalogController { // edit the `rate_limit` of the `Chain` node in given `table_id`'s fragments // return the actor_ids to be applied - pub async fn update_mv_rate_limit_by_job_id( + pub async fn update_backfill_rate_limit_by_job_id( &self, job_id: ObjectId, rate_limit: Option, @@ -1411,7 +1410,7 @@ impl CatalogController { fragments.retain_mut(|(_, fragment_type_mask, stream_node)| { let mut found = false; if (*fragment_type_mask & PbFragmentTypeFlag::StreamScan as i32 != 0) - || (*fragment_type_mask & PbFragmentTypeFlag::Source as i32 != 0) + || (*fragment_type_mask & PbFragmentTypeFlag::SourceScan as i32 != 0) { visit_stream_node(stream_node, |node| match node { PbNodeBody::StreamCdcScan(node) => { @@ -1422,11 +1421,9 @@ impl CatalogController { node.rate_limit = rate_limit; found = true; } - PbNodeBody::Source(node) => { - if let Some(inner) = node.source_inner.as_mut() { - inner.rate_limit = rate_limit; - found = true; - } + PbNodeBody::SourceBackfill(node) => { + node.rate_limit = rate_limit; + found = true; } _ => {} }); diff --git a/src/meta/src/manager/metadata.rs b/src/meta/src/manager/metadata.rs index 40d3c025c0c8..db53f5fb8b6b 100644 --- a/src/meta/src/manager/metadata.rs +++ b/src/meta/src/manager/metadata.rs @@ -657,14 +657,14 @@ impl MetadataManager { .collect()) } - pub async fn update_mv_rate_limit_by_table_id( + pub async fn update_backfill_rate_limit_by_table_id( &self, table_id: TableId, rate_limit: Option, ) -> MetaResult>> { let fragment_actors = self .catalog_controller - .update_mv_rate_limit_by_job_id(table_id.table_id as _, rate_limit) + .update_backfill_rate_limit_by_job_id(table_id.table_id as _, rate_limit) .await?; Ok(fragment_actors .into_iter() diff --git a/src/stream/src/executor/source/source_backfill_executor.rs b/src/stream/src/executor/source/source_backfill_executor.rs index 6ada0f2b62eb..bbf71b281d3e 100644 --- a/src/stream/src/executor/source/source_backfill_executor.rs +++ b/src/stream/src/executor/source/source_backfill_executor.rs @@ -570,6 +570,33 @@ impl SourceBackfillExecutorInner { ) .await?; } + Mutation::Throttle(actor_to_apply) => { + if let Some(new_rate_limit) = + actor_to_apply.get(&self.actor_ctx.id) + && *new_rate_limit != self.rate_limit_rps + { + tracing::info!( + "updating rate limit from {:?} to {:?}", + self.rate_limit_rps, + *new_rate_limit + ); + self.rate_limit_rps = *new_rate_limit; + // rebuild reader + let (reader, _backfill_info) = self + .build_stream_source_reader( + &source_desc, + backfill_stage + .get_latest_unfinished_splits()?, + ) + .await?; + + backfill_stream = select_with_strategy( + input.by_ref().map(Either::Left), + reader.map(Either::Right), + select_strategy, + ); + } + } _ => {} } } diff --git a/src/stream/src/executor/source/source_executor.rs b/src/stream/src/executor/source/source_executor.rs index 118b33c08ae5..80b252014d28 100644 --- a/src/stream/src/executor/source/source_executor.rs +++ b/src/stream/src/executor/source/source_executor.rs @@ -608,7 +608,7 @@ impl SourceExecutor { if let Some(new_rate_limit) = actor_to_apply.get(&self.actor_ctx.id) && *new_rate_limit != self.rate_limit_rps { - tracing::debug!( + tracing::info!( "updating rate limit from {:?} to {:?}", self.rate_limit_rps, *new_rate_limit From 3fdd6a5cc0c0789118ab726cbb776bfefb41a557 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 20 Nov 2024 21:41:47 +0800 Subject: [PATCH 11/11] feat: add rw_rate_limit system catalog (#19466) Signed-off-by: xxchan --- e2e_test/source_inline/fs/posix_fs.slt | 44 ++++++- ...slt => rate_limit_source_kafka.slt.serial} | 26 +++- ...rate_limit_source_kafka_shared.slt.serial} | 34 ++++- ....slt => rate_limit_table_kafka.slt.serial} | 2 +- proto/meta.proto | 14 ++ proto/stream_plan.proto | 2 + .../catalog/system_catalog/rw_catalog/mod.rs | 1 + .../system_catalog/rw_catalog/rw_fragments.rs | 2 +- .../rw_catalog/rw_rate_limit.rs | 50 ++++++++ src/frontend/src/meta_client.rs | 7 + src/frontend/src/stream_fragmenter/mod.rs | 4 + src/frontend/src/test_utils.rs | 5 + src/meta/service/src/stream_service.rs | 12 ++ src/meta/src/controller/streaming_job.rs | 121 ++++++++++++++++-- src/meta/src/manager/metadata.rs | 6 + src/prost/src/lib.rs | 19 +++ src/rpc_client/src/meta_client.rs | 9 ++ src/sqlparser/src/parser.rs | 2 +- 18 files changed, 340 insertions(+), 20 deletions(-) rename e2e_test/source_inline/kafka/alter/{rate_limit_source_kafka.slt => rate_limit_source_kafka.slt.serial} (80%) rename e2e_test/source_inline/kafka/alter/{rate_limit_source_kafka_shared.slt => rate_limit_source_kafka_shared.slt.serial} (74%) rename e2e_test/source_inline/kafka/alter/{rate_limit_table_kafka.slt => rate_limit_table_kafka.slt.serial} (99%) create mode 100644 src/frontend/src/catalog/system_catalog/rw_catalog/rw_rate_limit.rs diff --git a/e2e_test/source_inline/fs/posix_fs.slt b/e2e_test/source_inline/fs/posix_fs.slt index da56502e417e..5408daf28321 100644 --- a/e2e_test/source_inline/fs/posix_fs.slt +++ b/e2e_test/source_inline/fs/posix_fs.slt @@ -33,21 +33,36 @@ create materialized view diamonds_mv as select * from diamonds_source; sleep 1s # no output due to rate limit -query TTTT rowsort +statement count 0 select * from diamonds; ----- -query TTTT rowsort + +statement count 0 select * from diamonds_mv; + + +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name, node_name; ---- +diamonds FS_FETCH {FS_FETCH} 0 +diamonds SOURCE {SOURCE} 0 +diamonds_mv FS_FETCH {MVIEW,FS_FETCH} 0 +diamonds_mv SOURCE {SOURCE} 0 statement ok ALTER TABLE diamonds SET source_rate_limit TO DEFAULT; -statement ok -ALTER source diamonds_source SET source_rate_limit TO DEFAULT; -sleep 10s +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name, node_name; +---- +diamonds_mv FS_FETCH {MVIEW,FS_FETCH} 0 +diamonds_mv SOURCE {SOURCE} 0 + + +sleep 3s query TTTT rowsort select * from diamonds; @@ -63,6 +78,23 @@ select * from diamonds; 1.28 Good J 63.1 1.3 Fair E 64.7 + +statement count 0 +select * from diamonds_mv; + + + +statement ok +ALTER SOURCE diamonds_source SET source_rate_limit TO DEFAULT; + +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name, node_name; +---- + + +sleep 3s + query TTTT rowsort select * from diamonds_mv; ---- diff --git a/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka.slt b/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka.slt.serial similarity index 80% rename from e2e_test/source_inline/kafka/alter/rate_limit_source_kafka.slt rename to e2e_test/source_inline/kafka/alter/rate_limit_source_kafka.slt.serial index 96fd016c5812..8353166b5a87 100644 --- a/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka.slt +++ b/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka.slt.serial @@ -80,16 +80,38 @@ select * from rl_mv3; ---- 0 +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name; +---- +rl_mv1 SOURCE {SOURCE} 0 +rl_mv2 SOURCE {SOURCE} 0 +rl_mv3 SOURCE {SOURCE} 0 + ############## Alter Source (rate_limit = 0 --> rate_limit = 1000) skipif in-memory -query I +statement count 0 alter source kafka_source set source_rate_limit to 1000; +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name; +---- +rl_mv1 SOURCE {SOURCE} 1000 +rl_mv2 SOURCE {SOURCE} 1000 +rl_mv3 SOURCE {SOURCE} 1000 + skipif in-memory -query I +statement count 0 alter source kafka_source set source_rate_limit to default; +# rate limit becomes None +query T +select count(*) from rw_rate_limit; +---- +0 + skipif in-memory sleep 3s diff --git a/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt b/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt.serial similarity index 74% rename from e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt rename to e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt.serial index 29c0b83aa40d..a9a730930b1b 100644 --- a/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt +++ b/e2e_test/source_inline/kafka/alter/rate_limit_source_kafka_shared.slt.serial @@ -84,11 +84,26 @@ SELECT progress from rw_ddl_progress; ---- 0 rows consumed +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name; +---- +kafka_source SOURCE {SOURCE} 0 +rl_mv2 SOURCE_BACKFILL {SOURCE_SCAN} 0 + + ############## Alter Source (rate_limit = 0 --> rate_limit = 1000) statement ok alter source kafka_source set source_rate_limit to 1000; +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name; +---- +kafka_source SOURCE {SOURCE} 1000 +rl_mv2 SOURCE_BACKFILL {SOURCE_SCAN} 0 + sleep 3s query I @@ -114,17 +129,34 @@ LINE 1: alter materialized view rl_mv2 set source_rate_limit = 1000; ^ +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name; +---- +kafka_source SOURCE {SOURCE} 1000 +rl_mv2 SOURCE_BACKFILL {SOURCE_SCAN} 0 + + statement ok alter materialized view rl_mv2 set backfill_rate_limit = 2000; + +query T +select name, node_name, fragment_type, rate_limit from rw_rate_limit join rw_relations on table_id=id +order by name; +---- +kafka_source SOURCE {SOURCE} 1000 +rl_mv2 SOURCE_BACKFILL {SOURCE_SCAN} 2000 + sleep 3s -query ? +query T select * from rl_mv2; ---- 2000 + ############## Cleanup statement ok diff --git a/e2e_test/source_inline/kafka/alter/rate_limit_table_kafka.slt b/e2e_test/source_inline/kafka/alter/rate_limit_table_kafka.slt.serial similarity index 99% rename from e2e_test/source_inline/kafka/alter/rate_limit_table_kafka.slt rename to e2e_test/source_inline/kafka/alter/rate_limit_table_kafka.slt.serial index ac2a665fd10c..5d22fc85dea4 100644 --- a/e2e_test/source_inline/kafka/alter/rate_limit_table_kafka.slt +++ b/e2e_test/source_inline/kafka/alter/rate_limit_table_kafka.slt.serial @@ -63,7 +63,7 @@ select count(*) from kafka_source; ############## Alter source (rate_limit = 0 --> rate_limit = 1000) skipif in-memory -query I +statement ok alter table kafka_source set source_rate_limit to 1000; skipif in-memory diff --git a/proto/meta.proto b/proto/meta.proto index 15a16f36bddd..5c6d1c64274f 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -335,6 +335,7 @@ service StreamManagerService { rpc ListObjectDependencies(ListObjectDependenciesRequest) returns (ListObjectDependenciesResponse); rpc ApplyThrottle(ApplyThrottleRequest) returns (ApplyThrottleResponse); rpc Recover(RecoverRequest) returns (RecoverResponse); + rpc ListRateLimits(ListRateLimitsRequest) returns (ListRateLimitsResponse); } // Below for cluster service. @@ -862,3 +863,16 @@ message GetClusterLimitsResponse { service ClusterLimitService { rpc GetClusterLimits(GetClusterLimitsRequest) returns (GetClusterLimitsResponse); } + +message ListRateLimitsRequest {} + +message ListRateLimitsResponse { + message RateLimitInfo { + uint32 fragment_id = 1; + uint32 job_id = 2; + uint32 fragment_type_mask = 3; + uint32 rate_limit = 4; + string node_name = 5; + } + repeated RateLimitInfo rate_limits = 1; +} diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 70c0d229394b..d5a47b53b6af 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -985,6 +985,8 @@ enum FragmentTypeFlag { FRAGMENT_TYPE_FLAG_CDC_FILTER = 256; FRAGMENT_TYPE_FLAG_SOURCE_SCAN = 1024; FRAGMENT_TYPE_FLAG_SNAPSHOT_BACKFILL_STREAM_SCAN = 2048; + // Note: this flag is not available in old fragments, so only suitable for debugging purpose. + FRAGMENT_TYPE_FLAG_FS_FETCH = 4096; } // The streaming context associated with a stream plan diff --git a/src/frontend/src/catalog/system_catalog/rw_catalog/mod.rs b/src/frontend/src/catalog/system_catalog/rw_catalog/mod.rs index 9c546f1ec729..947560e44e62 100644 --- a/src/frontend/src/catalog/system_catalog/rw_catalog/mod.rs +++ b/src/frontend/src/catalog/system_catalog/rw_catalog/mod.rs @@ -39,6 +39,7 @@ mod rw_indexes; mod rw_internal_tables; mod rw_materialized_views; mod rw_meta_snapshot; +mod rw_rate_limit; mod rw_relation_info; mod rw_relations; mod rw_schemas; diff --git a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_fragments.rs b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_fragments.rs index 91f818e7919f..75a040f2733c 100644 --- a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_fragments.rs +++ b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_fragments.rs @@ -32,7 +32,7 @@ struct RwFragment { max_parallelism: i32, } -fn extract_fragment_type_flag(mask: u32) -> Vec { +pub(super) fn extract_fragment_type_flag(mask: u32) -> Vec { let mut result = vec![]; for i in 0..32 { let bit = 1 << i; diff --git a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_rate_limit.rs b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_rate_limit.rs new file mode 100644 index 000000000000..34602461ca3b --- /dev/null +++ b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_rate_limit.rs @@ -0,0 +1,50 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::Fields; +use risingwave_frontend_macro::system_catalog; + +use super::rw_fragments::extract_fragment_type_flag; +use crate::catalog::system_catalog::SysCatalogReaderImpl; +use crate::error::Result; + +#[derive(Fields)] +#[primary_key(fragment_id, node_name)] +struct RwRateLimit { + fragment_id: i32, + fragment_type: Vec, + node_name: String, + table_id: i32, + rate_limit: i32, +} + +#[system_catalog(table, "rw_catalog.rw_rate_limit")] +async fn read_rw_rate_limit(reader: &SysCatalogReaderImpl) -> Result> { + let rate_limits = reader.meta_client.list_rate_limits().await?; + + Ok(rate_limits + .into_iter() + .map(|info| RwRateLimit { + fragment_id: info.fragment_id as i32, + fragment_type: extract_fragment_type_flag(info.fragment_type_mask) + .into_iter() + .flat_map(|t| t.as_str_name().strip_prefix("FRAGMENT_TYPE_FLAG_")) + .map(|s| s.into()) + .collect(), + table_id: info.job_id as i32, + rate_limit: info.rate_limit as i32, + node_name: info.node_name, + }) + .collect()) +} diff --git a/src/frontend/src/meta_client.rs b/src/frontend/src/meta_client.rs index a91a0d8abc87..760c7bd450e1 100644 --- a/src/frontend/src/meta_client.rs +++ b/src/frontend/src/meta_client.rs @@ -33,6 +33,7 @@ use risingwave_pb::meta::list_actor_splits_response::ActorSplit; use risingwave_pb::meta::list_actor_states_response::ActorState; use risingwave_pb::meta::list_fragment_distribution_response::FragmentDistribution; use risingwave_pb::meta::list_object_dependencies_response::PbObjectDependencies; +use risingwave_pb::meta::list_rate_limits_response::RateLimitInfo; use risingwave_pb::meta::list_table_fragment_states_response::TableFragmentState; use risingwave_pb::meta::list_table_fragments_response::TableFragmentInfo; use risingwave_pb::meta::{EventLog, PbThrottleTarget, RecoveryStatus}; @@ -125,6 +126,8 @@ pub trait FrontendMetaClient: Send + Sync { async fn get_cluster_recovery_status(&self) -> Result; async fn get_cluster_limits(&self) -> Result>; + + async fn list_rate_limits(&self) -> Result>; } pub struct FrontendMetaClientImpl(pub MetaClient); @@ -300,4 +303,8 @@ impl FrontendMetaClient for FrontendMetaClientImpl { async fn get_cluster_limits(&self) -> Result> { self.0.get_cluster_limits().await } + + async fn list_rate_limits(&self) -> Result> { + self.0.list_rate_limits().await + } } diff --git a/src/frontend/src/stream_fragmenter/mod.rs b/src/frontend/src/stream_fragmenter/mod.rs index daa48d99969c..f30b0abf5b4c 100644 --- a/src/frontend/src/stream_fragmenter/mod.rs +++ b/src/frontend/src/stream_fragmenter/mod.rs @@ -361,6 +361,10 @@ fn build_fragment( current_fragment.requires_singleton = true; } + NodeBody::StreamFsFetch(_) => { + current_fragment.fragment_type_mask |= FragmentTypeFlag::FsFetch as u32; + } + _ => {} }; diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index d94b1dd2652d..15a5281dec5e 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -56,6 +56,7 @@ use risingwave_pb::meta::list_actor_splits_response::ActorSplit; use risingwave_pb::meta::list_actor_states_response::ActorState; use risingwave_pb::meta::list_fragment_distribution_response::FragmentDistribution; use risingwave_pb::meta::list_object_dependencies_response::PbObjectDependencies; +use risingwave_pb::meta::list_rate_limits_response::RateLimitInfo; use risingwave_pb::meta::list_table_fragment_states_response::TableFragmentState; use risingwave_pb::meta::list_table_fragments_response::TableFragmentInfo; use risingwave_pb::meta::{ @@ -1065,6 +1066,10 @@ impl FrontendMetaClient for MockFrontendMetaClient { async fn get_cluster_limits(&self) -> RpcResult> { Ok(vec![]) } + + async fn list_rate_limits(&self) -> RpcResult> { + Ok(vec![]) + } } #[cfg(test)] diff --git a/src/meta/service/src/stream_service.rs b/src/meta/service/src/stream_service.rs index dfd8ec21187f..4bb9bfb2d448 100644 --- a/src/meta/service/src/stream_service.rs +++ b/src/meta/service/src/stream_service.rs @@ -433,4 +433,16 @@ impl StreamManagerService for StreamServiceImpl { Ok(Response::new(ListActorSplitsResponse { actor_splits })) } + + async fn list_rate_limits( + &self, + _request: Request, + ) -> Result, Status> { + let rate_limits = self + .metadata_manager + .catalog_controller + .list_rate_limits() + .await?; + Ok(Response::new(ListRateLimitsResponse { rate_limits })) + } } diff --git a/src/meta/src/controller/streaming_job.rs b/src/meta/src/controller/streaming_job.rs index 5139b5069d9d..a908704129c7 100644 --- a/src/meta/src/controller/streaming_job.rs +++ b/src/meta/src/controller/streaming_job.rs @@ -39,6 +39,7 @@ use risingwave_meta_model::{ use risingwave_pb::catalog::source::PbOptionalAssociatedTableId; use risingwave_pb::catalog::table::{PbOptionalAssociatedSourceId, PbTableVersion}; use risingwave_pb::catalog::{PbCreateType, PbTable}; +use risingwave_pb::meta::list_rate_limits_response::RateLimitInfo; use risingwave_pb::meta::relation::{PbRelationInfo, RelationInfo}; use risingwave_pb::meta::subscribe_response::{ Info as NotificationInfo, Info, Operation as NotificationOperation, Operation, @@ -53,12 +54,12 @@ use risingwave_pb::stream_plan::update_mutation::PbMergeUpdate; use risingwave_pb::stream_plan::{ PbDispatcher, PbDispatcherType, PbFragmentTypeFlag, PbStreamActor, }; -use sea_orm::sea_query::{Expr, Query, SimpleExpr}; +use sea_orm::sea_query::{BinOper, Expr, Query, SimpleExpr}; use sea_orm::ActiveValue::Set; use sea_orm::{ ActiveEnum, ActiveModelTrait, ColumnTrait, DatabaseTransaction, EntityTrait, IntoActiveModel, - JoinType, ModelTrait, NotSet, PaginatorTrait, QueryFilter, QuerySelect, RelationTrait, - TransactionTrait, + IntoSimpleExpr, JoinType, ModelTrait, NotSet, PaginatorTrait, QueryFilter, QuerySelect, + RelationTrait, TransactionTrait, }; use crate::barrier::{ReplaceTablePlan, Reschedule}; @@ -1332,9 +1333,11 @@ impl CatalogController { }); } if is_fs_source { - // scan all fragments for StreamFsFetch node if using fs connector + // in older versions, there's no fragment type flag for `FsFetch` node, + // so we just scan all fragments for StreamFsFetch node if using fs connector visit_stream_node(stream_node, |node| { if let PbNodeBody::StreamFsFetch(node) = node { + *fragment_type_mask |= PbFragmentTypeFlag::FsFetch as i32; if let Some(node_inner) = &mut node.node_inner && node_inner.source_id == source_id as u32 { @@ -1352,9 +1355,10 @@ impl CatalogController { "source id should be used by at least one fragment" ); let fragment_ids = fragments.iter().map(|(id, _, _)| *id).collect_vec(); - for (id, _, stream_node) in fragments { + for (id, fragment_type_mask, stream_node) in fragments { fragment::ActiveModel { fragment_id: Set(id), + fragment_type_mask: Set(fragment_type_mask), stream_node: Set(StreamNode::from(&stream_node)), ..Default::default() } @@ -1409,9 +1413,7 @@ impl CatalogController { fragments.retain_mut(|(_, fragment_type_mask, stream_node)| { let mut found = false; - if (*fragment_type_mask & PbFragmentTypeFlag::StreamScan as i32 != 0) - || (*fragment_type_mask & PbFragmentTypeFlag::SourceScan as i32 != 0) - { + if *fragment_type_mask & PbFragmentTypeFlag::backfill_rate_limit_fragments() != 0 { visit_stream_node(stream_node, |node| match node { PbNodeBody::StreamCdcScan(node) => { node.rate_limit = rate_limit; @@ -1778,4 +1780,107 @@ impl CatalogController { Ok(()) } + + /// Note: `FsFetch` created in old versions are not included. + /// Since this is only used for debugging, it should be fine. + pub async fn list_rate_limits(&self) -> MetaResult> { + let inner = self.inner.read().await; + let txn = inner.db.begin().await?; + + let fragments: Vec<(FragmentId, ObjectId, i32, StreamNode)> = Fragment::find() + .select_only() + .columns([ + fragment::Column::FragmentId, + fragment::Column::JobId, + fragment::Column::FragmentTypeMask, + fragment::Column::StreamNode, + ]) + .filter(fragment_type_mask_intersects( + PbFragmentTypeFlag::rate_limit_fragments(), + )) + .into_tuple() + .all(&txn) + .await?; + + let mut rate_limits = Vec::new(); + for (fragment_id, job_id, fragment_type_mask, stream_node) in fragments { + let mut stream_node = stream_node.to_protobuf(); + let mut rate_limit = None; + let mut node_name = None; + + visit_stream_node(&mut stream_node, |node| { + match node { + // source rate limit + PbNodeBody::Source(node) => { + if let Some(node_inner) = &mut node.source_inner { + debug_assert!( + rate_limit.is_none(), + "one fragment should only have 1 rate limit node" + ); + rate_limit = node_inner.rate_limit; + node_name = Some("SOURCE"); + } + } + PbNodeBody::StreamFsFetch(node) => { + if let Some(node_inner) = &mut node.node_inner { + debug_assert!( + rate_limit.is_none(), + "one fragment should only have 1 rate limit node" + ); + rate_limit = node_inner.rate_limit; + node_name = Some("FS_FETCH"); + } + } + // backfill rate limit + PbNodeBody::SourceBackfill(node) => { + debug_assert!( + rate_limit.is_none(), + "one fragment should only have 1 rate limit node" + ); + rate_limit = node.rate_limit; + node_name = Some("SOURCE_BACKFILL"); + } + PbNodeBody::StreamScan(node) => { + debug_assert!( + rate_limit.is_none(), + "one fragment should only have 1 rate limit node" + ); + rate_limit = node.rate_limit; + node_name = Some("STREAM_SCAN"); + } + PbNodeBody::StreamCdcScan(node) => { + debug_assert!( + rate_limit.is_none(), + "one fragment should only have 1 rate limit node" + ); + rate_limit = node.rate_limit; + node_name = Some("STREAM_CDC_SCAN"); + } + _ => {} + } + }); + + if let Some(rate_limit) = rate_limit { + rate_limits.push(RateLimitInfo { + fragment_id: fragment_id as u32, + job_id: job_id as u32, + fragment_type_mask: fragment_type_mask as u32, + rate_limit, + node_name: node_name.unwrap().to_string(), + }); + } + } + + Ok(rate_limits) + } +} + +fn bitflag_intersects(column: SimpleExpr, value: i32) -> SimpleExpr { + column + .binary(BinOper::Custom("&"), value) + .binary(BinOper::NotEqual, 0) +} + +fn fragment_type_mask_intersects(value: i32) -> SimpleExpr { + bitflag_intersects(fragment::Column::FragmentTypeMask.into_simple_expr(), value) } diff --git a/src/meta/src/manager/metadata.rs b/src/meta/src/manager/metadata.rs index db53f5fb8b6b..b974ad82b053 100644 --- a/src/meta/src/manager/metadata.rs +++ b/src/meta/src/manager/metadata.rs @@ -25,6 +25,7 @@ use risingwave_pb::catalog::{PbSink, PbSource, PbTable}; use risingwave_pb::common::worker_node::{PbResource, State}; use risingwave_pb::common::{HostAddress, PbWorkerNode, PbWorkerType, WorkerNode, WorkerType}; use risingwave_pb::meta::add_worker_node_request::Property as AddNodeProperty; +use risingwave_pb::meta::list_rate_limits_response::RateLimitInfo; use risingwave_pb::meta::table_fragments::{Fragment, PbFragment}; use risingwave_pb::stream_plan::{PbDispatchStrategy, StreamActor}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; @@ -720,6 +721,11 @@ impl MetadataManager { pub fn cluster_id(&self) -> &ClusterId { self.cluster_controller.cluster_id() } + + pub async fn list_rate_limits(&self) -> MetaResult> { + let rate_limits = self.catalog_controller.list_rate_limits().await?; + Ok(rate_limits) + } } impl MetadataManager { diff --git a/src/prost/src/lib.rs b/src/prost/src/lib.rs index 5974a0566472..a4678df09127 100644 --- a/src/prost/src/lib.rs +++ b/src/prost/src/lib.rs @@ -302,6 +302,25 @@ impl stream_plan::StreamNode { } } +impl stream_plan::FragmentTypeFlag { + /// Fragments that may be affected by `BACKFILL_RATE_LIMIT`. + pub fn backfill_rate_limit_fragments() -> i32 { + stream_plan::FragmentTypeFlag::SourceScan as i32 + | stream_plan::FragmentTypeFlag::StreamScan as i32 + } + + /// Fragments that may be affected by `SOURCE_RATE_LIMIT`. + /// Note: for `FsFetch`, old fragments don't have this flag set, so don't use this to check. + pub fn source_rate_limit_fragments() -> i32 { + stream_plan::FragmentTypeFlag::Source as i32 | stream_plan::FragmentTypeFlag::FsFetch as i32 + } + + /// Note: this doesn't include `FsFetch` created in old versions. + pub fn rate_limit_fragments() -> i32 { + Self::backfill_rate_limit_fragments() | Self::source_rate_limit_fragments() + } +} + impl catalog::StreamSourceInfo { /// Refer to [`Self::cdc_source_job`] for details. pub fn is_shared(&self) -> bool { diff --git a/src/rpc_client/src/meta_client.rs b/src/rpc_client/src/meta_client.rs index be733e8d4ec1..80213d0deda6 100644 --- a/src/rpc_client/src/meta_client.rs +++ b/src/rpc_client/src/meta_client.rs @@ -25,6 +25,7 @@ use async_trait::async_trait; use cluster_limit_service_client::ClusterLimitServiceClient; use either::Either; use futures::stream::BoxStream; +use list_rate_limits_response::RateLimitInfo; use lru::LruCache; use risingwave_common::catalog::{FunctionId, IndexId, SecretId, TableId}; use risingwave_common::config::{MetaConfig, MAX_CONNECTION_WINDOW_SIZE}; @@ -1494,6 +1495,13 @@ impl MetaClient { self.inner.merge_compaction_group(req).await?; Ok(()) } + + /// List all rate limits for sources and backfills + pub async fn list_rate_limits(&self) -> Result> { + let request = ListRateLimitsRequest {}; + let resp = self.inner.list_rate_limits(request).await?; + Ok(resp.rate_limits) + } } #[async_trait] @@ -2044,6 +2052,7 @@ macro_rules! for_all_meta_rpc { ,{ stream_client, list_actor_splits, ListActorSplitsRequest, ListActorSplitsResponse } ,{ stream_client, list_object_dependencies, ListObjectDependenciesRequest, ListObjectDependenciesResponse } ,{ stream_client, recover, RecoverRequest, RecoverResponse } + ,{ stream_client, list_rate_limits, ListRateLimitsRequest, ListRateLimitsResponse } ,{ ddl_client, create_table, CreateTableRequest, CreateTableResponse } ,{ ddl_client, alter_name, AlterNameRequest, AlterNameResponse } ,{ ddl_client, alter_owner, AlterOwnerRequest, AlterOwnerResponse } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index f93d41aeed2c..f8d449a253c3 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -3462,7 +3462,7 @@ impl Parser<'_> { } else if let Some(rate_limit) = self.parse_alter_source_rate_limit(false)? { AlterSourceOperation::SetSourceRateLimit { rate_limit } } else { - return self.expected("SCHEMA after SET"); + return self.expected("SCHEMA or SOURCE_RATE_LIMIT after SET"); } } else if self.peek_nth_any_of_keywords(0, &[Keyword::FORMAT]) { let format_encode = self.parse_schema()?.unwrap();