diff --git a/dashboard/proto/gen/batch_plan.ts b/dashboard/proto/gen/batch_plan.ts index 0ad51e94575a0..56436c2b71b6d 100644 --- a/dashboard/proto/gen/batch_plan.ts +++ b/dashboard/proto/gen/batch_plan.ts @@ -245,6 +245,8 @@ export interface LocalLookupJoinNode { joinType: JoinType; condition: ExprNode | undefined; outerSideKey: number[]; + innerSideKey: number[]; + lookupPrefixLen: number; innerSideTableDesc: StorageTableDesc | undefined; innerSideVnodeMapping: number[]; innerSideColumnIds: number[]; @@ -265,6 +267,8 @@ export interface DistributedLookupJoinNode { joinType: JoinType; condition: ExprNode | undefined; outerSideKey: number[]; + innerSideKey: number[]; + lookupPrefixLen: number; innerSideTableDesc: StorageTableDesc | undefined; innerSideColumnIds: number[]; outputIndices: number[]; @@ -1575,6 +1579,8 @@ function createBaseLocalLookupJoinNode(): LocalLookupJoinNode { joinType: JoinType.UNSPECIFIED, condition: undefined, outerSideKey: [], + innerSideKey: [], + lookupPrefixLen: 0, innerSideTableDesc: undefined, innerSideVnodeMapping: [], innerSideColumnIds: [], @@ -1590,6 +1596,8 @@ export const LocalLookupJoinNode = { joinType: isSet(object.joinType) ? joinTypeFromJSON(object.joinType) : JoinType.UNSPECIFIED, condition: isSet(object.condition) ? ExprNode.fromJSON(object.condition) : undefined, outerSideKey: Array.isArray(object?.outerSideKey) ? object.outerSideKey.map((e: any) => Number(e)) : [], + innerSideKey: Array.isArray(object?.innerSideKey) ? object.innerSideKey.map((e: any) => Number(e)) : [], + lookupPrefixLen: isSet(object.lookupPrefixLen) ? Number(object.lookupPrefixLen) : 0, innerSideTableDesc: isSet(object.innerSideTableDesc) ? StorageTableDesc.fromJSON(object.innerSideTableDesc) : undefined, @@ -1617,6 +1625,12 @@ export const LocalLookupJoinNode = { } else { obj.outerSideKey = []; } + if (message.innerSideKey) { + obj.innerSideKey = message.innerSideKey.map((e) => Math.round(e)); + } else { + obj.innerSideKey = []; + } + message.lookupPrefixLen !== undefined && (obj.lookupPrefixLen = Math.round(message.lookupPrefixLen)); message.innerSideTableDesc !== undefined && (obj.innerSideTableDesc = message.innerSideTableDesc ? StorageTableDesc.toJSON(message.innerSideTableDesc) : undefined); @@ -1655,6 +1669,8 @@ export const LocalLookupJoinNode = { ? ExprNode.fromPartial(object.condition) : undefined; message.outerSideKey = object.outerSideKey?.map((e) => e) || []; + message.innerSideKey = object.innerSideKey?.map((e) => e) || []; + message.lookupPrefixLen = object.lookupPrefixLen ?? 0; message.innerSideTableDesc = (object.innerSideTableDesc !== undefined && object.innerSideTableDesc !== null) ? StorageTableDesc.fromPartial(object.innerSideTableDesc) : undefined; @@ -1672,6 +1688,8 @@ function createBaseDistributedLookupJoinNode(): DistributedLookupJoinNode { joinType: JoinType.UNSPECIFIED, condition: undefined, outerSideKey: [], + innerSideKey: [], + lookupPrefixLen: 0, innerSideTableDesc: undefined, innerSideColumnIds: [], outputIndices: [], @@ -1685,6 +1703,8 @@ export const DistributedLookupJoinNode = { joinType: isSet(object.joinType) ? joinTypeFromJSON(object.joinType) : JoinType.UNSPECIFIED, condition: isSet(object.condition) ? ExprNode.fromJSON(object.condition) : undefined, outerSideKey: Array.isArray(object?.outerSideKey) ? object.outerSideKey.map((e: any) => Number(e)) : [], + innerSideKey: Array.isArray(object?.innerSideKey) ? object.innerSideKey.map((e: any) => Number(e)) : [], + lookupPrefixLen: isSet(object.lookupPrefixLen) ? Number(object.lookupPrefixLen) : 0, innerSideTableDesc: isSet(object.innerSideTableDesc) ? StorageTableDesc.fromJSON(object.innerSideTableDesc) : undefined, @@ -1706,6 +1726,12 @@ export const DistributedLookupJoinNode = { } else { obj.outerSideKey = []; } + if (message.innerSideKey) { + obj.innerSideKey = message.innerSideKey.map((e) => Math.round(e)); + } else { + obj.innerSideKey = []; + } + message.lookupPrefixLen !== undefined && (obj.lookupPrefixLen = Math.round(message.lookupPrefixLen)); message.innerSideTableDesc !== undefined && (obj.innerSideTableDesc = message.innerSideTableDesc ? StorageTableDesc.toJSON(message.innerSideTableDesc) : undefined); @@ -1734,6 +1760,8 @@ export const DistributedLookupJoinNode = { ? ExprNode.fromPartial(object.condition) : undefined; message.outerSideKey = object.outerSideKey?.map((e) => e) || []; + message.innerSideKey = object.innerSideKey?.map((e) => e) || []; + message.lookupPrefixLen = object.lookupPrefixLen ?? 0; message.innerSideTableDesc = (object.innerSideTableDesc !== undefined && object.innerSideTableDesc !== null) ? StorageTableDesc.fromPartial(object.innerSideTableDesc) : undefined; diff --git a/e2e_test/batch/basic/local/lookup_join.slt.part b/e2e_test/batch/basic/lookup_join.slt.part similarity index 88% rename from e2e_test/batch/basic/local/lookup_join.slt.part rename to e2e_test/batch/basic/lookup_join.slt.part index 05e2e021ae409..78a1a4c498138 100644 --- a/e2e_test/batch/basic/local/lookup_join.slt.part +++ b/e2e_test/batch/basic/lookup_join.slt.part @@ -220,5 +220,36 @@ drop table t1; statement ok drop table t2; +statement ok +create table t1(a int, b int); + +statement ok +create table t2(c int, d int); + +statement ok +create index idx on t2(c) include(d); + +statement ok +insert into t1 values (1,222); + +statement ok +insert into t2 values (1,222); + +query IIII +select * from t1 join idx on t1.a = idx.c; +---- +1 222 1 222 + +query IIII +select * from t1 join idx on t1.a = idx.c and t1.b = idx.d; +---- +1 222 1 222 + +statement ok +drop table t1; + +statement ok +drop table t2; + statement ok set rw_batch_enable_lookup_join to false; diff --git a/proto/batch_plan.proto b/proto/batch_plan.proto index 5a68a07ec0eb6..6852551150e42 100644 --- a/proto/batch_plan.proto +++ b/proto/batch_plan.proto @@ -217,14 +217,16 @@ message LocalLookupJoinNode { plan_common.JoinType join_type = 1; expr.ExprNode condition = 2; repeated uint32 outer_side_key = 3; - plan_common.StorageTableDesc inner_side_table_desc = 4; - repeated uint32 inner_side_vnode_mapping = 5; - repeated int32 inner_side_column_ids = 6; - repeated uint32 output_indices = 7; - repeated common.WorkerNode worker_nodes = 8; + repeated uint32 inner_side_key = 4; + uint32 lookup_prefix_len = 5; + plan_common.StorageTableDesc inner_side_table_desc = 6; + repeated uint32 inner_side_vnode_mapping = 7; + repeated int32 inner_side_column_ids = 8; + repeated uint32 output_indices = 9; + repeated common.WorkerNode worker_nodes = 10; // Null safe means it treats `null = null` as true. // Each key pair can be null safe independently. (left_key, right_key, null_safe) - repeated bool null_safe = 9; + repeated bool null_safe = 11; } // RFC: A new schedule way for distributed lookup join @@ -233,12 +235,14 @@ message DistributedLookupJoinNode { plan_common.JoinType join_type = 1; expr.ExprNode condition = 2; repeated uint32 outer_side_key = 3; - plan_common.StorageTableDesc inner_side_table_desc = 4; - repeated int32 inner_side_column_ids = 5; - repeated uint32 output_indices = 6; + repeated uint32 inner_side_key = 4; + uint32 lookup_prefix_len = 5; + plan_common.StorageTableDesc inner_side_table_desc = 6; + repeated int32 inner_side_column_ids = 7; + repeated uint32 output_indices = 8; // Null safe means it treats `null = null` as true. // Each key pair can be null safe independently. (left_key, right_key, null_safe) - repeated bool null_safe = 7; + repeated bool null_safe = 9; } message UnionNode {} diff --git a/src/batch/src/executor/join/distributed_lookup_join.rs b/src/batch/src/executor/join/distributed_lookup_join.rs index c959ccf34e436..b18b76012ef7e 100644 --- a/src/batch/src/executor/join/distributed_lookup_join.rs +++ b/src/batch/src/executor/join/distributed_lookup_join.rs @@ -15,9 +15,10 @@ use std::marker::PhantomData; use std::mem::swap; +use futures::pin_mut; use itertools::Itertools; use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId, TableOption}; -use risingwave_common::error::{internal_error, Result}; +use risingwave_common::error::Result; use risingwave_common::hash::{HashKey, HashKeyDispatcher}; use risingwave_common::row::Row; use risingwave_common::types::{DataType, Datum}; @@ -31,7 +32,7 @@ use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::expr::expr_node::Type; use risingwave_pb::plan_common::OrderType as ProstOrderType; use risingwave_storage::table::batch_table::storage_table::StorageTable; -use risingwave_storage::table::Distribution; +use risingwave_storage::table::{Distribution, TableIter}; use risingwave_storage::{dispatch_state_store, StateStore}; use crate::executor::join::JoinType; @@ -152,15 +153,12 @@ impl BoxedExecutorBuilder for DistributedLookupJoinExecutorBuilder { .map(|&i| outer_side_data_types[i].clone()) .collect_vec(); + let lookup_prefix_len: usize = + distributed_lookup_join_node.get_lookup_prefix_len() as usize; + let mut inner_side_key_idxs = vec![]; - for pk in &table_desc.pk { - let key_idx = inner_side_column_ids - .iter() - .position(|&i| table_desc.columns[pk.index as usize].column_id == i) - .ok_or_else(|| { - internal_error("Inner side key is not part of its output columns") - })?; - inner_side_key_idxs.push(key_idx); + for inner_side_key in distributed_lookup_join_node.get_inner_side_key() { + inner_side_key_idxs.push(*inner_side_key as usize) } let inner_side_key_types = inner_side_key_idxs @@ -232,6 +230,7 @@ impl BoxedExecutorBuilder for DistributedLookupJoinExecutorBuilder { let inner_side_builder = InnerSideExecutorBuilder::new( outer_side_key_types, inner_side_key_types.clone(), + lookup_prefix_len, source.epoch(), vec![], table, @@ -248,6 +247,7 @@ impl BoxedExecutorBuilder for DistributedLookupJoinExecutorBuilder { inner_side_key_types, inner_side_key_idxs, null_safe, + lookup_prefix_len, chunk_builder: DataChunkBuilder::new(original_schema.data_types(), chunk_size), schema: actual_schema, output_indices, @@ -269,6 +269,7 @@ struct DistributedLookupJoinExecutorArgs { inner_side_key_types: Vec, inner_side_key_idxs: Vec, null_safe: Vec, + lookup_prefix_len: usize, chunk_builder: DataChunkBuilder, schema: Schema, output_indices: Vec, @@ -291,6 +292,7 @@ impl HashKeyDispatcher for DistributedLookupJoinExecutorArgs { inner_side_key_types: self.inner_side_key_types, inner_side_key_idxs: self.inner_side_key_idxs, null_safe: self.null_safe, + lookup_prefix_len: self.lookup_prefix_len, chunk_builder: self.chunk_builder, schema: self.schema, output_indices: self.output_indices, @@ -310,6 +312,7 @@ impl HashKeyDispatcher for DistributedLookupJoinExecutorArgs { struct InnerSideExecutorBuilder { outer_side_key_types: Vec, inner_side_key_types: Vec, + lookup_prefix_len: usize, epoch: u64, row_list: Vec, table: StorageTable, @@ -320,6 +323,7 @@ impl InnerSideExecutorBuilder { fn new( outer_side_key_types: Vec, inner_side_key_types: Vec, + lookup_prefix_len: usize, epoch: u64, row_list: Vec, table: StorageTable, @@ -328,6 +332,7 @@ impl InnerSideExecutorBuilder { Self { outer_side_key_types, inner_side_key_types, + lookup_prefix_len, epoch, row_list, table, @@ -348,8 +353,16 @@ impl LookupExecutorBuilder for InnerSideExecutorBuilder { for ((datum, outer_type), inner_type) in key_datums .into_iter() - .zip_eq(self.outer_side_key_types.iter()) - .zip_eq(self.inner_side_key_types.iter()) + .zip_eq( + self.outer_side_key_types + .iter() + .take(self.lookup_prefix_len), + ) + .zip_eq( + self.inner_side_key_types + .iter() + .take(self.lookup_prefix_len), + ) { let datum = if inner_type == outer_type { datum @@ -367,13 +380,26 @@ impl LookupExecutorBuilder for InnerSideExecutorBuilder { } let pk_prefix = Row::new(scan_range.eq_conds); - let row = self - .table - .get_row(&pk_prefix, HummockReadEpoch::Committed(self.epoch)) - .await?; - if let Some(row) = row { - self.row_list.push(row); + if self.lookup_prefix_len == self.table.pk_indices().len() { + let row = self + .table + .get_row(&pk_prefix, HummockReadEpoch::Committed(self.epoch)) + .await?; + + if let Some(row) = row { + self.row_list.push(row); + } + } else { + let iter = self + .table + .batch_iter_with_pk_bounds(HummockReadEpoch::Committed(self.epoch), &pk_prefix, ..) + .await?; + + pin_mut!(iter); + while let Some(row) = iter.next_row().await? { + self.row_list.push(row); + } } Ok(()) diff --git a/src/batch/src/executor/join/local_lookup_join.rs b/src/batch/src/executor/join/local_lookup_join.rs index 67cbf21e90970..06cba117e2f4e 100644 --- a/src/batch/src/executor/join/local_lookup_join.rs +++ b/src/batch/src/executor/join/local_lookup_join.rs @@ -55,6 +55,7 @@ struct InnerSideExecutorBuilder { inner_side_schema: Schema, inner_side_column_ids: Vec, inner_side_key_types: Vec, + lookup_prefix_len: usize, context: C, task_id: TaskId, epoch: u64, @@ -171,8 +172,16 @@ impl LookupExecutorBuilder for InnerSideExecutorBuilder for ((datum, outer_type), inner_type) in key_datums .into_iter() - .zip_eq(self.outer_side_key_types.iter()) - .zip_eq(self.inner_side_key_types.iter()) + .zip_eq( + self.outer_side_key_types + .iter() + .take(self.lookup_prefix_len), + ) + .zip_eq( + self.inner_side_key_types + .iter() + .take(self.lookup_prefix_len), + ) { let datum = if inner_type == outer_type { datum @@ -346,15 +355,11 @@ impl BoxedExecutorBuilder for LocalLookupJoinExecutorBuilder { .map(|&i| outer_side_data_types[i].clone()) .collect_vec(); + let lookup_prefix_len: usize = lookup_join_node.get_lookup_prefix_len() as usize; + let mut inner_side_key_idxs = vec![]; - for pk in &table_desc.pk { - let key_idx = inner_side_column_ids - .iter() - .position(|&i| table_desc.columns[pk.index as usize].column_id == i) - .ok_or_else(|| { - internal_error("Inner side key is not part of its output columns") - })?; - inner_side_key_idxs.push(key_idx); + for inner_side_key in lookup_join_node.get_inner_side_key() { + inner_side_key_idxs.push(*inner_side_key as usize) } let inner_side_key_types = inner_side_key_idxs @@ -376,6 +381,7 @@ impl BoxedExecutorBuilder for LocalLookupJoinExecutorBuilder { inner_side_schema, inner_side_column_ids, inner_side_key_types: inner_side_key_types.clone(), + lookup_prefix_len, context: source.context().clone(), task_id: source.task_id.clone(), epoch: source.epoch(), @@ -394,6 +400,7 @@ impl BoxedExecutorBuilder for LocalLookupJoinExecutorBuilder { inner_side_key_types, inner_side_key_idxs, null_safe, + lookup_prefix_len, chunk_builder: DataChunkBuilder::new(original_schema.data_types(), chunk_size), schema: actual_schema, output_indices, @@ -414,6 +421,7 @@ struct LocalLookupJoinExecutorArgs { inner_side_key_types: Vec, inner_side_key_idxs: Vec, null_safe: Vec, + lookup_prefix_len: usize, chunk_builder: DataChunkBuilder, schema: Schema, output_indices: Vec, @@ -435,6 +443,7 @@ impl HashKeyDispatcher for LocalLookupJoinExecutorArgs { inner_side_key_types: self.inner_side_key_types, inner_side_key_idxs: self.inner_side_key_idxs, null_safe: self.null_safe, + lookup_prefix_len: self.lookup_prefix_len, chunk_builder: self.chunk_builder, schema: self.schema, output_indices: self.output_indices, @@ -536,6 +545,7 @@ mod tests { inner_side_key_types: vec![inner_side_data_types[0].clone()], inner_side_key_idxs: vec![0], null_safe: vec![null_safe], + lookup_prefix_len: 1, chunk_builder: DataChunkBuilder::new(original_schema.data_types(), CHUNK_SIZE), schema: original_schema.clone(), output_indices: (0..original_schema.len()).into_iter().collect(), diff --git a/src/batch/src/executor/join/lookup_join_base.rs b/src/batch/src/executor/join/lookup_join_base.rs index c45b578ceb101..59697fb08e323 100644 --- a/src/batch/src/executor/join/lookup_join_base.rs +++ b/src/batch/src/executor/join/lookup_join_base.rs @@ -44,6 +44,7 @@ pub struct LookupJoinBase { pub inner_side_key_types: Vec, // Data types only of key columns of inner side table pub inner_side_key_idxs: Vec, pub null_safe: Vec, + pub lookup_prefix_len: usize, pub chunk_builder: DataChunkBuilder, pub schema: Schema, pub output_indices: Vec, @@ -86,6 +87,7 @@ impl LookupJoinBase { chunk.rows().map(|row| { self.outer_side_key_idxs .iter() + .take(self.lookup_prefix_len) .map(|&idx| row.value_at(idx).to_owned_datum()) .collect_vec() }) diff --git a/src/frontend/src/optimizer/plan_node/batch_lookup_join.rs b/src/frontend/src/optimizer/plan_node/batch_lookup_join.rs index 4073d2d5c0dc3..132a2d7d19c74 100644 --- a/src/frontend/src/optimizer/plan_node/batch_lookup_join.rs +++ b/src/frontend/src/optimizer/plan_node/batch_lookup_join.rs @@ -43,6 +43,9 @@ pub struct BatchLookupJoin { /// Output column ids of the right side table right_output_column_ids: Vec, + /// The prefix length of the order key of right side table. + lookup_prefix_len: usize, + /// If `distributed_lookup` is true, it will generate `DistributedLookupJoinNode` for /// `ToBatchProst`. Otherwise, it will generate `LookupJoinNode`. distributed_lookup: bool, @@ -54,6 +57,7 @@ impl BatchLookupJoin { eq_join_predicate: EqJoinPredicate, right_table_desc: TableDesc, right_output_column_ids: Vec, + lookup_prefix_len: usize, distributed_lookup: bool, ) -> Self { let ctx = logical.base.ctx.clone(); @@ -65,6 +69,7 @@ impl BatchLookupJoin { eq_join_predicate, right_table_desc, right_output_column_ids, + lookup_prefix_len, distributed_lookup, } } @@ -148,6 +153,7 @@ impl PlanTreeNodeUnary for BatchLookupJoin { self.eq_join_predicate.clone(), self.right_table_desc.clone(), self.right_output_column_ids.clone(), + self.lookup_prefix_len, self.distributed_lookup, ) } @@ -184,6 +190,12 @@ impl ToBatchProst for BatchLookupJoin { .into_iter() .map(|a| a as _) .collect(), + inner_side_key: self + .eq_join_predicate + .right_eq_indexes() + .into_iter() + .map(|a| a as _) + .collect(), inner_side_table_desc: Some(self.right_table_desc.to_protobuf()), inner_side_column_ids: self .right_output_column_ids @@ -197,6 +209,7 @@ impl ToBatchProst for BatchLookupJoin { .map(|&x| x as u32) .collect(), null_safe: self.eq_join_predicate.null_safes(), + lookup_prefix_len: self.lookup_prefix_len as u32, }) } else { NodeBody::LocalLookupJoin(LocalLookupJoinNode { @@ -212,6 +225,12 @@ impl ToBatchProst for BatchLookupJoin { .into_iter() .map(|a| a as _) .collect(), + inner_side_key: self + .eq_join_predicate + .right_eq_indexes() + .into_iter() + .map(|a| a as _) + .collect(), inner_side_table_desc: Some(self.right_table_desc.to_protobuf()), inner_side_vnode_mapping: vec![], // To be filled in at local.rs inner_side_column_ids: self @@ -227,6 +246,7 @@ impl ToBatchProst for BatchLookupJoin { .collect(), worker_nodes: vec![], // To be filled in at local.rs null_safe: self.eq_join_predicate.null_safes(), + lookup_prefix_len: self.lookup_prefix_len as u32, }) } } diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index a59122ad3192d..04e2b004aa89c 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::max; use std::fmt; use fixedbitset::FixedBitSet; @@ -421,23 +422,50 @@ impl LogicalJoin { let table_desc = logical_scan.table_desc().clone(); let output_column_ids = logical_scan.output_column_ids(); - // Verify that the right join key columns are the same as the primary key - // TODO: Refactor Lookup Join so that prefixes of the primary key are allowed + // Verify that the right join key columns are the the prefix of the primary key and + // also contain the distribution key. let order_col_ids = table_desc.order_column_ids(); - if order_col_ids.len() != predicate.right_eq_indexes().len() { - // In Lookup Join, the right columns of the equality join predicates must be the same as - // the primary key. A different join will be used instead. + let order_key = table_desc.order_column_indices(); + let dist_key = table_desc.distribution_key.clone(); + // The at least prefix of order key that contains distribution key. + let at_least_prefix_len = { + let mut max_pos = 0; + for d in dist_key { + max_pos = max( + max_pos, + order_key + .iter() + .position(|x| *x == d) + .expect("dist_key must in order_key"), + ); + } + max_pos + 1 + }; + + if predicate.right_eq_indexes().len() < at_least_prefix_len { + // In Lookup Join, the right columns of the equality join predicates must contains the + // prefix of order key. return None; } - for (order_col_id, eq_idx) in order_col_ids + // Lookup prefix len is the prefix length of the order key. + let mut lookup_prefix_len = 0; + #[expect(clippy::disallowed_methods)] + for (i, (order_col_id, eq_idx)) in order_col_ids .into_iter() - .zip_eq(predicate.right_eq_indexes()) + .zip(predicate.right_eq_indexes()) + .enumerate() { if order_col_id != output_column_ids[eq_idx] { - // In Lookup Join, the right columns of the equality join predicates must be the - // same as the primary key. A different join will be used instead. - return None; + if i < at_least_prefix_len { + // In Lookup Join, the right columns of the equality join predicates must + // contains the prefix of order key. + return None; + } else { + break; + } + } else { + lookup_prefix_len = i + 1; } } @@ -541,6 +569,7 @@ impl LogicalJoin { new_predicate, table_desc, new_scan_output_column_ids, + lookup_prefix_len, false, ) .into(),