diff --git a/proto/plan_common.proto b/proto/plan_common.proto index 610f40968755c..0f4e988e6c035 100644 --- a/proto/plan_common.proto +++ b/proto/plan_common.proto @@ -141,6 +141,29 @@ enum JoinType { JOIN_TYPE_RIGHT_ANTI = 8; } +enum AsOfJoinType { + AS_OF_JOIN_TYPE_UNSPECIFIED = 0; + AS_OF_JOIN_TYPE_INNER = 1; + AS_OF_JOIN_TYPE_LEFT_OUTER = 2; +} + +enum AsOfJoinInequalityType { + AS_OF_INEQUALITY_TYPE_UNSPECIFIED = 0; + AS_OF_INEQUALITY_TYPE_GT = 1; + AS_OF_INEQUALITY_TYPE_GE = 2; + AS_OF_INEQUALITY_TYPE_LT = 3; + AS_OF_INEQUALITY_TYPE_LE = 4; +} + +message AsOfJoinDesc { + // The index of the right side's as of column. + uint32 right_idx = 1; + // The index of the left side's as of column. + uint32 left_idx = 2; + // The type of the inequality. + AsOfJoinInequalityType inequality_type = 3; +} + // https://github.com/tokio-rs/prost/issues/80 enum FormatType { FORMAT_TYPE_UNSPECIFIED = 0; diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index a96f54818146e..ca67737aeafe0 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -455,6 +455,32 @@ message HashJoinNode { bool is_append_only = 14; } +message AsOfJoinNode { + plan_common.AsOfJoinType join_type = 1; + repeated int32 left_key = 2; + repeated int32 right_key = 3; + // Used for internal table states. + catalog.Table left_table = 4; + // Used for internal table states. + catalog.Table right_table = 5; + // Used for internal table states. + catalog.Table left_degree_table = 6; + // Used for internal table states. + catalog.Table right_degree_table = 7; + // The output indices of current node + repeated uint32 output_indices = 8; + // Left deduped input pk indices. The pk of the left_table and + // The pk of the left_table is [left_join_key | left_inequality_key | left_deduped_input_pk_indices] + // left_inequality_key is not used but for forward compatibility. + repeated uint32 left_deduped_input_pk_indices = 9; + // Right deduped input pk indices. + // The pk of the right_table is [right_join_key | right_inequality_key | right_deduped_input_pk_indices] + // right_inequality_key is not used but for forward compatibility. + repeated uint32 right_deduped_input_pk_indices = 10; + repeated bool null_safe = 11; + optional plan_common.AsOfJoinDesc asof_desc = 12; +} + message TemporalJoinNode { plan_common.JoinType join_type = 1; repeated int32 left_key = 2; diff --git a/src/stream/src/executor/asof_join.rs b/src/stream/src/executor/asof_join.rs new file mode 100644 index 0000000000000..cb8a141481f28 --- /dev/null +++ b/src/stream/src/executor/asof_join.rs @@ -0,0 +1,1377 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use std::collections::{BTreeMap, HashSet}; +use std::ops::Bound; +use std::time::Duration; + +use either::Either; +use itertools::Itertools; +use multimap::MultiMap; +use risingwave_common::array::Op; +use risingwave_common::hash::{HashKey, NullBitmap}; +use risingwave_common::util::epoch::EpochPair; +use risingwave_common::util::iter_util::ZipEqDebug; +use tokio::time::Instant; + +use self::builder::JoinChunkBuilder; +use super::barrier_align::*; +use super::join::hash_join::*; +use super::join::*; +use super::watermark::*; +use crate::executor::join::builder::JoinStreamChunkBuilder; +use crate::executor::prelude::*; + +/// Evict the cache every n rows. +const EVICT_EVERY_N_ROWS: u32 = 16; + +fn is_subset(vec1: Vec, vec2: Vec) -> bool { + HashSet::::from_iter(vec1).is_subset(&vec2.into_iter().collect()) +} + +pub struct JoinParams { + /// Indices of the join keys + pub join_key_indices: Vec, + /// Indices of the input pk after dedup + pub deduped_pk_indices: Vec, +} + +impl JoinParams { + pub fn new(join_key_indices: Vec, deduped_pk_indices: Vec) -> Self { + Self { + join_key_indices, + deduped_pk_indices, + } + } +} + +struct JoinSide { + /// Store all data from a one side stream + ht: JoinHashMap, + /// Indices of the join key columns + join_key_indices: Vec, + /// The data type of all columns without degree. + all_data_types: Vec, + /// The start position for the side in output new columns + start_pos: usize, + /// The mapping from input indices of a side to output columes. + i2o_mapping: Vec<(usize, usize)>, + i2o_mapping_indexed: MultiMap, + /// Whether degree table is needed for this side. + need_degree_table: bool, +} + +impl std::fmt::Debug for JoinSide { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JoinSide") + .field("join_key_indices", &self.join_key_indices) + .field("col_types", &self.all_data_types) + .field("start_pos", &self.start_pos) + .field("i2o_mapping", &self.i2o_mapping) + .field("need_degree_table", &self.need_degree_table) + .finish() + } +} + +impl JoinSide { + // WARNING: Please do not call this until we implement it. + fn is_dirty(&self) -> bool { + unimplemented!() + } + + #[expect(dead_code)] + fn clear_cache(&mut self) { + assert!( + !self.is_dirty(), + "cannot clear cache while states of hash join are dirty" + ); + + // TODO: not working with rearranged chain + // self.ht.clear(); + } + + pub fn init(&mut self, epoch: EpochPair) { + self.ht.init(epoch); + } +} + +/// `AsOfJoinExecutor` takes two input streams and runs equal hash join on them. +/// The output columns are the concatenation of left and right columns. +pub struct AsOfJoinExecutor { + ctx: ActorContextRef, + info: ExecutorInfo, + + /// Left input executor + input_l: Option, + /// Right input executor + input_r: Option, + /// The data types of the formed new columns + actual_output_data_types: Vec, + /// The parameters of the left join executor + side_l: JoinSide, + /// The parameters of the right join executor + side_r: JoinSide, + + metrics: Arc, + /// The maximum size of the chunk produced by executor at a time + chunk_size: usize, + /// Count the messages received, clear to 0 when counted to `EVICT_EVERY_N_MESSAGES` + cnt_rows_received: u32, + + /// watermark column index -> `BufferedWatermarks` + watermark_buffers: BTreeMap>, + + high_join_amplification_threshold: usize, + /// `AsOf` join description + asof_desc: AsOfDesc, +} + +impl std::fmt::Debug + for AsOfJoinExecutor +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsOfJoinExecutor") + .field("join_type", &T) + .field("input_left", &self.input_l.as_ref().unwrap().identity()) + .field("input_right", &self.input_r.as_ref().unwrap().identity()) + .field("side_l", &self.side_l) + .field("side_r", &self.side_r) + .field("pk_indices", &self.info.pk_indices) + .field("schema", &self.info.schema) + .field("actual_output_data_types", &self.actual_output_data_types) + .finish() + } +} + +impl Execute + for AsOfJoinExecutor +{ + fn execute(self: Box) -> BoxedMessageStream { + self.into_stream().boxed() + } +} + +struct EqJoinArgs<'a, K: HashKey, S: StateStore> { + ctx: &'a ActorContextRef, + side_l: &'a mut JoinSide, + side_r: &'a mut JoinSide, + asof_desc: &'a AsOfDesc, + actual_output_data_types: &'a [DataType], + // inequality_watermarks: &'a Watermark, + chunk: StreamChunk, + chunk_size: usize, + cnt_rows_received: &'a mut u32, + high_join_amplification_threshold: usize, +} + +impl AsOfJoinExecutor { + #[allow(clippy::too_many_arguments)] + pub fn new( + ctx: ActorContextRef, + info: ExecutorInfo, + input_l: Executor, + input_r: Executor, + params_l: JoinParams, + params_r: JoinParams, + null_safe: Vec, + output_indices: Vec, + state_table_l: StateTable, + degree_state_table_l: StateTable, + state_table_r: StateTable, + degree_state_table_r: StateTable, + watermark_epoch: AtomicU64Ref, + metrics: Arc, + chunk_size: usize, + high_join_amplification_threshold: usize, + asof_desc: AsOfDesc, + ) -> Self { + let side_l_column_n = input_l.schema().len(); + + let schema_fields = [ + input_l.schema().fields.clone(), + input_r.schema().fields.clone(), + ] + .concat(); + + let original_output_data_types = schema_fields + .iter() + .map(|field| field.data_type()) + .collect_vec(); + let actual_output_data_types = output_indices + .iter() + .map(|&idx| original_output_data_types[idx].clone()) + .collect_vec(); + + // Data types of of hash join state. + let state_all_data_types_l = input_l.schema().data_types(); + let state_all_data_types_r = input_r.schema().data_types(); + + let state_pk_indices_l = input_l.pk_indices().to_vec(); + let state_pk_indices_r = input_r.pk_indices().to_vec(); + + let state_order_key_indices_l = state_table_l.pk_indices(); + let state_order_key_indices_r = state_table_r.pk_indices(); + + let state_join_key_indices_l = params_l.join_key_indices; + let state_join_key_indices_r = params_r.join_key_indices; + + let degree_join_key_indices_l = (0..state_join_key_indices_l.len()).collect_vec(); + let degree_join_key_indices_r = (0..state_join_key_indices_r.len()).collect_vec(); + + let degree_pk_indices_l = (state_join_key_indices_l.len() + ..state_join_key_indices_l.len() + params_l.deduped_pk_indices.len()) + .collect_vec(); + let degree_pk_indices_r = (state_join_key_indices_r.len() + ..state_join_key_indices_r.len() + params_r.deduped_pk_indices.len()) + .collect_vec(); + + // If pk is contained in join key. + let pk_contained_in_jk_l = + is_subset(state_pk_indices_l.clone(), state_join_key_indices_l.clone()); + let pk_contained_in_jk_r = + is_subset(state_pk_indices_r.clone(), state_join_key_indices_r.clone()); + + let join_key_data_types_l = state_join_key_indices_l + .iter() + .map(|idx| state_all_data_types_l[*idx].clone()) + .collect_vec(); + + let join_key_data_types_r = state_join_key_indices_r + .iter() + .map(|idx| state_all_data_types_r[*idx].clone()) + .collect_vec(); + + assert_eq!(join_key_data_types_l, join_key_data_types_r); + + let degree_all_data_types_l = state_order_key_indices_l + .iter() + .map(|idx| state_all_data_types_l[*idx].clone()) + .collect_vec(); + let degree_all_data_types_r = state_order_key_indices_r + .iter() + .map(|idx| state_all_data_types_r[*idx].clone()) + .collect_vec(); + + let null_matched = K::Bitmap::from_bool_vec(null_safe); + + let need_degree_table_l = false; + let need_degree_table_r = false; + + let (left_to_output, right_to_output) = { + let (left_len, right_len) = if is_left_semi_or_anti(T) { + (state_all_data_types_l.len(), 0usize) + } else if is_right_semi_or_anti(T) { + (0usize, state_all_data_types_r.len()) + } else { + (state_all_data_types_l.len(), state_all_data_types_r.len()) + }; + JoinStreamChunkBuilder::get_i2o_mapping(&output_indices, left_len, right_len) + }; + + let l2o_indexed = MultiMap::from_iter(left_to_output.iter().copied()); + let r2o_indexed = MultiMap::from_iter(right_to_output.iter().copied()); + + // handle inequality watermarks + // https://github.com/risingwavelabs/risingwave/issues/18503 + // let inequality_watermarks = None; + let watermark_buffers = BTreeMap::new(); + + let inequal_key_idx_l = Some(asof_desc.left_idx); + let inequal_key_idx_r = Some(asof_desc.right_idx); + + Self { + ctx: ctx.clone(), + info, + input_l: Some(input_l), + input_r: Some(input_r), + actual_output_data_types, + side_l: JoinSide { + ht: JoinHashMap::new( + watermark_epoch.clone(), + join_key_data_types_l, + state_join_key_indices_l.clone(), + state_all_data_types_l.clone(), + state_table_l, + params_l.deduped_pk_indices, + degree_join_key_indices_l, + degree_all_data_types_l, + degree_state_table_l, + degree_pk_indices_l, + null_matched.clone(), + need_degree_table_l, + pk_contained_in_jk_l, + inequal_key_idx_l, + metrics.clone(), + ctx.id, + ctx.fragment_id, + "left", + ), + join_key_indices: state_join_key_indices_l, + all_data_types: state_all_data_types_l, + i2o_mapping: left_to_output, + i2o_mapping_indexed: l2o_indexed, + start_pos: 0, + need_degree_table: need_degree_table_l, + }, + side_r: JoinSide { + ht: JoinHashMap::new( + watermark_epoch, + join_key_data_types_r, + state_join_key_indices_r.clone(), + state_all_data_types_r.clone(), + state_table_r, + params_r.deduped_pk_indices, + degree_join_key_indices_r, + degree_all_data_types_r, + degree_state_table_r, + degree_pk_indices_r, + null_matched, + need_degree_table_r, + pk_contained_in_jk_r, + inequal_key_idx_r, + metrics.clone(), + ctx.id, + ctx.fragment_id, + "right", + ), + join_key_indices: state_join_key_indices_r, + all_data_types: state_all_data_types_r, + start_pos: side_l_column_n, + i2o_mapping: right_to_output, + i2o_mapping_indexed: r2o_indexed, + need_degree_table: need_degree_table_r, + }, + metrics, + chunk_size, + cnt_rows_received: 0, + watermark_buffers, + high_join_amplification_threshold, + asof_desc, + } + } + + #[try_stream(ok = Message, error = StreamExecutorError)] + async fn into_stream(mut self) { + let input_l = self.input_l.take().unwrap(); + let input_r = self.input_r.take().unwrap(); + let aligned_stream = barrier_align( + input_l.execute(), + input_r.execute(), + self.ctx.id, + self.ctx.fragment_id, + self.metrics.clone(), + "Join", + ); + pin_mut!(aligned_stream); + + let barrier = expect_first_barrier_from_aligned_stream(&mut aligned_stream).await?; + self.side_l.init(barrier.epoch); + self.side_r.init(barrier.epoch); + + // The first barrier message should be propagated. + yield Message::Barrier(barrier); + let actor_id_str = self.ctx.id.to_string(); + let fragment_id_str = self.ctx.fragment_id.to_string(); + + // initialized some metrics + let join_actor_input_waiting_duration_ns = self + .metrics + .join_actor_input_waiting_duration_ns + .with_guarded_label_values(&[&actor_id_str, &fragment_id_str]); + let left_join_match_duration_ns = self + .metrics + .join_match_duration_ns + .with_guarded_label_values(&[&actor_id_str, &fragment_id_str, "left"]); + let right_join_match_duration_ns = self + .metrics + .join_match_duration_ns + .with_guarded_label_values(&[&actor_id_str, &fragment_id_str, "right"]); + + let barrier_join_match_duration_ns = self + .metrics + .join_match_duration_ns + .with_guarded_label_values(&[&actor_id_str, &fragment_id_str, "barrier"]); + + let left_join_cached_entry_count = self + .metrics + .join_cached_entry_count + .with_guarded_label_values(&[&actor_id_str, &fragment_id_str, "left"]); + + let right_join_cached_entry_count = self + .metrics + .join_cached_entry_count + .with_guarded_label_values(&[&actor_id_str, &fragment_id_str, "right"]); + + let mut start_time = Instant::now(); + + while let Some(msg) = aligned_stream + .next() + .instrument_await("hash_join_barrier_align") + .await + { + join_actor_input_waiting_duration_ns.inc_by(start_time.elapsed().as_nanos() as u64); + match msg? { + AlignedMessage::WatermarkLeft(watermark) => { + for watermark_to_emit in self.handle_watermark(SideType::Left, watermark)? { + yield Message::Watermark(watermark_to_emit); + } + } + AlignedMessage::WatermarkRight(watermark) => { + for watermark_to_emit in self.handle_watermark(SideType::Right, watermark)? { + yield Message::Watermark(watermark_to_emit); + } + } + AlignedMessage::Left(chunk) => { + let mut left_time = Duration::from_nanos(0); + let mut left_start_time = Instant::now(); + #[for_await] + for chunk in Self::eq_join_left(EqJoinArgs { + ctx: &self.ctx, + side_l: &mut self.side_l, + side_r: &mut self.side_r, + asof_desc: &self.asof_desc, + actual_output_data_types: &self.actual_output_data_types, + // inequality_watermarks: &self.inequality_watermarks, + chunk, + chunk_size: self.chunk_size, + cnt_rows_received: &mut self.cnt_rows_received, + high_join_amplification_threshold: self.high_join_amplification_threshold, + }) { + left_time += left_start_time.elapsed(); + yield Message::Chunk(chunk?); + left_start_time = Instant::now(); + } + left_time += left_start_time.elapsed(); + left_join_match_duration_ns.inc_by(left_time.as_nanos() as u64); + self.try_flush_data().await?; + } + AlignedMessage::Right(chunk) => { + let mut right_time = Duration::from_nanos(0); + let mut right_start_time = Instant::now(); + #[for_await] + for chunk in Self::eq_join_right(EqJoinArgs { + ctx: &self.ctx, + side_l: &mut self.side_l, + side_r: &mut self.side_r, + asof_desc: &self.asof_desc, + actual_output_data_types: &self.actual_output_data_types, + // inequality_watermarks: &self.inequality_watermarks, + chunk, + chunk_size: self.chunk_size, + cnt_rows_received: &mut self.cnt_rows_received, + high_join_amplification_threshold: self.high_join_amplification_threshold, + }) { + right_time += right_start_time.elapsed(); + yield Message::Chunk(chunk?); + right_start_time = Instant::now(); + } + right_time += right_start_time.elapsed(); + right_join_match_duration_ns.inc_by(right_time.as_nanos() as u64); + self.try_flush_data().await?; + } + AlignedMessage::Barrier(barrier) => { + let barrier_start_time = Instant::now(); + self.flush_data(barrier.epoch).await?; + + // Update the vnode bitmap for state tables of both sides if asked. + if let Some(vnode_bitmap) = barrier.as_update_vnode_bitmap(self.ctx.id) { + if self.side_l.ht.update_vnode_bitmap(vnode_bitmap.clone()) { + self.watermark_buffers + .values_mut() + .for_each(|buffers| buffers.clear()); + // self.inequality_watermarks.fill(None); + } + self.side_r.ht.update_vnode_bitmap(vnode_bitmap); + } + + // Report metrics of cached join rows/entries + for (join_cached_entry_count, ht) in [ + (&left_join_cached_entry_count, &self.side_l.ht), + (&right_join_cached_entry_count, &self.side_r.ht), + ] { + join_cached_entry_count.set(ht.entry_count() as i64); + } + + barrier_join_match_duration_ns + .inc_by(barrier_start_time.elapsed().as_nanos() as u64); + yield Message::Barrier(barrier); + } + } + start_time = Instant::now(); + } + } + + async fn flush_data(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { + // All changes to the state has been buffered in the mem-table of the state table. Just + // `commit` them here. + self.side_l.ht.flush(epoch).await?; + self.side_r.ht.flush(epoch).await?; + Ok(()) + } + + async fn try_flush_data(&mut self) -> StreamExecutorResult<()> { + // All changes to the state has been buffered in the mem-table of the state table. Just + // `commit` them here. + self.side_l.ht.try_flush().await?; + self.side_r.ht.try_flush().await?; + Ok(()) + } + + // We need to manually evict the cache. + fn evict_cache( + side_update: &mut JoinSide, + side_match: &mut JoinSide, + cnt_rows_received: &mut u32, + ) { + *cnt_rows_received += 1; + if *cnt_rows_received == EVICT_EVERY_N_ROWS { + side_update.ht.evict(); + side_match.ht.evict(); + *cnt_rows_received = 0; + } + } + + fn handle_watermark( + &mut self, + side: SideTypePrimitive, + watermark: Watermark, + ) -> StreamExecutorResult> { + let (side_update, side_match) = if side == SideType::Left { + (&mut self.side_l, &mut self.side_r) + } else { + (&mut self.side_r, &mut self.side_l) + }; + + // State cleaning + if side_update.join_key_indices[0] == watermark.col_idx { + side_match.ht.update_watermark(watermark.val.clone()); + } + + // Select watermarks to yield. + let wm_in_jk = side_update + .join_key_indices + .iter() + .positions(|idx| *idx == watermark.col_idx); + let mut watermarks_to_emit = vec![]; + for idx in wm_in_jk { + let buffers = self + .watermark_buffers + .entry(idx) + .or_insert_with(|| BufferedWatermarks::with_ids([SideType::Left, SideType::Right])); + if let Some(selected_watermark) = buffers.handle_watermark(side, watermark.clone()) { + let empty_indices = vec![]; + let output_indices = side_update + .i2o_mapping_indexed + .get_vec(&side_update.join_key_indices[idx]) + .unwrap_or(&empty_indices) + .iter() + .chain( + side_match + .i2o_mapping_indexed + .get_vec(&side_match.join_key_indices[idx]) + .unwrap_or(&empty_indices), + ); + for output_idx in output_indices { + watermarks_to_emit.push(selected_watermark.clone().with_idx(*output_idx)); + } + }; + } + Ok(watermarks_to_emit) + } + + /// the data the hash table and match the coming + /// data chunk with the executor state + async fn hash_eq_match( + key: &K, + ht: &mut JoinHashMap, + ) -> StreamExecutorResult> { + if !key.null_bitmap().is_subset(ht.null_matched()) { + Ok(None) + } else { + ht.take_state(key).await.map(Some) + } + } + + #[try_stream(ok = StreamChunk, error = StreamExecutorError)] + async fn eq_join_left(args: EqJoinArgs<'_, K, S>) { + let EqJoinArgs { + ctx: _, + side_l, + side_r, + asof_desc, + actual_output_data_types, + // inequality_watermarks, + chunk, + chunk_size, + cnt_rows_received, + high_join_amplification_threshold: _, + } = args; + + let (side_update, side_match) = (side_l, side_r); + + let mut join_chunk_builder = + JoinChunkBuilder::::new(JoinStreamChunkBuilder::new( + chunk_size, + actual_output_data_types.to_vec(), + side_update.i2o_mapping.clone(), + side_match.i2o_mapping.clone(), + )); + + let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk()); + for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) { + let Some((op, row)) = r else { + continue; + }; + Self::evict_cache(side_update, side_match, cnt_rows_received); + + let matched_rows = if !side_update.ht.check_inequal_key_null(&row) { + Self::hash_eq_match(key, &mut side_match.ht).await? + } else { + None + }; + let inequal_key = side_update.ht.serialize_inequal_key_from_row(row); + + if let Some(matched_rows) = matched_rows { + let matched_row_by_inequality = match asof_desc.inequality_type { + AsOfInequalityType::Lt => matched_rows.lower_bound_by_inequality( + Bound::Excluded(&inequal_key), + &side_match.all_data_types, + ), + AsOfInequalityType::Le => matched_rows.lower_bound_by_inequality( + Bound::Included(&inequal_key), + &side_match.all_data_types, + ), + AsOfInequalityType::Gt => matched_rows.upper_bound_by_inequality( + Bound::Excluded(&inequal_key), + &side_match.all_data_types, + ), + AsOfInequalityType::Ge => matched_rows.upper_bound_by_inequality( + Bound::Included(&inequal_key), + &side_match.all_data_types, + ), + }; + match op { + Op::Insert | Op::UpdateInsert => { + if let Some(matched_row_by_inequality) = matched_row_by_inequality { + let matched_row = matched_row_by_inequality?; + + if let Some(chunk) = + join_chunk_builder.with_match_on_insert(&row, &matched_row) + { + yield chunk; + } + } else if let Some(chunk) = + join_chunk_builder.forward_if_not_matched(Op::Insert, row) + { + yield chunk; + } + side_update.ht.insert_row(key, row).await?; + } + Op::Delete | Op::UpdateDelete => { + if let Some(matched_row_by_inequality) = matched_row_by_inequality { + let matched_row = matched_row_by_inequality?; + + if let Some(chunk) = + join_chunk_builder.with_match_on_delete(&row, &matched_row) + { + yield chunk; + } + } else if let Some(chunk) = + join_chunk_builder.forward_if_not_matched(Op::Delete, row) + { + yield chunk; + } + side_update.ht.delete_row(key, row)?; + } + } + // Insert back the state taken from ht. + side_match.ht.update_state(key, matched_rows); + } else { + // Row which violates null-safe bitmap will never be matched so we need not + // store. + match op { + Op::Insert | Op::UpdateInsert => { + if let Some(chunk) = + join_chunk_builder.forward_if_not_matched(Op::Insert, row) + { + yield chunk; + } + } + Op::Delete | Op::UpdateDelete => { + if let Some(chunk) = + join_chunk_builder.forward_if_not_matched(Op::Delete, row) + { + yield chunk; + } + } + } + } + } + if let Some(chunk) = join_chunk_builder.take() { + yield chunk; + } + } + + #[try_stream(ok = StreamChunk, error = StreamExecutorError)] + async fn eq_join_right(args: EqJoinArgs<'_, K, S>) { + let EqJoinArgs { + ctx, + side_l, + side_r, + asof_desc, + actual_output_data_types, + // inequality_watermarks, + chunk, + chunk_size, + cnt_rows_received, + high_join_amplification_threshold, + } = args; + + let (side_update, side_match) = (side_r, side_l); + + let mut join_chunk_builder = JoinStreamChunkBuilder::new( + chunk_size, + actual_output_data_types.to_vec(), + side_update.i2o_mapping.clone(), + side_match.i2o_mapping.clone(), + ); + + let join_matched_rows_metrics = ctx + .streaming_metrics + .join_matched_join_keys + .with_guarded_label_values(&[ + &ctx.id.to_string(), + &ctx.fragment_id.to_string(), + &side_update.ht.table_id().to_string(), + ]); + + let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk()); + for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) { + let Some((op, row)) = r else { + continue; + }; + let mut join_matched_rows_cnt = 0; + + Self::evict_cache(side_update, side_match, cnt_rows_received); + + let matched_rows = if !side_update.ht.check_inequal_key_null(&row) { + Self::hash_eq_match(key, &mut side_match.ht).await? + } else { + None + }; + let inequal_key = side_update.ht.serialize_inequal_key_from_row(row); + + if let Some(matched_rows) = matched_rows { + let update_rows = Self::hash_eq_match(key, &mut side_update.ht).await?.expect("None is not expected because we have checked null in key when getting matched_rows"); + let right_inequality_index = update_rows.inequality_index(); + let (row_to_delete_r, row_to_insert_r) = + if let Some(pks) = right_inequality_index.get(&inequal_key) { + assert!(!pks.is_empty()); + let row_pk = side_match.ht.serialize_pk_from_row(row); + match op { + Op::Insert | Op::UpdateInsert => { + // If there are multiple rows match the inequality key in the right table, we use one with smallest pk. + let smallest_pk = pks.first_key_sorted().unwrap(); + if smallest_pk > &row_pk { + // smallest_pk is in the cache index, so it must exist in the cache. + if let Some(to_delete_row) = update_rows + .get_by_indexed_pk(smallest_pk, &side_update.all_data_types) + { + ( + Some(Either::Left(to_delete_row?.row)), + Some(Either::Right(row)), + ) + } else { + // Something wrong happened. Ignore this row in non strict consistency mode. + (None, None) + } + } else { + // No affected row in the right table. + (None, None) + } + } + Op::Delete | Op::UpdateDelete => { + let smallest_pk = pks.first_key_sorted().unwrap(); + if smallest_pk == &row_pk { + if let Some(second_smallest_pk) = pks.second_key_sorted() { + if let Some(to_insert_row) = update_rows.get_by_indexed_pk( + second_smallest_pk, + &side_update.all_data_types, + ) { + ( + Some(Either::Right(row)), + Some(Either::Left(to_insert_row?.row)), + ) + } else { + // Something wrong happened. Ignore this row in non strict consistency mode. + (None, None) + } + } else { + (Some(Either::Right(row)), None) + } + } else { + // No affected row in the right table. + (None, None) + } + } + } + } else { + match op { + // Decide the row_to_delete later + Op::Insert | Op::UpdateInsert => (None, Some(Either::Right(row))), + // Decide the row_to_insert later + Op::Delete | Op::UpdateDelete => (Some(Either::Right(row)), None), + } + }; + + // 4 cases for row_to_delete_r and row_to_insert_r: + // 1. Some(_), Some(_): delete row_to_delete_r and insert row_to_insert_r + // 2. None, Some(_) : row_to_delete to be decided by the nearest inequality key + // 3. Some(_), None : row_to_insert to be decided by the nearest inequality key + // 4. None, None : do nothing + if row_to_delete_r.is_none() && row_to_insert_r.is_none() { + // no row to delete or insert. + } else { + let prev_inequality_key = + right_inequality_index.upper_bound_key(Bound::Excluded(&inequal_key)); + let next_inequality_key = + right_inequality_index.lower_bound_key(Bound::Excluded(&inequal_key)); + let affected_row_r = match asof_desc.inequality_type { + AsOfInequalityType::Lt | AsOfInequalityType::Le => next_inequality_key + .and_then(|k| { + update_rows.get_first_by_inequality(k, &side_update.all_data_types) + }), + AsOfInequalityType::Gt | AsOfInequalityType::Ge => prev_inequality_key + .and_then(|k| { + update_rows.get_first_by_inequality(k, &side_update.all_data_types) + }), + } + .transpose()? + .map(|r| Either::Left(r.row)); + + let (row_to_delete_r, row_to_insert_r) = + match (&row_to_delete_r, &row_to_insert_r) { + (Some(_), Some(_)) => (row_to_delete_r, row_to_insert_r), + (None, Some(_)) => (affected_row_r, row_to_insert_r), + (Some(_), None) => (row_to_delete_r, affected_row_r), + (None, None) => unreachable!(), + }; + let range = match asof_desc.inequality_type { + AsOfInequalityType::Lt => ( + prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included), + Bound::Excluded(&inequal_key), + ), + AsOfInequalityType::Le => ( + prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded), + Bound::Included(&inequal_key), + ), + AsOfInequalityType::Gt => ( + Bound::Excluded(&inequal_key), + next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included), + ), + AsOfInequalityType::Ge => ( + Bound::Included(&inequal_key), + next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded), + ), + }; + + let rows_l = + matched_rows.range_by_inequality(range, &side_match.all_data_types); + for row_l in rows_l { + join_matched_rows_cnt += 1; + let row_l = row_l?.row; + if let Some(row_to_delete_r) = &row_to_delete_r { + if let Some(chunk) = + join_chunk_builder.append_row(Op::Delete, row_to_delete_r, &row_l) + { + yield chunk; + } + } else if is_as_of_left_outer(T) { + if let Some(chunk) = + join_chunk_builder.append_row_matched(Op::Delete, &row_l) + { + yield chunk; + } + } + if let Some(row_to_insert_r) = &row_to_insert_r { + if let Some(chunk) = + join_chunk_builder.append_row(Op::Insert, row_to_insert_r, &row_l) + { + yield chunk; + } + } else if is_as_of_left_outer(T) { + if let Some(chunk) = + join_chunk_builder.append_row_matched(Op::Insert, &row_l) + { + yield chunk; + } + } + } + } + // Insert back the state taken from ht. + side_match.ht.update_state(key, matched_rows); + side_update.ht.update_state(key, update_rows); + + match op { + Op::Insert | Op::UpdateInsert => { + side_update.ht.insert_row(key, row).await?; + } + Op::Delete | Op::UpdateDelete => { + side_update.ht.delete_row(key, row)?; + } + } + } else { + // Row which violates null-safe bitmap will never be matched so we need not + // store. + // Noop here because we only support left outer AsOf join. + } + join_matched_rows_metrics.observe(join_matched_rows_cnt as _); + if join_matched_rows_cnt > high_join_amplification_threshold { + let join_key_data_types = side_update.ht.join_key_data_types(); + let key = key.deserialize(join_key_data_types)?; + tracing::warn!(target: "high_join_amplification", + matched_rows_len = join_matched_rows_cnt, + update_table_id = side_update.ht.table_id(), + match_table_id = side_match.ht.table_id(), + join_key = ?key, + actor_id = ctx.id, + fragment_id = ctx.fragment_id, + "large rows matched for join key when AsOf join updating right side", + ); + } + } + if let Some(chunk) = join_chunk_builder.take() { + yield chunk; + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::AtomicU64; + + use risingwave_common::array::*; + use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId}; + use risingwave_common::hash::Key64; + use risingwave_common::util::epoch::test_epoch; + use risingwave_common::util::sort_util::OrderType; + use risingwave_storage::memory::MemoryStateStore; + + use super::*; + use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt}; + + async fn create_in_memory_state_table( + mem_state: MemoryStateStore, + data_types: &[DataType], + order_types: &[OrderType], + pk_indices: &[usize], + table_id: u32, + ) -> (StateTable, StateTable) { + let column_descs = data_types + .iter() + .enumerate() + .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone())) + .collect_vec(); + let state_table = StateTable::new_without_distribution( + mem_state.clone(), + TableId::new(table_id), + column_descs, + order_types.to_vec(), + pk_indices.to_vec(), + ) + .await; + + // Create degree table + let mut degree_table_column_descs = vec![]; + pk_indices.iter().enumerate().for_each(|(pk_id, idx)| { + degree_table_column_descs.push(ColumnDesc::unnamed( + ColumnId::new(pk_id as i32), + data_types[*idx].clone(), + )) + }); + degree_table_column_descs.push(ColumnDesc::unnamed( + ColumnId::new(pk_indices.len() as i32), + DataType::Int64, + )); + let degree_state_table = StateTable::new_without_distribution( + mem_state, + TableId::new(table_id + 1), + degree_table_column_descs, + order_types.to_vec(), + pk_indices.to_vec(), + ) + .await; + (state_table, degree_state_table) + } + + async fn create_executor( + asof_desc: AsOfDesc, + ) -> (MessageSender, MessageSender, BoxedMessageStream) { + let schema = Schema { + fields: vec![ + Field::unnamed(DataType::Int64), // join key + Field::unnamed(DataType::Int64), + Field::unnamed(DataType::Int64), + ], + }; + let (tx_l, source_l) = MockSource::channel(); + let source_l = source_l.into_executor(schema.clone(), vec![1]); + let (tx_r, source_r) = MockSource::channel(); + let source_r = source_r.into_executor(schema, vec![1]); + let params_l = JoinParams::new(vec![0], vec![1]); + let params_r = JoinParams::new(vec![0], vec![1]); + + let mem_state = MemoryStateStore::new(); + + let (state_l, degree_state_l) = create_in_memory_state_table( + mem_state.clone(), + &[DataType::Int64, DataType::Int64, DataType::Int64], + &[ + OrderType::ascending(), + OrderType::ascending(), + OrderType::ascending(), + ], + &[0, asof_desc.left_idx, 1], + 0, + ) + .await; + + let (state_r, degree_state_r) = create_in_memory_state_table( + mem_state, + &[DataType::Int64, DataType::Int64, DataType::Int64], + &[ + OrderType::ascending(), + OrderType::ascending(), + OrderType::ascending(), + ], + &[0, asof_desc.right_idx, 1], + 2, + ) + .await; + + let schema: Schema = [source_l.schema().fields(), source_r.schema().fields()] + .concat() + .into_iter() + .collect(); + let schema_len = schema.len(); + let info = ExecutorInfo { + schema, + pk_indices: vec![1], + identity: "HashJoinExecutor".to_string(), + }; + + let executor = AsOfJoinExecutor::::new( + ActorContext::for_test(123), + info, + source_l, + source_r, + params_l, + params_r, + vec![false], + (0..schema_len).collect_vec(), + state_l, + degree_state_l, + state_r, + degree_state_r, + Arc::new(AtomicU64::new(0)), + Arc::new(StreamingMetrics::unused()), + 1024, + 2048, + asof_desc, + ); + (tx_l, tx_r, executor.boxed().execute()) + } + + #[tokio::test] + async fn test_as_of_inner_join() -> StreamExecutorResult<()> { + let asof_desc = AsOfDesc { + left_idx: 0, + right_idx: 2, + inequality_type: AsOfInequalityType::Lt, + }; + + let chunk_l1 = StreamChunk::from_pretty( + " I I I + + 1 4 7 + + 2 5 8 + + 3 6 9", + ); + let chunk_l2 = StreamChunk::from_pretty( + " I I I + + 3 8 1 + - 3 8 1", + ); + let chunk_r1 = StreamChunk::from_pretty( + " I I I + + 2 1 7 + + 2 2 1 + + 2 3 4 + + 2 4 2 + + 6 1 9 + + 6 2 9", + ); + let chunk_r2 = StreamChunk::from_pretty( + " I I I + - 2 3 4", + ); + let chunk_r3 = StreamChunk::from_pretty( + " I I I + + 2 3 3", + ); + let chunk_l3 = StreamChunk::from_pretty( + " I I I + - 2 5 8", + ); + let chunk_l4 = StreamChunk::from_pretty( + " I I I + + 6 3 1 + + 6 4 1", + ); + let chunk_r4 = StreamChunk::from_pretty( + " I I I + - 6 1 9", + ); + + let (mut tx_l, mut tx_r, mut hash_join) = + create_executor::<{ AsOfJoinType::Inner }>(asof_desc).await; + + // push the init barrier for left and right + tx_l.push_barrier(test_epoch(1), false); + tx_r.push_barrier(test_epoch(1), false); + hash_join.next_unwrap_ready_barrier()?; + + // push the 1st left chunk + tx_l.push_chunk(chunk_l1); + hash_join.next_unwrap_pending(); + + // push the init barrier for left and right + tx_l.push_barrier(test_epoch(2), false); + tx_r.push_barrier(test_epoch(2), false); + hash_join.next_unwrap_ready_barrier()?; + + // push the 2nd left chunk + tx_l.push_chunk(chunk_l2); + hash_join.next_unwrap_pending(); + + // push the 1st right chunk + tx_r.push_chunk(chunk_r1); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + + 2 5 8 2 1 7 + - 2 5 8 2 1 7 + + 2 5 8 2 3 4" + ) + ); + + // push the 2nd right chunk + tx_r.push_chunk(chunk_r2); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 2 5 8 2 3 4 + + 2 5 8 2 1 7" + ) + ); + + // push the 3rd right chunk + tx_r.push_chunk(chunk_r3); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 2 5 8 2 1 7 + + 2 5 8 2 3 3" + ) + ); + + // push the 3rd left chunk + tx_l.push_chunk(chunk_l3); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 2 5 8 2 3 3" + ) + ); + + // push the 4th left chunk + tx_l.push_chunk(chunk_l4); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + + 6 3 1 6 1 9 + + 6 4 1 6 1 9" + ) + ); + + // push the 4th right chunk + tx_r.push_chunk(chunk_r4); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 6 3 1 6 1 9 + + 6 3 1 6 2 9 + - 6 4 1 6 1 9 + + 6 4 1 6 2 9" + ) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_as_of_left_outer_join() -> StreamExecutorResult<()> { + let asof_desc = AsOfDesc { + left_idx: 1, + right_idx: 2, + inequality_type: AsOfInequalityType::Ge, + }; + + let chunk_l1 = StreamChunk::from_pretty( + " I I I + + 1 4 7 + + 2 5 8 + + 3 6 9", + ); + let chunk_l2 = StreamChunk::from_pretty( + " I I I + + 3 8 1 + - 3 8 1", + ); + let chunk_r1 = StreamChunk::from_pretty( + " I I I + + 2 3 4 + + 2 2 5 + + 2 1 5 + + 6 1 8 + + 6 2 9", + ); + let chunk_r2 = StreamChunk::from_pretty( + " I I I + - 2 3 4 + - 2 1 5 + - 2 2 5", + ); + let chunk_l3 = StreamChunk::from_pretty( + " I I I + + 6 8 9", + ); + let chunk_r3 = StreamChunk::from_pretty( + " I I I + - 6 1 8", + ); + + let (mut tx_l, mut tx_r, mut hash_join) = + create_executor::<{ AsOfJoinType::LeftOuter }>(asof_desc).await; + + // push the init barrier for left and right + tx_l.push_barrier(test_epoch(1), false); + tx_r.push_barrier(test_epoch(1), false); + hash_join.next_unwrap_ready_barrier()?; + + // push the 1st left chunk + tx_l.push_chunk(chunk_l1); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + + 1 4 7 . . . + + 2 5 8 . . . + + 3 6 9 . . ." + ) + ); + + // push the init barrier for left and right + tx_l.push_barrier(test_epoch(2), false); + tx_r.push_barrier(test_epoch(2), false); + hash_join.next_unwrap_ready_barrier()?; + + // push the 2nd left chunk + tx_l.push_chunk(chunk_l2); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + + 3 8 1 . . . + - 3 8 1 . . ." + ) + ); + + // push the 1st right chunk + tx_r.push_chunk(chunk_r1); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 2 5 8 . . . + + 2 5 8 2 3 4 + - 2 5 8 2 3 4 + + 2 5 8 2 2 5 + - 2 5 8 2 2 5 + + 2 5 8 2 1 5" + ) + ); + + // push the 2nd right chunk + tx_r.push_chunk(chunk_r2); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 2 5 8 2 1 5 + + 2 5 8 2 2 5 + - 2 5 8 2 2 5 + + 2 5 8 . . ." + ) + ); + + // push the 3rd left chunk + tx_l.push_chunk(chunk_l3); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + + 6 8 9 6 1 8" + ) + ); + + // push the 3rd right chunk + tx_r.push_chunk(chunk_r3); + let chunk = hash_join.next_unwrap_ready_chunk()?; + assert_eq!( + chunk, + StreamChunk::from_pretty( + " I I I I I I + - 6 8 9 6 1 8 + + 6 8 9 . . ." + ) + ); + Ok(()) + } +} diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index e1a1b177bcfcc..e23c17724be02 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -396,6 +396,7 @@ impl HashJoinExecutor HashJoinExecutor; +type InequalKeyType = Vec; pub type StateValueType = EncodedJoinRow; pub type HashValueType = Box; @@ -154,6 +157,21 @@ impl JoinHashMapMetrics { } } +/// Inequality key description for `AsOf` join. +struct InequalityKeyDesc { + idx: usize, + serializer: OrderedRowSerde, +} + +impl InequalityKeyDesc { + /// Serialize the inequality key from a row. + pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType { + let indices = vec![self.idx]; + let inequality_key = row.project(&indices); + inequality_key.memcmp_serialize(&self.serializer) + } +} + pub struct JoinHashMap { /// Store the join states. inner: JoinHashMapInner, @@ -182,6 +200,8 @@ pub struct JoinHashMap { need_degree_table: bool, /// Pk is part of the join key. pk_contained_in_jk: bool, + /// Inequality key description for `AsOf` join. + inequality_key_desc: Option, /// Metrics of the hash map metrics: JoinHashMapMetrics, } @@ -230,6 +250,7 @@ impl JoinHashMap { null_matched: K::Bitmap, need_degree_table: bool, pk_contained_in_jk: bool, + inequality_key_idx: Option, metrics: Arc, actor_id: ActorId, fragment_id: FragmentId, @@ -246,6 +267,14 @@ impl JoinHashMap { vec![OrderType::ascending(); state_pk_indices.len()], ); + let inequality_key_desc = inequality_key_idx.map(|idx| { + let serializer = OrderedRowSerde::new( + vec![state_all_data_types[idx].clone()], + vec![OrderType::ascending()], + ); + InequalityKeyDesc { idx, serializer } + }); + let join_table_id = state_table.table_id(); let state = TableInner { pk_indices: state_pk_indices, @@ -286,6 +315,7 @@ impl JoinHashMap { degree_state, need_degree_table, pk_contained_in_jk, + inequality_key_desc, metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id), } } @@ -427,11 +457,16 @@ impl JoinHashMap { let degree_i64 = degree_row .datum_at(degree_row.len() - 1) .expect("degree should not be NULL"); + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(row.row())); entry_state .insert( pk, JoinRow::new(row.row(), degree_i64.into_int64() as u64) .encode(), + inequality_key, ) .with_context(|| self.state.error_context(row.row()))?; } @@ -459,6 +494,10 @@ impl JoinHashMap { .as_ref() .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(row.row())); let degree_i64 = degree_row .datum_at(degree_row.len() - 1) .expect("degree should not be NULL"); @@ -466,6 +505,7 @@ impl JoinHashMap { .insert( pk, JoinRow::new(row.row(), degree_i64.into_int64() as u64).encode(), + inequality_key, ) .with_context(|| self.state.error_context(row.row()))?; } @@ -486,8 +526,12 @@ impl JoinHashMap { .as_ref() .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(row.row())); entry_state - .insert(pk, JoinRow::new(row.row(), 0).encode()) + .insert(pk, JoinRow::new(row.row(), 0).encode(), inequality_key) .with_context(|| self.state.error_context(row.row()))?; } }; @@ -511,9 +555,12 @@ impl JoinHashMap { /// Insert a join row #[allow(clippy::unused_async)] pub async fn insert(&mut self, key: &K, value: JoinRow) -> StreamExecutorResult<()> { - let pk = (&value.row) - .project(&self.state.pk_indices) - .memcmp_serialize(&self.pk_serializer); + let pk = self.serialize_pk_from_row(&value.row); + + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(&value.row)); // TODO(yuhao): avoid this `contains`. // https://github.com/risingwavelabs/risingwave/issues/9233 @@ -521,14 +568,14 @@ impl JoinHashMap { // Update cache let mut entry = self.inner.get_mut(key).unwrap(); entry - .insert(pk, value.encode()) + .insert(pk, value.encode(), inequality_key) .with_context(|| self.state.error_context(&value.row))?; } else if self.pk_contained_in_jk { // Refill cache when the join key exist in neither cache or storage. self.metrics.insert_cache_miss_count += 1; let mut state = JoinEntryState::default(); state - .insert(pk, value.encode()) + .insert(pk, value.encode(), inequality_key) .with_context(|| self.state.error_context(&value.row))?; self.update_state(key, state.into()); } @@ -545,24 +592,25 @@ impl JoinHashMap { #[allow(clippy::unused_async)] pub async fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> { let join_row = JoinRow::new(&value, 0); - let pk = (&value) - .project(&self.state.pk_indices) - .memcmp_serialize(&self.pk_serializer); - + let pk = self.serialize_pk_from_row(&value); + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(&value)); // TODO(yuhao): avoid this `contains`. // https://github.com/risingwavelabs/risingwave/issues/9233 if self.inner.contains(key) { // Update cache let mut entry = self.inner.get_mut(key).unwrap(); entry - .insert(pk, join_row.encode()) + .insert(pk, join_row.encode(), inequality_key) .with_context(|| self.state.error_context(&value))?; } else if self.pk_contained_in_jk { // Refill cache when the join key exist in neither cache or storage. self.metrics.insert_cache_miss_count += 1; let mut state = JoinEntryState::default(); state - .insert(pk, join_row.encode()) + .insert(pk, join_row.encode(), inequality_key) .with_context(|| self.state.error_context(&value))?; self.update_state(key, state.into()); } @@ -578,8 +626,12 @@ impl JoinHashMap { let pk = (&value.row) .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(&value.row)); entry - .remove(pk) + .remove(pk, inequality_key.as_ref()) .with_context(|| self.state.error_context(&value.row))?; } @@ -597,8 +649,13 @@ impl JoinHashMap { let pk = (&value) .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); + + let inequality_key = self + .inequality_key_desc + .as_ref() + .map(|desc| desc.serialize_inequal_key_from_row(&value)); entry - .remove(pk) + .remove(pk, inequality_key.as_ref()) .with_context(|| self.state.error_context(&value))?; } @@ -680,6 +737,29 @@ impl JoinHashMap { pub fn join_key_data_types(&self) -> &[DataType] { &self.join_key_data_types } + + /// Return true if the inequality key is null. + /// # Panics + /// Panics if the inequality key is not set. + pub fn check_inequal_key_null(&self, row: &impl Row) -> bool { + let desc = self.inequality_key_desc.as_ref().unwrap(); + row.datum_at(desc.idx).is_none() + } + + /// Serialize the inequality key from a row. + /// # Panics + /// Panics if the inequality key is not set. + pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType { + self.inequality_key_desc + .as_ref() + .unwrap() + .serialize_inequal_key_from_row(&row) + } + + pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType { + row.project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer) + } } use risingwave_common_estimate_size::KvSize; @@ -695,7 +775,9 @@ use super::*; #[derive(Default)] pub struct JoinEntryState { /// The full copy of the state. - cached: join_row_set::JoinRowSet, + cached: JoinRowSet, + /// Index used for AS OF join. The key is inequal column value. The value is the primary key in `cached`. + inequality_index: JoinRowSet>, kv_heap_size: KvSize, } @@ -710,9 +792,11 @@ impl EstimateSize for JoinEntryState { #[derive(Error, Debug)] pub enum JoinEntryError { #[error("double inserting a join state entry")] - OccupiedError, + Occupied, #[error("removing a join state entry but it is not in the cache")] - RemoveError, + Remove, + #[error("retrieving a pk from the inequality index but it is not in the cache")] + InequalIndex, } impl JoinEntryState { @@ -721,11 +805,15 @@ impl JoinEntryState { &mut self, key: PkType, value: StateValueType, + inequality_key: Option, ) -> Result<&mut StateValueType, JoinEntryError> { let mut removed = false; if !enable_strict_consistency() { // strict consistency is off, let's remove existing (if any) first if let Some(old_value) = self.cached.remove(&key) { + if let Some(inequality_key) = inequality_key.as_ref() { + self.remove_pk_from_inequality_index(&key, inequality_key); + } self.kv_heap_size.sub(&key, &old_value); removed = true; } @@ -733,6 +821,9 @@ impl JoinEntryState { self.kv_heap_size.add(&key, &value); + if let Some(inequality_key) = inequality_key { + self.insert_pk_to_inequality_index(key.clone(), inequality_key); + } let ret = self.cached.try_insert(key.clone(), value); if !enable_strict_consistency() { @@ -743,22 +834,77 @@ impl JoinEntryState { } } - ret.map_err(|_| JoinEntryError::OccupiedError) + ret.map_err(|_| JoinEntryError::Occupied) } /// Delete from the cache. - pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> { + pub fn remove( + &mut self, + pk: PkType, + inequality_key: Option<&InequalKeyType>, + ) -> Result<(), JoinEntryError> { if let Some(value) = self.cached.remove(&pk) { self.kv_heap_size.sub(&pk, &value); + if let Some(inequality_key) = inequality_key { + self.remove_pk_from_inequality_index(&pk, inequality_key); + } Ok(()) } else if enable_strict_consistency() { - Err(JoinEntryError::RemoveError) + Err(JoinEntryError::Remove) } else { consistency_error!(?pk, "removing a join state entry but it's not in the cache"); Ok(()) } } + fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) { + if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) { + if pk_set.remove(pk).is_none() { + if enable_strict_consistency() { + panic!("removing a pk that it not in the inequality index"); + } else { + consistency_error!(?pk, "removing a pk that it not in the inequality index"); + }; + } else { + self.kv_heap_size.sub(pk, &()); + } + if pk_set.is_empty() { + self.inequality_index.remove(inequality_key); + } + } + } + + fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) { + if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) { + let pk_size = pk.estimated_size(); + if pk_set.try_insert(pk, ()).is_err() { + if enable_strict_consistency() { + panic!("inserting a pk that it already in the inequality index"); + } else { + consistency_error!("inserting a pk that it already in the inequality index"); + }; + } else { + self.kv_heap_size.add_size(pk_size); + } + } else { + let mut pk_set = JoinRowSet::default(); + pk_set.try_insert(pk, ()).unwrap(); + self.inequality_index + .try_insert(inequality_key, pk_set) + .unwrap(); + } + } + + pub fn get( + &self, + pk: &PkType, + data_types: &[DataType], + ) -> Option>> { + self.cached + .get(pk) + .map(|encoded| encoded.decode(data_types)) + } + /// Note: the first item in the tuple is the mutable reference to the value in this entry, while /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply /// the changes to the first item. @@ -782,6 +928,92 @@ impl JoinEntryState { pub fn len(&self) -> usize { self.cached.len() } + + /// Range scan the cache using the inequality index. + pub fn range_by_inequality<'a, R>( + &'a self, + range: R, + data_types: &'a [DataType], + ) -> impl Iterator>> + 'a + where + R: RangeBounds + 'a, + { + self.inequality_index.range(range).flat_map(|(_, pk_set)| { + pk_set + .keys() + .flat_map(|pk| self.get_by_indexed_pk(pk, data_types)) + }) + } + + /// Get the records whose inequality key upper bound satisfy the given bound. + pub fn upper_bound_by_inequality<'a>( + &'a self, + bound: Bound<&InequalKeyType>, + data_types: &'a [DataType], + ) -> Option>> { + if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) { + if let Some(pk) = pk_set.first_key_sorted() { + self.get_by_indexed_pk(pk, data_types) + } else { + panic!("pk set for a index record must has at least one element"); + } + } else { + None + } + } + + pub fn get_by_indexed_pk( + &self, + pk: &PkType, + data_types: &[DataType], + ) -> Option>> +where { + if let Some(value) = self.cached.get(pk) { + Some(value.decode(data_types)) + } else if enable_strict_consistency() { + Some(Err(anyhow!(JoinEntryError::InequalIndex).into())) + } else { + consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report()); + None + } + } + + /// Get the records whose inequality key lower bound satisfy the given bound. + pub fn lower_bound_by_inequality<'a>( + &'a self, + bound: Bound<&InequalKeyType>, + data_types: &'a [DataType], + ) -> Option>> { + if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) { + if let Some(pk) = pk_set.first_key_sorted() { + self.get_by_indexed_pk(pk, data_types) + } else { + panic!("pk set for a index record must has at least one element"); + } + } else { + None + } + } + + pub fn get_first_by_inequality<'a>( + &'a self, + inequality_key: &InequalKeyType, + data_types: &'a [DataType], + ) -> Option>> { + if let Some(pk_set) = self.inequality_index.get(inequality_key) { + if let Some(pk) = pk_set.first_key_sorted() { + self.get_by_indexed_pk(pk, data_types) + } else { + panic!("pk set for a index record must has at least one element"); + } + } else { + None + } + } + + pub fn inequality_index(&self) -> &JoinRowSet> { + &self.inequality_index + } } #[cfg(test)] @@ -795,16 +1027,36 @@ mod tests { fn insert_chunk( managed_state: &mut JoinEntryState, pk_indices: &[usize], + col_types: &[DataType], + inequality_key_idx: Option, data_chunk: &DataChunk, ) { + let pk_col_type = pk_indices + .iter() + .map(|idx| col_types[*idx].clone()) + .collect_vec(); + let pk_serializer = + OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]); + let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone()); + let inequality_key_serializer = inequality_key_type + .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()])); for row_ref in data_chunk.rows() { let row: OwnedRow = row_ref.into_owned_row(); let value_indices = (0..row.len() - 1).collect_vec(); let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec(); // Pk is only a `i64` here, so encoding method does not matter. - let pk = OwnedRow::new(pk).project(&value_indices).value_serialize(); + let pk = OwnedRow::new(pk) + .project(&value_indices) + .memcmp_serialize(&pk_serializer); + let inequality_key = inequality_key_idx.map(|idx| { + (&row) + .project(&[idx]) + .memcmp_serialize(inequality_key_serializer.as_ref().unwrap()) + }); let join_row = JoinRow { row, degree: 0 }; - managed_state.insert(pk, join_row.encode()).unwrap(); + managed_state + .insert(pk, join_row.encode(), inequality_key) + .unwrap(); } } @@ -826,7 +1078,7 @@ mod tests { } #[tokio::test] - async fn test_managed_all_or_none_state() { + async fn test_managed_join_state() { let mut managed_state = JoinEntryState::default(); let col_types = vec![DataType::Int64, DataType::Int64]; let pk_indices = [0]; @@ -841,7 +1093,13 @@ mod tests { ); // `Vec` in state - insert_chunk(&mut managed_state, &pk_indices, &data_chunk1); + insert_chunk( + &mut managed_state, + &pk_indices, + &col_types, + None, + &data_chunk1, + ); check(&mut managed_state, &col_types, &col1, &col2); // `BtreeMap` in state @@ -852,7 +1110,76 @@ mod tests { 5 8 4 9", ); - insert_chunk(&mut managed_state, &pk_indices, &data_chunk2); + insert_chunk( + &mut managed_state, + &pk_indices, + &col_types, + None, + &data_chunk2, + ); check(&mut managed_state, &col_types, &col1, &col2); } + + #[tokio::test] + async fn test_managed_join_state_w_inequality_index() { + let mut managed_state = JoinEntryState::default(); + let col_types = vec![DataType::Int64, DataType::Int64]; + let pk_indices = [0]; + let inequality_key_idx = Some(1); + let inequality_key_serializer = + OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]); + + let col1 = [3, 2, 1]; + let col2 = [4, 5, 5]; + let data_chunk1 = DataChunk::from_pretty( + "I I + 3 4 + 2 5 + 1 5", + ); + + // `Vec` in state + insert_chunk( + &mut managed_state, + &pk_indices, + &col_types, + inequality_key_idx, + &data_chunk1, + ); + check(&mut managed_state, &col_types, &col1, &col2); + let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))]) + .memcmp_serialize(&inequality_key_serializer); + let row = managed_state + .upper_bound_by_inequality(Bound::Included(&bound), &col_types) + .unwrap() + .unwrap(); + assert_eq!(row.row[0], Some(ScalarImpl::Int64(1))); + let row = managed_state + .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types) + .unwrap() + .unwrap(); + assert_eq!(row.row[0], Some(ScalarImpl::Int64(3))); + + // `BtreeMap` in state + let col1 = [1, 2, 3, 4, 5]; + let col2 = [5, 5, 4, 4, 8]; + let data_chunk2 = DataChunk::from_pretty( + "I I + 5 8 + 4 4", + ); + insert_chunk( + &mut managed_state, + &pk_indices, + &col_types, + inequality_key_idx, + &data_chunk2, + ); + check(&mut managed_state, &col_types, &col1, &col2); + + let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))]) + .memcmp_serialize(&inequality_key_serializer); + let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types); + assert!(row.is_none()); + } } diff --git a/src/stream/src/executor/join/join_row_set.rs b/src/stream/src/executor/join/join_row_set.rs index de6f5ce2f0279..b34e163410eec 100644 --- a/src/stream/src/executor/join/join_row_set.rs +++ b/src/stream/src/executor/join/join_row_set.rs @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Borrow; use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError; use std::collections::BTreeMap; use std::fmt::Debug; use std::mem; +use std::ops::{Bound, RangeBounds}; use auto_enums::auto_enum; use enum_as_inner::EnumAsInner; @@ -110,6 +112,13 @@ impl JoinRowSet { } } + pub fn is_empty(&self) -> bool { + match self { + Self::BTree(inner) => inner.is_empty(), + Self::Vec(inner) => inner.is_empty(), + } + } + #[auto_enum(Iterator)] pub fn values_mut(&mut self) -> impl Iterator { match self { @@ -117,4 +126,161 @@ impl JoinRowSet { Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v), } } + + #[auto_enum(Iterator)] + pub fn keys(&self) -> impl Iterator { + match self { + Self::BTree(inner) => inner.keys(), + Self::Vec(inner) => inner.iter().map(|(k, _v)| k), + } + } + + #[auto_enum(Iterator)] + pub fn range(&self, range: R) -> impl Iterator + where + T: Ord + ?Sized, + K: Borrow + Ord, + R: RangeBounds, + { + match self { + Self::BTree(inner) => inner.range(range), + Self::Vec(inner) => inner + .iter() + .filter(move |(k, _)| range.contains(k.borrow())) + .map(|(k, v)| (k, v)), + } + } + + pub fn lower_bound_key(&self, bound: Bound<&K>) -> Option<&K> { + self.lower_bound(bound).map(|(k, _v)| k) + } + + pub fn upper_bound_key(&self, bound: Bound<&K>) -> Option<&K> { + self.upper_bound(bound).map(|(k, _v)| k) + } + + pub fn lower_bound(&self, bound: Bound<&K>) -> Option<(&K, &V)> { + match self { + Self::BTree(inner) => inner.lower_bound(bound).next(), + Self::Vec(inner) => inner + .iter() + .filter(|(k, _)| (bound, Bound::Unbounded).contains(k)) + .min_by_key(|(k, _)| k) + .map(|(k, v)| (k, v)), + } + } + + pub fn upper_bound(&self, bound: Bound<&K>) -> Option<(&K, &V)> { + match self { + Self::BTree(inner) => inner.upper_bound(bound).prev(), + Self::Vec(inner) => inner + .iter() + .filter(|(k, _)| (Bound::Unbounded, bound).contains(k)) + .max_by_key(|(k, _)| k) + .map(|(k, v)| (k, v)), + } + } + + pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { + match self { + Self::BTree(inner) => inner.get_mut(key), + Self::Vec(inner) => inner.iter_mut().find(|(k, _)| k == key).map(|(_, v)| v), + } + } + + pub fn get(&self, key: &K) -> Option<&V> { + match self { + Self::BTree(inner) => inner.get(key), + Self::Vec(inner) => inner.iter().find(|(k, _)| k == key).map(|(_, v)| v), + } + } + + /// Returns the key-value pair with smallest key in the map. + pub fn first_key_sorted(&self) -> Option<&K> { + match self { + Self::BTree(inner) => inner.first_key_value().map(|(k, _)| k), + Self::Vec(inner) => inner.iter().map(|(k, _)| k).min(), + } + } + + /// Returns the key-value pair with the second smallest key in the map. + pub fn second_key_sorted(&self) -> Option<&K> { + match self { + Self::BTree(inner) => inner.iter().nth(1).map(|(k, _)| k), + Self::Vec(inner) => { + let mut res = None; + let mut smallest = None; + for (k, _) in inner { + if let Some(smallest_k) = smallest { + if k < smallest_k { + res = Some(smallest_k); + smallest = Some(k); + } else if let Some(res_k) = res { + if k < res_k { + res = Some(k); + } + } else { + res = Some(k); + } + } else { + smallest = Some(k); + } + } + res + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_join_row_set_bounds() { + let mut join_row_set: JoinRowSet = JoinRowSet::default(); + + // Insert elements + assert!(join_row_set.try_insert(1, 10).is_ok()); + assert!(join_row_set.try_insert(2, 20).is_ok()); + assert!(join_row_set.try_insert(3, 30).is_ok()); + + // Check lower bound + assert_eq!(join_row_set.lower_bound_key(Bound::Included(&2)), Some(&2)); + assert_eq!(join_row_set.lower_bound_key(Bound::Excluded(&2)), Some(&3)); + + // Check upper bound + assert_eq!(join_row_set.upper_bound_key(Bound::Included(&2)), Some(&2)); + assert_eq!(join_row_set.upper_bound_key(Bound::Excluded(&2)), Some(&1)); + } + + #[test] + fn test_join_row_set_first_and_second_key_sorted() { + { + let mut join_row_set: JoinRowSet = JoinRowSet::default(); + + // Insert elements + assert!(join_row_set.try_insert(3, 30).is_ok()); + assert!(join_row_set.try_insert(1, 10).is_ok()); + assert!(join_row_set.try_insert(2, 20).is_ok()); + + // Check first key sorted + assert_eq!(join_row_set.first_key_sorted(), Some(&1)); + + // Check second key sorted + assert_eq!(join_row_set.second_key_sorted(), Some(&2)); + } + { + let mut join_row_set: JoinRowSet = JoinRowSet::default(); + + // Insert elements + assert!(join_row_set.try_insert(1, 10).is_ok()); + assert!(join_row_set.try_insert(2, 20).is_ok()); + + // Check first key sorted + assert_eq!(join_row_set.first_key_sorted(), Some(&1)); + + // Check second key sorted + assert_eq!(join_row_set.second_key_sorted(), Some(&2)); + } + } } diff --git a/src/stream/src/executor/join/mod.rs b/src/stream/src/executor/join/mod.rs index b8bd5ff84d95f..ea53a7992f265 100644 --- a/src/stream/src/executor/join/mod.rs +++ b/src/stream/src/executor/join/mod.rs @@ -12,6 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_expr::bail; +use risingwave_pb::plan_common::{AsOfJoinDesc, AsOfJoinInequalityType}; + +use crate::error::StreamResult; + pub mod builder; pub mod hash_join; pub mod join_row_set; @@ -35,6 +40,15 @@ pub mod JoinType { pub const RightAnti: JoinTypePrimitive = 7; } +pub type AsOfJoinTypePrimitive = u8; + +#[allow(non_snake_case, non_upper_case_globals)] +pub mod AsOfJoinType { + use super::AsOfJoinTypePrimitive; + pub const Inner: AsOfJoinTypePrimitive = 0; + pub const LeftOuter: AsOfJoinTypePrimitive = 1; +} + pub type SideTypePrimitive = u8; #[allow(non_snake_case, non_upper_case_globals)] pub mod SideType { @@ -43,6 +57,38 @@ pub mod SideType { pub const Right: SideTypePrimitive = 1; } +pub enum AsOfInequalityType { + Le, + Lt, + Ge, + Gt, +} + +pub struct AsOfDesc { + pub left_idx: usize, + pub right_idx: usize, + pub inequality_type: AsOfInequalityType, +} + +impl AsOfDesc { + pub fn from_protobuf(desc_proto: &AsOfJoinDesc) -> StreamResult { + let typ = match desc_proto.inequality_type() { + AsOfJoinInequalityType::AsOfInequalityTypeLt => AsOfInequalityType::Lt, + AsOfJoinInequalityType::AsOfInequalityTypeLe => AsOfInequalityType::Le, + AsOfJoinInequalityType::AsOfInequalityTypeGt => AsOfInequalityType::Gt, + AsOfJoinInequalityType::AsOfInequalityTypeGe => AsOfInequalityType::Ge, + AsOfJoinInequalityType::AsOfInequalityTypeUnspecified => { + bail!("unspecified AsOf join inequality type") + } + }; + Ok(Self { + left_idx: desc_proto.left_idx as usize, + right_idx: desc_proto.right_idx as usize, + inequality_type: typ, + }) + } +} + pub const fn is_outer_side(join_type: JoinTypePrimitive, side_type: SideTypePrimitive) -> bool { join_type == JoinType::FullOuter || (join_type == JoinType::LeftOuter && side_type == SideType::Left) @@ -106,3 +152,7 @@ pub const fn need_right_degree(join_type: JoinTypePrimitive) -> bool { || join_type == JoinType::RightAnti || join_type == JoinType::RightSemi } + +pub const fn is_as_of_left_outer(join_type: AsOfJoinTypePrimitive) -> bool { + join_type == AsOfJoinType::LeftOuter +} diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 8b9f7b3f2242b..3d1ca35b6d610 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -57,6 +57,7 @@ pub mod monitor; pub mod agg_common; pub mod aggregation; +pub mod asof_join; mod backfill; mod barrier_recv; mod batch_query; @@ -133,7 +134,7 @@ pub use filter::FilterExecutor; pub use hash_agg::HashAggExecutor; pub use hash_join::*; pub use hop_window::HopWindowExecutor; -pub use join::JoinType; +pub use join::{AsOfDesc, AsOfJoinType, JoinType}; pub use lookup::*; pub use lookup_union::LookupUnionExecutor; pub use merge::MergeExecutor; diff --git a/src/stream/src/from_proto/asof_join.rs b/src/stream/src/from_proto/asof_join.rs new file mode 100644 index 0000000000000..3d74ac884b4f0 --- /dev/null +++ b/src/stream/src/from_proto/asof_join.rs @@ -0,0 +1,192 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use risingwave_common::hash::{HashKey, HashKeyDispatcher}; +use risingwave_common::types::DataType; +use risingwave_pb::plan_common::AsOfJoinType as JoinTypeProto; +use risingwave_pb::stream_plan::AsOfJoinNode; + +use super::*; +use crate::common::table::state_table::StateTable; +use crate::executor::asof_join::*; +use crate::executor::monitor::StreamingMetrics; +use crate::executor::{ActorContextRef, AsOfDesc, AsOfJoinType, JoinType}; +use crate::task::AtomicU64Ref; + +pub struct AsOfJoinExecutorBuilder; + +impl ExecutorBuilder for AsOfJoinExecutorBuilder { + type Node = AsOfJoinNode; + + async fn new_boxed_executor( + params: ExecutorParams, + node: &Self::Node, + store: impl StateStore, + ) -> StreamResult { + // This assert is to make sure AsOf join can use `JoinChunkBuilder` as Hash join. + assert_eq!(AsOfJoinType::Inner, JoinType::Inner); + assert_eq!(AsOfJoinType::LeftOuter, JoinType::LeftOuter); + let vnodes = Arc::new(params.vnode_bitmap.expect("vnodes not set for AsOf join")); + + let [source_l, source_r]: [_; 2] = params.input.try_into().unwrap(); + + let table_l = node.get_left_table()?; + let degree_table_l = node.get_left_degree_table()?; + + let table_r = node.get_right_table()?; + let degree_table_r = node.get_right_degree_table()?; + + let params_l = JoinParams::new( + node.get_left_key() + .iter() + .map(|key| *key as usize) + .collect_vec(), + node.get_left_deduped_input_pk_indices() + .iter() + .map(|key| *key as usize) + .collect_vec(), + ); + let params_r = JoinParams::new( + node.get_right_key() + .iter() + .map(|key| *key as usize) + .collect_vec(), + node.get_right_deduped_input_pk_indices() + .iter() + .map(|key| *key as usize) + .collect_vec(), + ); + let null_safe = node.get_null_safe().to_vec(); + let output_indices = node + .get_output_indices() + .iter() + .map(|&x| x as usize) + .collect_vec(); + + let join_key_data_types = params_l + .join_key_indices + .iter() + .map(|idx| source_l.schema().fields[*idx].data_type()) + .collect_vec(); + + let state_table_l = + StateTable::from_table_catalog(table_l, store.clone(), Some(vnodes.clone())).await; + let degree_state_table_l = + StateTable::from_table_catalog(degree_table_l, store.clone(), Some(vnodes.clone())) + .await; + + let state_table_r = + StateTable::from_table_catalog(table_r, store.clone(), Some(vnodes.clone())).await; + let degree_state_table_r = + StateTable::from_table_catalog(degree_table_r, store, Some(vnodes)).await; + + let join_type_proto = node.get_join_type()?; + let as_of_desc_proto = node.get_asof_desc()?; + let asof_desc = AsOfDesc::from_protobuf(as_of_desc_proto)?; + + let args = AsOfJoinExecutorDispatcherArgs { + ctx: params.actor_context, + info: params.info.clone(), + source_l, + source_r, + params_l, + params_r, + null_safe, + output_indices, + state_table_l, + degree_state_table_l, + state_table_r, + degree_state_table_r, + lru_manager: params.watermark_epoch, + metrics: params.executor_stats, + join_type_proto, + join_key_data_types, + chunk_size: params.env.config().developer.chunk_size, + high_join_amplification_threshold: params + .env + .config() + .developer + .high_join_amplification_threshold, + asof_desc, + }; + + let exec = args.dispatch()?; + Ok((params.info, exec).into()) + } +} + +struct AsOfJoinExecutorDispatcherArgs { + ctx: ActorContextRef, + info: ExecutorInfo, + source_l: Executor, + source_r: Executor, + params_l: JoinParams, + params_r: JoinParams, + null_safe: Vec, + output_indices: Vec, + state_table_l: StateTable, + degree_state_table_l: StateTable, + state_table_r: StateTable, + degree_state_table_r: StateTable, + lru_manager: AtomicU64Ref, + metrics: Arc, + join_type_proto: JoinTypeProto, + join_key_data_types: Vec, + chunk_size: usize, + high_join_amplification_threshold: usize, + asof_desc: AsOfDesc, +} + +impl HashKeyDispatcher for AsOfJoinExecutorDispatcherArgs { + type Output = StreamResult>; + + fn dispatch_impl(self) -> Self::Output { + /// This macro helps to fill the const generic type parameter. + macro_rules! build { + ($join_type:ident) => { + Ok(AsOfJoinExecutor::::new( + self.ctx, + self.info, + self.source_l, + self.source_r, + self.params_l, + self.params_r, + self.null_safe, + self.output_indices, + self.state_table_l, + self.degree_state_table_l, + self.state_table_r, + self.degree_state_table_r, + self.lru_manager, + self.metrics, + self.chunk_size, + self.high_join_amplification_threshold, + self.asof_desc, + ) + .boxed()) + }; + } + match self.join_type_proto { + JoinTypeProto::Unspecified => unreachable!(), + JoinTypeProto::Inner => build!(Inner), + JoinTypeProto::LeftOuter => build!(LeftOuter), + } + } + + fn data_types(&self) -> &[DataType] { + &self.join_key_data_types + } +} diff --git a/src/stream/src/from_proto/mod.rs b/src/stream/src/from_proto/mod.rs index 6f185695eadf7..1f63b6cd5db85 100644 --- a/src/stream/src/from_proto/mod.rs +++ b/src/stream/src/from_proto/mod.rs @@ -16,6 +16,7 @@ mod agg_common; mod append_only_dedup; +mod asof_join; mod barrier_recv; mod batch_query; mod cdc_filter;