From 45b9f35439600dead276661bb465126212652c86 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Mon, 29 Jan 2024 23:36:16 +0800 Subject: [PATCH] fix temporal join shuffle --- .../tests/testdata/output/temporal_join.yaml | 13 +++++++------ .../src/optimizer/plan_node/logical_join.rs | 14 +++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml b/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml index f49a82be2dd78..ea844cda185b1 100644 --- a/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml +++ b/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml @@ -107,12 +107,13 @@ StreamMaterialize { columns: [k, x1, x2, a1, b1, stream._row_id(hidden), version2.k(hidden)], stream_key: [stream._row_id, k], pk_columns: [stream._row_id, k], pk_conflict: NoCheck } └─StreamExchange { dist: HashShard(stream.k, stream._row_id) } └─StreamTemporalJoin { type: Inner, predicate: stream.k = version2.k, output: [stream.k, version1.x1, version2.x2, stream.a1, stream.b1, stream._row_id, version2.k] } - ├─StreamTemporalJoin { type: Inner, predicate: stream.k = version1.k, output: [stream.k, stream.a1, stream.b1, version1.x1, stream._row_id, version1.k] } - │ ├─StreamExchange { dist: HashShard(stream.k) } - │ │ └─StreamFilter { predicate: (stream.a1 < 10:Int32) } - │ │ └─StreamTableScan { table: stream, columns: [stream.k, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } - │ └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(version1.k) } - │ └─StreamTableScan { table: version1, columns: [version1.k, version1.x1], pk: [version1.k], dist: UpstreamHashShard(version1.k) } + ├─StreamExchange { dist: HashShard(stream.k) } + │ └─StreamTemporalJoin { type: Inner, predicate: stream.k = version1.k, output: [stream.k, stream.a1, stream.b1, version1.x1, stream._row_id, version1.k] } + │ ├─StreamExchange { dist: HashShard(stream.k) } + │ │ └─StreamFilter { predicate: (stream.a1 < 10:Int32) } + │ │ └─StreamTableScan { table: stream, columns: [stream.k, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + │ └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(version1.k) } + │ └─StreamTableScan { table: version1, columns: [version1.k, version1.x1], pk: [version1.k], dist: UpstreamHashShard(version1.k) } └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(version2.k) } └─StreamTableScan { table: version2, columns: [version2.k, version2.x2], pk: [version2.k], dist: UpstreamHashShard(version2.k) } - name: multi-way temporal join with different keys diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index dd555e5e3a1c0..a5be7e1fb0368 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -1054,9 +1054,8 @@ impl LogicalJoin { let lookup_prefix_len = reorder_idx.len(); let predicate = predicate.reorder(&reorder_idx); - let left = if dist_key_in_order_key_pos.is_empty() { - self.left() - .to_stream_with_dist_required(&RequiredDist::single(), ctx)? + let required_dist = if dist_key_in_order_key_pos.is_empty() { + RequiredDist::single() } else { let left_eq_indexes = predicate.left_eq_indexes(); let left_dist_key = dist_key_in_order_key_pos @@ -1064,12 +1063,13 @@ impl LogicalJoin { .map(|pos| left_eq_indexes[*pos]) .collect_vec(); - self.left().to_stream_with_dist_required( - &RequiredDist::shard_by_key(self.left().schema().len(), &left_dist_key), - ctx, - )? + RequiredDist::shard_by_key(self.left().schema().len(), &left_dist_key) }; + let left = self.left().to_stream(ctx)?; + // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange. + let left = required_dist.enforce(left, &Order::any()); + if !left.append_only() { return Err(RwError::from(ErrorCode::NotSupported( "Temporal join requires an append-only left input".into(),