From b164efc73f596659c4b594a2d4b609e8b028b5d8 Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Tue, 30 Jan 2024 15:47:22 +0800 Subject: [PATCH 1/2] refactor --- src/stream/src/common/mod.rs | 3 +- src/stream/src/executor/hash_join.rs | 222 +----- src/stream/src/executor/join/builder.rs | 276 +++++++ src/stream/src/executor/join/hash_join.rs | 769 +++++++++++++++++++ src/stream/src/executor/join/join_row_set.rs | 120 +++ src/stream/src/executor/join/mod.rs | 108 +++ src/stream/src/executor/join/row.rs | 82 ++ src/stream/src/executor/lookup/impl_.rs | 2 +- src/stream/src/executor/managed_state/mod.rs | 15 - src/stream/src/executor/mod.rs | 3 +- src/stream/src/executor/temporal_join.rs | 5 +- src/stream/src/from_proto/hash_join.rs | 2 +- 12 files changed, 1371 insertions(+), 236 deletions(-) create mode 100644 src/stream/src/executor/join/builder.rs create mode 100644 src/stream/src/executor/join/hash_join.rs create mode 100644 src/stream/src/executor/join/join_row_set.rs create mode 100644 src/stream/src/executor/join/mod.rs create mode 100644 src/stream/src/executor/join/row.rs delete mode 100644 src/stream/src/executor/managed_state/mod.rs diff --git a/src/stream/src/common/mod.rs b/src/stream/src/common/mod.rs index 18129884db16c..c12fb5de1fbab 100644 --- a/src/stream/src/common/mod.rs +++ b/src/stream/src/common/mod.rs @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub use builder::*; pub use column_mapping::*; +pub use risingwave_common::array::stream_chunk_builder::StreamChunkBuilder; -mod builder; pub mod cache; mod column_mapping; pub mod log_store_impl; diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index af35e7c7b9603..471d382a763d0 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -33,10 +33,12 @@ use risingwave_expr::ExprError; use risingwave_storage::StateStore; use tokio::time::Instant; -use self::JoinType::{FullOuter, LeftOuter, LeftSemi, RightAnti, RightOuter, RightSemi}; +use self::builder::JoinChunkBuilder; use super::barrier_align::*; use super::error::{StreamExecutorError, StreamExecutorResult}; -use super::managed_state::join::*; +use super::join::hash_join::*; +use super::join::row::JoinRow; +use super::join::{JoinTypePrimitive, SideTypePrimitive, *}; use super::monitor::StreamingMetrics; use super::watermark::*; use super::{ @@ -44,101 +46,13 @@ use super::{ PkIndicesRef, Watermark, }; use crate::common::table::state_table::StateTable; -use crate::common::JoinStreamChunkBuilder; use crate::executor::expect_first_barrier_from_aligned_stream; -use crate::executor::JoinType::LeftAnti; +use crate::executor::join::builder::JoinStreamChunkBuilder; use crate::task::AtomicU64Ref; -/// The `JoinType` and `SideType` are to mimic a enum, because currently -/// enum is not supported in const generic. -// TODO: Use enum to replace this once [feature(adt_const_params)](https://github.com/rust-lang/rust/issues/95174) get completed. -pub type JoinTypePrimitive = u8; - /// Evict the cache every n rows. const EVICT_EVERY_N_ROWS: u32 = 16; -#[allow(non_snake_case, non_upper_case_globals)] -pub mod JoinType { - use super::JoinTypePrimitive; - pub const Inner: JoinTypePrimitive = 0; - pub const LeftOuter: JoinTypePrimitive = 1; - pub const RightOuter: JoinTypePrimitive = 2; - pub const FullOuter: JoinTypePrimitive = 3; - pub const LeftSemi: JoinTypePrimitive = 4; - pub const LeftAnti: JoinTypePrimitive = 5; - pub const RightSemi: JoinTypePrimitive = 6; - pub const RightAnti: JoinTypePrimitive = 7; -} - -pub type SideTypePrimitive = u8; -#[allow(non_snake_case, non_upper_case_globals)] -pub mod SideType { - use super::SideTypePrimitive; - pub const Left: SideTypePrimitive = 0; - pub const Right: SideTypePrimitive = 1; -} - -const fn is_outer_side(join_type: JoinTypePrimitive, side_type: SideTypePrimitive) -> bool { - join_type == JoinType::FullOuter - || (join_type == JoinType::LeftOuter && side_type == SideType::Left) - || (join_type == JoinType::RightOuter && side_type == SideType::Right) -} - -const fn outer_side_null(join_type: JoinTypePrimitive, side_type: SideTypePrimitive) -> bool { - join_type == JoinType::FullOuter - || (join_type == JoinType::LeftOuter && side_type == SideType::Right) - || (join_type == JoinType::RightOuter && side_type == SideType::Left) -} - -/// Send the update only once if the join type is semi/anti and the update is the same side as the -/// join -const fn forward_exactly_once(join_type: JoinTypePrimitive, side_type: SideTypePrimitive) -> bool { - ((join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti) - && side_type == SideType::Left) - || ((join_type == JoinType::RightSemi || join_type == JoinType::RightAnti) - && side_type == SideType::Right) -} - -const fn only_forward_matched_side( - join_type: JoinTypePrimitive, - side_type: SideTypePrimitive, -) -> bool { - ((join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti) - && side_type == SideType::Right) - || ((join_type == JoinType::RightSemi || join_type == JoinType::RightAnti) - && side_type == SideType::Left) -} - -const fn is_semi(join_type: JoinTypePrimitive) -> bool { - join_type == JoinType::LeftSemi || join_type == JoinType::RightSemi -} - -const fn is_anti(join_type: JoinTypePrimitive) -> bool { - join_type == JoinType::LeftAnti || join_type == JoinType::RightAnti -} - -const fn is_left_semi_or_anti(join_type: JoinTypePrimitive) -> bool { - join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti -} - -const fn is_right_semi_or_anti(join_type: JoinTypePrimitive) -> bool { - join_type == JoinType::RightSemi || join_type == JoinType::RightAnti -} - -const fn need_left_degree(join_type: JoinTypePrimitive) -> bool { - join_type == FullOuter - || join_type == LeftOuter - || join_type == LeftAnti - || join_type == LeftSemi -} - -const fn need_right_degree(join_type: JoinTypePrimitive) -> bool { - join_type == FullOuter - || join_type == RightOuter - || join_type == RightAnti - || join_type == RightSemi -} - fn is_subset(vec1: Vec, vec2: Vec) -> bool { HashSet::::from_iter(vec1).is_subset(&vec2.into_iter().collect()) } @@ -295,10 +209,6 @@ impl Executor for HashJoi } } -struct HashJoinChunkBuilder { - stream_chunk_builder: JoinStreamChunkBuilder, -} - struct EqJoinArgs<'a, K: HashKey, S: StateStore> { ctx: &'a ActorContextRef, side_l: &'a mut JoinSide, @@ -312,121 +222,6 @@ struct EqJoinArgs<'a, K: HashKey, S: StateStore> { cnt_rows_received: &'a mut u32, } -impl HashJoinChunkBuilder { - fn with_match_on_insert( - &mut self, - row: &RowRef<'_>, - matched_row: &JoinRow, - ) -> Option { - // Left/Right Anti sides - if is_anti(T) { - if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { - self.stream_chunk_builder - .append_row_matched(Op::Delete, &matched_row.row) - } else { - None - } - // Left/Right Semi sides - } else if is_semi(T) { - if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { - self.stream_chunk_builder - .append_row_matched(Op::Insert, &matched_row.row) - } else { - None - } - // Outer sides - } else if matched_row.is_zero_degree() && outer_side_null(T, SIDE) { - // if the matched_row does not have any current matches - // `StreamChunkBuilder` guarantees that `UpdateDelete` will never - // issue an output chunk. - if self - .stream_chunk_builder - .append_row_matched(Op::UpdateDelete, &matched_row.row) - .is_some() - { - unreachable!("`Op::UpdateDelete` should not yield chunk"); - } - self.stream_chunk_builder - .append_row(Op::UpdateInsert, row, &matched_row.row) - // Inner sides - } else { - self.stream_chunk_builder - .append_row(Op::Insert, row, &matched_row.row) - } - } - - fn with_match_on_delete( - &mut self, - row: &RowRef<'_>, - matched_row: &JoinRow, - ) -> Option { - // Left/Right Anti sides - if is_anti(T) { - if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { - self.stream_chunk_builder - .append_row_matched(Op::Insert, &matched_row.row) - } else { - None - } - // Left/Right Semi sides - } else if is_semi(T) { - if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { - self.stream_chunk_builder - .append_row_matched(Op::Delete, &matched_row.row) - } else { - None - } - // Outer sides - } else if matched_row.is_zero_degree() && outer_side_null(T, SIDE) { - // if the matched_row does not have any current - // matches - if self - .stream_chunk_builder - .append_row(Op::UpdateDelete, row, &matched_row.row) - .is_some() - { - unreachable!("`Op::UpdateDelete` should not yield chunk"); - } - self.stream_chunk_builder - .append_row_matched(Op::UpdateInsert, &matched_row.row) - // Inner sides - } else { - // concat with the matched_row and append the new - // row - // FIXME: we always use `Op::Delete` here to avoid - // violating - // the assumption for U+ after U-. - self.stream_chunk_builder - .append_row(Op::Delete, row, &matched_row.row) - } - } - - #[inline] - fn forward_exactly_once_if_matched(&mut self, op: Op, row: RowRef<'_>) -> Option { - // if it's a semi join and the side needs to be maintained. - if is_semi(T) && forward_exactly_once(T, SIDE) { - self.stream_chunk_builder.append_row_update(op, row) - } else { - None - } - } - - #[inline] - fn forward_if_not_matched(&mut self, op: Op, row: RowRef<'_>) -> Option { - // if it's outer join or anti join and the side needs to be maintained. - if (is_anti(T) && forward_exactly_once(T, SIDE)) || is_outer_side(T, SIDE) { - self.stream_chunk_builder.append_row_update(op, row) - } else { - None - } - } - - #[inline] - fn take(&mut self) -> Option { - self.stream_chunk_builder.take() - } -} - impl HashJoinExecutor { #[allow(clippy::too_many_arguments)] pub fn new( @@ -1025,14 +820,13 @@ impl HashJoinExecutor { - stream_chunk_builder: JoinStreamChunkBuilder::new( + let mut hashjoin_chunk_builder = + JoinChunkBuilder::::new(JoinStreamChunkBuilder::new( chunk_size, actual_output_data_types.to_vec(), side_update.i2o_mapping.clone(), side_match.i2o_mapping.clone(), - ), - }; + )); let join_matched_join_keys = ctx .streaming_metrics diff --git a/src/stream/src/executor/join/builder.rs b/src/stream/src/executor/join/builder.rs new file mode 100644 index 0000000000000..ca34ba9cad9e8 --- /dev/null +++ b/src/stream/src/executor/join/builder.rs @@ -0,0 +1,276 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::stream_chunk_builder::StreamChunkBuilder; +use risingwave_common::array::{Op, RowRef, StreamChunk}; +use risingwave_common::row::{OwnedRow, Row}; +use risingwave_common::types::{DataType, DatumRef}; + +use self::row::JoinRow; +// Re-export `StreamChunkBuilder`. +use super::*; +use super::{JoinTypePrimitive, SideTypePrimitive}; + +type IndexMappings = Vec<(usize, usize)>; + +/// Build stream chunks with fixed chunk size from joined two sides of rows. +pub struct JoinStreamChunkBuilder { + builder: StreamChunkBuilder, + + /// The column index mapping from update side to output. + update_to_output: IndexMappings, + + /// The column index mapping from matched side to output. + matched_to_output: IndexMappings, +} + +impl JoinStreamChunkBuilder { + pub fn new( + chunk_size: usize, + data_types: Vec, + update_to_output: IndexMappings, + matched_to_output: IndexMappings, + ) -> Self { + Self { + builder: StreamChunkBuilder::new(chunk_size, data_types), + update_to_output, + matched_to_output, + } + } + + /// Get the mappings from left/right input indices to the output indices. The mappings can be + /// used to create [`JoinStreamChunkBuilder`] later. + /// + /// Please note the semantics of `update` and `matched` when creating the builder: either left + /// or right side can be `update` side or `matched` side, the key is to call the corresponding + /// append method once you passed `left_to_output`/`right_to_output` to + /// `update_to_output`/`matched_to_output`. + pub fn get_i2o_mapping( + output_indices: &[usize], + left_len: usize, + right_len: usize, + ) -> (IndexMappings, IndexMappings) { + let mut left_to_output = vec![]; + let mut right_to_output = vec![]; + + for (output_idx, &idx) in output_indices.iter().enumerate() { + if idx < left_len { + left_to_output.push((idx, output_idx)) + } else if idx >= left_len && idx < left_len + right_len { + right_to_output.push((idx - left_len, output_idx)); + } else { + unreachable!("output_indices out of bound") + } + } + (left_to_output, right_to_output) + } + + /// Append a row with coming update value and matched value. + /// + /// A [`StreamChunk`] will be returned when `size == capacity`. + #[must_use] + pub fn append_row( + &mut self, + op: Op, + row_update: impl Row, + row_matched: impl Row, + ) -> Option { + self.builder.append_iter( + op, + self.update_to_output + .iter() + .map(|&(update_idx, output_idx)| (output_idx, row_update.datum_at(update_idx))) + .chain( + self.matched_to_output + .iter() + .map(|&(matched_idx, output_idx)| { + (output_idx, row_matched.datum_at(matched_idx)) + }), + ), + ) + } + + /// Append a row with coming update value and fill the other side with null. + /// + /// A [`StreamChunk`] will be returned when `size == capacity`. + #[must_use] + pub fn append_row_update(&mut self, op: Op, row_update: impl Row) -> Option { + self.builder.append_iter( + op, + self.update_to_output + .iter() + .map(|&(update_idx, output_idx)| (output_idx, row_update.datum_at(update_idx))) + .chain( + self.matched_to_output + .iter() + .map(|&(_, output_idx)| (output_idx, DatumRef::None)), + ), + ) + } + + /// Append a row with matched value and fill the coming side with null. + /// + /// A [`StreamChunk`] will be returned when `size == capacity`. + #[must_use] + pub fn append_row_matched(&mut self, op: Op, row_matched: impl Row) -> Option { + self.builder.append_iter( + op, + self.update_to_output + .iter() + .map(|&(_, output_idx)| (output_idx, DatumRef::None)) + .chain( + self.matched_to_output + .iter() + .map(|&(matched_idx, output_idx)| { + (output_idx, row_matched.datum_at(matched_idx)) + }), + ), + ) + } + + /// Take out the remaining rows as a chunk. Return `None` if the builder is empty. + #[must_use] + pub fn take(&mut self) -> Option { + self.builder.take() + } +} + +pub struct JoinChunkBuilder { + stream_chunk_builder: JoinStreamChunkBuilder, +} + +impl JoinChunkBuilder { + pub fn new(stream_chunk_builder: JoinStreamChunkBuilder) -> Self { + Self { + stream_chunk_builder, + } + } + + pub fn with_match_on_insert( + &mut self, + row: &RowRef<'_>, + matched_row: &JoinRow, + ) -> Option { + // Left/Right Anti sides + if is_anti(T) { + if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { + self.stream_chunk_builder + .append_row_matched(Op::Delete, &matched_row.row) + } else { + None + } + // Left/Right Semi sides + } else if is_semi(T) { + if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { + self.stream_chunk_builder + .append_row_matched(Op::Insert, &matched_row.row) + } else { + None + } + // Outer sides + } else if matched_row.is_zero_degree() && outer_side_null(T, SIDE) { + // if the matched_row does not have any current matches + // `StreamChunkBuilder` guarantees that `UpdateDelete` will never + // issue an output chunk. + if self + .stream_chunk_builder + .append_row_matched(Op::UpdateDelete, &matched_row.row) + .is_some() + { + unreachable!("`Op::UpdateDelete` should not yield chunk"); + } + self.stream_chunk_builder + .append_row(Op::UpdateInsert, row, &matched_row.row) + // Inner sides + } else { + self.stream_chunk_builder + .append_row(Op::Insert, row, &matched_row.row) + } + } + + pub fn with_match_on_delete( + &mut self, + row: &RowRef<'_>, + matched_row: &JoinRow, + ) -> Option { + // Left/Right Anti sides + if is_anti(T) { + if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { + self.stream_chunk_builder + .append_row_matched(Op::Insert, &matched_row.row) + } else { + None + } + // Left/Right Semi sides + } else if is_semi(T) { + if matched_row.is_zero_degree() && only_forward_matched_side(T, SIDE) { + self.stream_chunk_builder + .append_row_matched(Op::Delete, &matched_row.row) + } else { + None + } + // Outer sides + } else if matched_row.is_zero_degree() && outer_side_null(T, SIDE) { + // if the matched_row does not have any current + // matches + if self + .stream_chunk_builder + .append_row(Op::UpdateDelete, row, &matched_row.row) + .is_some() + { + unreachable!("`Op::UpdateDelete` should not yield chunk"); + } + self.stream_chunk_builder + .append_row_matched(Op::UpdateInsert, &matched_row.row) + // Inner sides + } else { + // concat with the matched_row and append the new + // row + // FIXME: we always use `Op::Delete` here to avoid + // violating + // the assumption for U+ after U-. + self.stream_chunk_builder + .append_row(Op::Delete, row, &matched_row.row) + } + } + + #[inline] + pub fn forward_exactly_once_if_matched( + &mut self, + op: Op, + row: RowRef<'_>, + ) -> Option { + // if it's a semi join and the side needs to be maintained. + if is_semi(T) && forward_exactly_once(T, SIDE) { + self.stream_chunk_builder.append_row_update(op, row) + } else { + None + } + } + + #[inline] + pub fn forward_if_not_matched(&mut self, op: Op, row: RowRef<'_>) -> Option { + // if it's outer join or anti join and the side needs to be maintained. + if (is_anti(T) && forward_exactly_once(T, SIDE)) || is_outer_side(T, SIDE) { + self.stream_chunk_builder.append_row_update(op, row) + } else { + None + } + } + + #[inline] + pub fn take(&mut self) -> Option { + self.stream_chunk_builder.take() + } +} diff --git a/src/stream/src/executor/join/hash_join.rs b/src/stream/src/executor/join/hash_join.rs new file mode 100644 index 0000000000000..c3fbaeb83face --- /dev/null +++ b/src/stream/src/executor/join/hash_join.rs @@ -0,0 +1,769 @@ +use std::alloc::Global; +use std::ops::{Bound, Deref, DerefMut}; +use std::sync::Arc; + +use anyhow::Context; +use futures::future::{join, try_join}; +use futures::StreamExt; +use futures_async_stream::for_await; +use local_stats_alloc::{SharedStatsAlloc, StatsAlloc}; +use risingwave_common::buffer::Bitmap; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::hash::{HashKey, PrecomputedBuildHasher}; +use risingwave_common::metrics::LabelGuardedIntCounter; +use risingwave_common::row::{OwnedRow, Row, RowExt}; +use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::epoch::EpochPair; +use risingwave_common::util::row_serde::OrderedRowSerde; +use risingwave_common::util::sort_util::OrderType; +use risingwave_storage::store::PrefetchOptions; +use risingwave_storage::StateStore; + +use super::row::{DegreeType, EncodedJoinRow}; +use crate::cache::{new_with_hasher_in, ManagedLruCache}; +use crate::common::metrics::MetricsInfo; +use crate::common::table::state_table::StateTable; +use crate::executor::error::StreamExecutorResult; +use crate::executor::join::row::JoinRow; +use crate::executor::monitor::StreamingMetrics; +use crate::task::{ActorId, AtomicU64Ref, FragmentId}; + +/// Memcomparable encoding. +type PkType = Vec; + +pub type StateValueType = EncodedJoinRow; +pub type HashValueType = Box; + +impl EstimateSize for HashValueType { + fn estimated_heap_size(&self) -> usize { + self.as_ref().estimated_heap_size() + } +} + +/// The wrapper for [`JoinEntryState`] which should be `Some` most of the time in the hash table. +/// +/// When the executor is operating on the specific entry of the map, it can hold the ownership of +/// the entry by taking the value out of the `Option`, instead of holding a mutable reference to the +/// map, which can make the compiler happy. +struct HashValueWrapper(Option); + +impl EstimateSize for HashValueWrapper { + fn estimated_heap_size(&self) -> usize { + self.0.estimated_heap_size() + } +} + +impl HashValueWrapper { + const MESSAGE: &'static str = "the state should always be `Some`"; + + /// Take the value out of the wrapper. Panic if the value is `None`. + pub fn take(&mut self) -> HashValueType { + self.0.take().expect(Self::MESSAGE) + } +} + +impl Deref for HashValueWrapper { + type Target = HashValueType; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().expect(Self::MESSAGE) + } +} + +impl DerefMut for HashValueWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut().expect(Self::MESSAGE) + } +} + +type JoinHashMapInner = + ManagedLruCache>; + +pub struct JoinHashMapMetrics { + /// Basic information + /// How many times have we hit the cache of join executor + lookup_miss_count: usize, + total_lookup_count: usize, + /// How many times have we miss the cache when insert row + insert_cache_miss_count: usize, + + // Metrics + join_lookup_total_count_metric: LabelGuardedIntCounter<5>, + join_lookup_miss_count_metric: LabelGuardedIntCounter<5>, + join_insert_cache_miss_count_metrics: LabelGuardedIntCounter<5>, +} + +impl JoinHashMapMetrics { + pub fn new( + metrics: &StreamingMetrics, + actor_id: ActorId, + fragment_id: FragmentId, + side: &'static str, + join_table_id: u32, + degree_table_id: u32, + ) -> Self { + let actor_id = actor_id.to_string(); + let fragment_id = fragment_id.to_string(); + let join_table_id = join_table_id.to_string(); + let degree_table_id = degree_table_id.to_string(); + let join_lookup_total_count_metric = + metrics.join_lookup_total_count.with_guarded_label_values(&[ + (side), + &join_table_id, + °ree_table_id, + &actor_id, + &fragment_id, + ]); + let join_lookup_miss_count_metric = + metrics.join_lookup_miss_count.with_guarded_label_values(&[ + (side), + &join_table_id, + °ree_table_id, + &actor_id, + &fragment_id, + ]); + let join_insert_cache_miss_count_metrics = metrics + .join_insert_cache_miss_count + .with_guarded_label_values(&[ + (side), + &join_table_id, + °ree_table_id, + &actor_id, + &fragment_id, + ]); + + Self { + lookup_miss_count: 0, + total_lookup_count: 0, + insert_cache_miss_count: 0, + join_lookup_total_count_metric, + join_lookup_miss_count_metric, + join_insert_cache_miss_count_metrics, + } + } + + pub fn flush(&mut self) { + self.join_lookup_total_count_metric + .inc_by(self.total_lookup_count as u64); + self.join_lookup_miss_count_metric + .inc_by(self.lookup_miss_count as u64); + self.join_insert_cache_miss_count_metrics + .inc_by(self.insert_cache_miss_count as u64); + self.total_lookup_count = 0; + self.lookup_miss_count = 0; + self.insert_cache_miss_count = 0; + } +} + +pub struct JoinHashMap { + /// Store the join states. + inner: JoinHashMapInner, + /// Data types of the join key columns + join_key_data_types: Vec, + /// Null safe bitmap for each join pair + null_matched: K::Bitmap, + /// The memcomparable serializer of primary key. + pk_serializer: OrderedRowSerde, + /// State table. Contains the data from upstream. + state: TableInner, + /// Degree table. + /// + /// The degree is generated from the hash join executor. + /// Each row in `state` has a corresponding degree in `degree state`. + /// A degree value `d` in for a row means the row has `d` matched row in the other join side. + /// + /// It will only be used when needed in a side. + /// + /// - Full Outer: both side + /// - Left Outer/Semi/Anti: left side + /// - Right Outer/Semi/Anti: right side + /// - Inner: None. + degree_state: TableInner, + /// If degree table is need + need_degree_table: bool, + /// Pk is part of the join key. + pk_contained_in_jk: bool, + /// Metrics of the hash map + metrics: JoinHashMapMetrics, +} + +struct TableInner { + /// Indices of the (cache) pk in a state row + pk_indices: Vec, + /// Indices of the join key in a state row + join_key_indices: Vec, + // This should be identical to the pk in state table. + order_key_indices: Vec, + // This should be identical to the data types in table schema. + #[expect(dead_code)] + all_data_types: Vec, + pub(crate) table: StateTable, +} + +impl TableInner { + fn error_context(&self, row: &impl Row) -> String { + let pk = row.project(&self.pk_indices); + let jk = row.project(&self.join_key_indices); + format!( + "join key: {}, pk: {}, row: {}, state_table_id: {}", + jk.display(), + pk.display(), + row.display(), + self.table.table_id() + ) + } +} + +impl JoinHashMap { + /// Create a [`JoinHashMap`] with the given LRU capacity. + #[allow(clippy::too_many_arguments)] + pub fn new( + watermark_epoch: AtomicU64Ref, + join_key_data_types: Vec, + state_join_key_indices: Vec, + state_all_data_types: Vec, + state_table: StateTable, + state_pk_indices: Vec, + degree_join_key_indices: Vec, + degree_all_data_types: Vec, + degree_table: StateTable, + degree_pk_indices: Vec, + null_matched: K::Bitmap, + need_degree_table: bool, + pk_contained_in_jk: bool, + metrics: Arc, + actor_id: ActorId, + fragment_id: FragmentId, + side: &'static str, + ) -> Self { + let alloc = StatsAlloc::new(Global).shared(); + // TODO: unify pk encoding with state table. + let pk_data_types = state_pk_indices + .iter() + .map(|i| state_all_data_types[*i].clone()) + .collect(); + let pk_serializer = OrderedRowSerde::new( + pk_data_types, + vec![OrderType::ascending(); state_pk_indices.len()], + ); + + let join_table_id = state_table.table_id(); + let degree_table_id = degree_table.table_id(); + let state = TableInner { + pk_indices: state_pk_indices, + join_key_indices: state_join_key_indices, + order_key_indices: state_table.pk_indices().to_vec(), + all_data_types: state_all_data_types, + table: state_table, + }; + + let degree_state = TableInner { + pk_indices: degree_pk_indices, + join_key_indices: degree_join_key_indices, + order_key_indices: degree_table.pk_indices().to_vec(), + all_data_types: degree_all_data_types, + table: degree_table, + }; + + let metrics_info = MetricsInfo::new( + metrics.clone(), + join_table_id, + actor_id, + &format!("hash join {}", side), + ); + + let cache = + new_with_hasher_in(watermark_epoch, metrics_info, PrecomputedBuildHasher, alloc); + + Self { + inner: cache, + join_key_data_types, + null_matched, + pk_serializer, + state, + degree_state, + need_degree_table, + pk_contained_in_jk, + metrics: JoinHashMapMetrics::new( + &metrics, + actor_id, + fragment_id, + side, + join_table_id, + degree_table_id, + ), + } + } + + pub fn init(&mut self, epoch: EpochPair) { + self.update_epoch(epoch.curr); + self.state.table.init_epoch(epoch); + self.degree_state.table.init_epoch(epoch); + } + + pub fn update_epoch(&mut self, epoch: u64) { + // Update the current epoch in `ManagedLruCache` + self.inner.update_epoch(epoch) + } + + /// Update the vnode bitmap and manipulate the cache if necessary. + pub fn update_vnode_bitmap(&mut self, vnode_bitmap: Arc) -> bool { + let (_previous_vnode_bitmap, cache_may_stale) = + self.state.table.update_vnode_bitmap(vnode_bitmap.clone()); + let _ = self.degree_state.table.update_vnode_bitmap(vnode_bitmap); + + if cache_may_stale { + self.inner.clear(); + } + + cache_may_stale + } + + pub fn update_watermark(&mut self, watermark: ScalarImpl) { + // TODO: remove data in cache. + self.state.table.update_watermark(watermark.clone(), false); + self.degree_state.table.update_watermark(watermark, false); + } + + /// Take the state for the given `key` out of the hash table and return it. One **MUST** call + /// `update_state` after some operations to put the state back. + /// + /// If the state does not exist in the cache, fetch the remote storage and return. If it still + /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be + /// returned. + /// + /// Note: This will NOT remove anything from remote storage. + pub async fn take_state<'a>(&mut self, key: &K) -> StreamExecutorResult { + self.metrics.total_lookup_count += 1; + let state = if self.inner.contains(key) { + // Do not update the LRU statistics here with `peek_mut` since we will put the state + // back. + let mut state = self.inner.peek_mut(key).unwrap(); + state.take() + } else { + self.metrics.lookup_miss_count += 1; + self.fetch_cached_state(key).await?.into() + }; + Ok(state) + } + + /// Fetch cache from the state store. Should only be called if the key does not exist in memory. + /// Will return a empty `JoinEntryState` even when state does not exist in remote. + async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult { + let key = key.deserialize(&self.join_key_data_types)?; + + let mut entry_state = JoinEntryState::default(); + + if self.need_degree_table { + let sub_range: &(Bound, Bound) = + &(Bound::Unbounded, Bound::Unbounded); + let table_iter_fut = + self.state + .table + .iter_with_prefix(&key, sub_range, PrefetchOptions::default()); + let degree_table_iter_fut = self.degree_state.table.iter_with_prefix( + &key, + sub_range, + PrefetchOptions::default(), + ); + + let (table_iter, degree_table_iter) = + try_join(table_iter_fut, degree_table_iter_fut).await?; + + let mut pinned_table_iter = std::pin::pin!(table_iter); + let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter); + loop { + // Iterate on both iterators and ensure they have same size. Basically `zip_eq()`. + let (row, degree) = + join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await; + let (row, degree) = match (row, degree) { + (None, None) => break, + (None, Some(_)) | (Some(_), None) => { + panic!("mismatched row and degree table of join key: {:?}", &key) + } + (Some(r), Some(d)) => (r, d), + }; + + let row = row?; + let degree_row = degree?; + let pk1 = row.key(); + let pk2 = degree_row.key(); + debug_assert_eq!( + pk1, pk2, + "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}", + ); + let pk = row + .as_ref() + .project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer); + let degree_i64 = degree_row + .datum_at(degree_row.len() - 1) + .expect("degree should not be NULL"); + entry_state + .insert( + pk, + JoinRow::new(row.row(), degree_i64.into_int64() as u64).encode(), + ) + .with_context(|| self.state.error_context(row.row()))?; + } + } else { + let sub_range: &(Bound, Bound) = + &(Bound::Unbounded, Bound::Unbounded); + let table_iter = self + .state + .table + .iter_with_prefix(&key, sub_range, PrefetchOptions::default()) + .await?; + + #[for_await] + for entry in table_iter { + let row = entry?; + let pk = row + .as_ref() + .project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer); + entry_state + .insert(pk, JoinRow::new(row.row(), 0).encode()) + .with_context(|| self.state.error_context(row.row()))?; + } + }; + + Ok(entry_state) + } + + pub async fn flush(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { + self.metrics.flush(); + self.state.table.commit(epoch).await?; + self.degree_state.table.commit(epoch).await?; + Ok(()) + } + + pub async fn try_flush(&mut self) -> StreamExecutorResult<()> { + self.state.table.try_flush().await?; + self.degree_state.table.try_flush().await?; + Ok(()) + } + + /// Insert a join row + #[allow(clippy::unused_async)] + pub async fn insert(&mut self, key: &K, value: JoinRow) -> StreamExecutorResult<()> { + let pk = (&value.row) + .project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer); + + // TODO(yuhao): avoid this `contains`. + // https://github.com/risingwavelabs/risingwave/issues/9233 + if self.inner.contains(key) { + // Update cache + let mut entry = self.inner.get_mut(key).unwrap(); + entry + .insert(pk, value.encode()) + .with_context(|| self.state.error_context(&value.row))?; + } else if self.pk_contained_in_jk { + // Refill cache when the join key exist in neither cache or storage. + self.metrics.insert_cache_miss_count += 1; + let mut state = JoinEntryState::default(); + state + .insert(pk, value.encode()) + .with_context(|| self.state.error_context(&value.row))?; + self.update_state(key, state.into()); + } + + // Update the flush buffer. + let (row, degree) = value.to_table_rows(&self.state.order_key_indices); + self.state.table.insert(row); + self.degree_state.table.insert(degree); + Ok(()) + } + + /// Insert a row. + /// Used when the side does not need to update degree. + #[allow(clippy::unused_async)] + pub async fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> { + let join_row = JoinRow::new(&value, 0); + let pk = (&value) + .project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer); + + // TODO(yuhao): avoid this `contains`. + // https://github.com/risingwavelabs/risingwave/issues/9233 + if self.inner.contains(key) { + // Update cache + let mut entry = self.inner.get_mut(key).unwrap(); + entry + .insert(pk, join_row.encode()) + .with_context(|| self.state.error_context(&value))?; + } else if self.pk_contained_in_jk { + // Refill cache when the join key exist in neither cache or storage. + self.metrics.insert_cache_miss_count += 1; + let mut state = JoinEntryState::default(); + state + .insert(pk, join_row.encode()) + .with_context(|| self.state.error_context(&value))?; + self.update_state(key, state.into()); + } + + // Update the flush buffer. + self.state.table.insert(value); + Ok(()) + } + + /// Delete a join row + pub fn delete(&mut self, key: &K, value: JoinRow) -> StreamExecutorResult<()> { + if let Some(mut entry) = self.inner.get_mut(key) { + let pk = (&value.row) + .project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer); + entry + .remove(pk) + .with_context(|| self.state.error_context(&value.row))?; + } + + // If no cache maintained, only update the state table. + let (row, degree) = value.to_table_rows(&self.state.order_key_indices); + self.state.table.delete(row); + self.degree_state.table.delete(degree); + Ok(()) + } + + /// Delete a row + /// Used when the side does not need to update degree. + pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> { + if let Some(mut entry) = self.inner.get_mut(key) { + let pk = (&value) + .project(&self.state.pk_indices) + .memcmp_serialize(&self.pk_serializer); + entry + .remove(pk) + .with_context(|| self.state.error_context(&value))?; + } + + // If no cache maintained, only update the state table. + self.state.table.delete(value); + Ok(()) + } + + /// Update a [`JoinEntryState`] into the hash table. + pub fn update_state(&mut self, key: &K, state: HashValueType) { + self.inner.put(key.clone(), HashValueWrapper(Some(state))); + } + + /// Manipulate the degree of the given [`JoinRow`] and [`EncodedJoinRow`] with `action`, both in + /// memory and in the degree table. + fn manipulate_degree( + &mut self, + join_row_ref: &mut StateValueType, + join_row: &mut JoinRow, + action: impl Fn(&mut DegreeType), + ) { + // TODO: no need to `into_owned_row` here due to partial borrow. + let old_degree = join_row + .to_table_rows(&self.state.order_key_indices) + .1 + .into_owned_row(); + + action(&mut join_row_ref.degree); + action(&mut join_row.degree); + + let new_degree = join_row.to_table_rows(&self.state.order_key_indices).1; + + self.degree_state.table.update(old_degree, new_degree); + } + + /// Increment the degree of the given [`JoinRow`] and [`EncodedJoinRow`] with `action`, both in + /// memory and in the degree table. + pub fn inc_degree( + &mut self, + join_row_ref: &mut StateValueType, + join_row: &mut JoinRow, + ) { + self.manipulate_degree(join_row_ref, join_row, |d| *d += 1) + } + + /// Decrement the degree of the given [`JoinRow`] and [`EncodedJoinRow`] with `action`, both in + /// memory and in the degree table. + pub fn dec_degree( + &mut self, + join_row_ref: &mut StateValueType, + join_row: &mut JoinRow, + ) { + self.manipulate_degree(join_row_ref, join_row, |d| { + *d = d + .checked_sub(1) + .expect("Tried to decrement zero join row degree") + }) + } + + /// Evict the cache. + pub fn evict(&mut self) { + self.inner.evict(); + } + + /// Cached entry count for this hash table. + pub fn entry_count(&self) -> usize { + self.inner.len() + } + + pub fn null_matched(&self) -> &K::Bitmap { + &self.null_matched + } + + pub fn table_id(&self) -> u32 { + self.state.table.table_id() + } +} + +use risingwave_common::estimate_size::KvSize; +use thiserror::Error; + +use super::*; + +/// We manages a `HashMap` in memory for all entries belonging to a join key. +/// When evicted, `cached` does not hold any entries. +/// +/// If a `JoinEntryState` exists for a join key, the all records under this +/// join key will be presented in the cache. +#[derive(Default)] +pub struct JoinEntryState { + /// The full copy of the state. + cached: join_row_set::JoinRowSet, + kv_heap_size: KvSize, +} + +impl EstimateSize for JoinEntryState { + fn estimated_heap_size(&self) -> usize { + // TODO: Add btreemap internal size. + // https://github.com/risingwavelabs/risingwave/issues/9713 + self.kv_heap_size.size() + } +} + +#[derive(Error, Debug)] +pub enum JoinEntryError { + #[error("double inserting a join state entry")] + OccupiedError, + #[error("removing a join state entry but it is not in the cache")] + RemoveError, +} + +impl JoinEntryState { + /// Insert into the cache. + pub fn insert( + &mut self, + key: PkType, + value: StateValueType, + ) -> Result<&mut StateValueType, JoinEntryError> { + self.kv_heap_size.add(&key, &value); + self.cached + .try_insert(key, value) + .map_err(|_| JoinEntryError::OccupiedError) + } + + /// Delete from the cache. + pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> { + if let Some(value) = self.cached.remove(&pk) { + self.kv_heap_size.sub(&pk, &value); + Ok(()) + } else { + Err(JoinEntryError::RemoveError) + } + } + + /// Note: the first item in the tuple is the mutable reference to the value in this entry, while + /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply + /// the changes to the first item. + /// + /// WARNING: Should not change the heap size of `StateValueType` with the mutable reference. + pub fn values_mut<'a>( + &'a mut self, + data_types: &'a [DataType], + ) -> impl Iterator< + Item = ( + &'a mut StateValueType, + StreamExecutorResult>, + ), + > + 'a { + self.cached.values_mut().map(|encoded| { + let decoded = encoded.decode(data_types); + (encoded, decoded) + }) + } + + pub fn len(&self) -> usize { + self.cached.len() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use risingwave_common::array::*; + use risingwave_common::types::{DataType, ScalarImpl}; + use risingwave_common::util::iter_util::ZipEqDebug; + + use super::*; + + fn insert_chunk( + managed_state: &mut JoinEntryState, + pk_indices: &[usize], + data_chunk: &DataChunk, + ) { + for row_ref in data_chunk.rows() { + let row: OwnedRow = row_ref.into_owned_row(); + let value_indices = (0..row.len() - 1).collect_vec(); + let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec(); + // Pk is only a `i64` here, so encoding method does not matter. + let pk = OwnedRow::new(pk).project(&value_indices).value_serialize(); + let join_row = JoinRow { row, degree: 0 }; + managed_state.insert(pk, join_row.encode()).unwrap(); + } + } + + fn check( + managed_state: &mut JoinEntryState, + col_types: &[DataType], + col1: &[i64], + col2: &[i64], + ) { + for ((_, matched_row), (d1, d2)) in managed_state + .values_mut(col_types) + .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter())) + { + let matched_row = matched_row.unwrap(); + assert_eq!(matched_row.row[0], Some(ScalarImpl::Int64(*d1))); + assert_eq!(matched_row.row[1], Some(ScalarImpl::Int64(*d2))); + assert_eq!(matched_row.degree, 0); + } + } + + #[tokio::test] + async fn test_managed_all_or_none_state() { + let mut managed_state = JoinEntryState::default(); + let col_types = vec![DataType::Int64, DataType::Int64]; + let pk_indices = [0]; + + let col1 = [3, 2, 1]; + let col2 = [4, 5, 6]; + let data_chunk1 = DataChunk::from_pretty( + "I I + 3 4 + 2 5 + 1 6", + ); + + // `Vec` in state + insert_chunk(&mut managed_state, &pk_indices, &data_chunk1); + check(&mut managed_state, &col_types, &col1, &col2); + + // `BtreeMap` in state + let col1 = [1, 2, 3, 4, 5]; + let col2 = [6, 5, 4, 9, 8]; + let data_chunk2 = DataChunk::from_pretty( + "I I + 5 8 + 4 9", + ); + insert_chunk(&mut managed_state, &pk_indices, &data_chunk2); + check(&mut managed_state, &col_types, &col1, &col2); + } +} diff --git a/src/stream/src/executor/join/join_row_set.rs b/src/stream/src/executor/join/join_row_set.rs new file mode 100644 index 0000000000000..de6f5ce2f0279 --- /dev/null +++ b/src/stream/src/executor/join/join_row_set.rs @@ -0,0 +1,120 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError; +use std::collections::BTreeMap; +use std::fmt::Debug; +use std::mem; + +use auto_enums::auto_enum; +use enum_as_inner::EnumAsInner; + +const MAX_VEC_SIZE: usize = 4; + +#[derive(Debug, EnumAsInner)] +pub enum JoinRowSet { + BTree(BTreeMap), + Vec(Vec<(K, V)>), +} + +impl Default for JoinRowSet { + fn default() -> Self { + Self::Vec(Vec::new()) + } +} + +#[derive(Debug)] +#[allow(dead_code)] +pub struct VecOccupiedError<'a, K, V> { + key: &'a K, + old_value: &'a V, + new_value: V, +} + +#[derive(Debug)] +pub enum JoinRowSetOccupiedError<'a, K: Ord, V> { + BTree(BTreeMapOccupiedError<'a, K, V>), + Vec(VecOccupiedError<'a, K, V>), +} + +impl JoinRowSet { + pub fn try_insert( + &mut self, + key: K, + value: V, + ) -> Result<&'_ mut V, JoinRowSetOccupiedError<'_, K, V>> { + if let Self::Vec(inner) = self + && inner.len() >= MAX_VEC_SIZE + { + let btree = BTreeMap::from_iter(inner.drain(..)); + mem::swap(self, &mut Self::BTree(btree)); + } + + match self { + Self::BTree(inner) => inner + .try_insert(key, value) + .map_err(JoinRowSetOccupiedError::BTree), + Self::Vec(inner) => { + if let Some(pos) = inner.iter().position(|elem| elem.0 == key) { + Err(JoinRowSetOccupiedError::Vec(VecOccupiedError { + key: &inner[pos].0, + old_value: &inner[pos].1, + new_value: value, + })) + } else { + if inner.capacity() == 0 { + // `Vec` will give capacity 4 when `1 < mem::size_of:: <= 1024` + // We only give one for memory optimization + inner.reserve_exact(1); + } + inner.push((key, value)); + Ok(&mut inner.last_mut().unwrap().1) + } + } + } + } + + pub fn remove(&mut self, key: &K) -> Option { + let ret = match self { + Self::BTree(inner) => inner.remove(key), + Self::Vec(inner) => inner + .iter() + .position(|elem| &elem.0 == key) + .map(|pos| inner.swap_remove(pos).1), + }; + if let Self::BTree(inner) = self + && inner.len() <= MAX_VEC_SIZE / 2 + { + let btree = mem::take(inner); + let vec = Vec::from_iter(btree); + mem::swap(self, &mut Self::Vec(vec)); + } + ret + } + + pub fn len(&self) -> usize { + match self { + Self::BTree(inner) => inner.len(), + Self::Vec(inner) => inner.len(), + } + } + + #[auto_enum(Iterator)] + pub fn values_mut(&mut self) -> impl Iterator { + match self { + Self::BTree(inner) => inner.values_mut(), + Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v), + } + } +} diff --git a/src/stream/src/executor/join/mod.rs b/src/stream/src/executor/join/mod.rs new file mode 100644 index 0000000000000..b8bd5ff84d95f --- /dev/null +++ b/src/stream/src/executor/join/mod.rs @@ -0,0 +1,108 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod builder; +pub mod hash_join; +pub mod join_row_set; +pub mod row; + +/// The `JoinType` and `SideType` are to mimic a enum, because currently +/// enum is not supported in const generic. +// TODO: Use enum to replace this once [feature(adt_const_params)](https://github.com/rust-lang/rust/issues/95174) get completed. +pub type JoinTypePrimitive = u8; + +#[allow(non_snake_case, non_upper_case_globals)] +pub mod JoinType { + use super::JoinTypePrimitive; + pub const Inner: JoinTypePrimitive = 0; + pub const LeftOuter: JoinTypePrimitive = 1; + pub const RightOuter: JoinTypePrimitive = 2; + pub const FullOuter: JoinTypePrimitive = 3; + pub const LeftSemi: JoinTypePrimitive = 4; + pub const LeftAnti: JoinTypePrimitive = 5; + pub const RightSemi: JoinTypePrimitive = 6; + pub const RightAnti: JoinTypePrimitive = 7; +} + +pub type SideTypePrimitive = u8; +#[allow(non_snake_case, non_upper_case_globals)] +pub mod SideType { + use super::SideTypePrimitive; + pub const Left: SideTypePrimitive = 0; + pub const Right: SideTypePrimitive = 1; +} + +pub const fn is_outer_side(join_type: JoinTypePrimitive, side_type: SideTypePrimitive) -> bool { + join_type == JoinType::FullOuter + || (join_type == JoinType::LeftOuter && side_type == SideType::Left) + || (join_type == JoinType::RightOuter && side_type == SideType::Right) +} + +pub const fn outer_side_null(join_type: JoinTypePrimitive, side_type: SideTypePrimitive) -> bool { + join_type == JoinType::FullOuter + || (join_type == JoinType::LeftOuter && side_type == SideType::Right) + || (join_type == JoinType::RightOuter && side_type == SideType::Left) +} + +/// Send the update only once if the join type is semi/anti and the update is the same side as the +/// join +pub const fn forward_exactly_once( + join_type: JoinTypePrimitive, + side_type: SideTypePrimitive, +) -> bool { + ((join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti) + && side_type == SideType::Left) + || ((join_type == JoinType::RightSemi || join_type == JoinType::RightAnti) + && side_type == SideType::Right) +} + +pub const fn only_forward_matched_side( + join_type: JoinTypePrimitive, + side_type: SideTypePrimitive, +) -> bool { + ((join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti) + && side_type == SideType::Right) + || ((join_type == JoinType::RightSemi || join_type == JoinType::RightAnti) + && side_type == SideType::Left) +} + +pub const fn is_semi(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::LeftSemi || join_type == JoinType::RightSemi +} + +pub const fn is_anti(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::LeftAnti || join_type == JoinType::RightAnti +} + +pub const fn is_left_semi_or_anti(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti +} + +pub const fn is_right_semi_or_anti(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::RightSemi || join_type == JoinType::RightAnti +} + +pub const fn need_left_degree(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::FullOuter + || join_type == JoinType::LeftOuter + || join_type == JoinType::LeftAnti + || join_type == JoinType::LeftSemi +} + +pub const fn need_right_degree(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::FullOuter + || join_type == JoinType::RightOuter + || join_type == JoinType::RightAnti + || join_type == JoinType::RightSemi +} diff --git a/src/stream/src/executor/join/row.rs b/src/stream/src/executor/join/row.rs new file mode 100644 index 0000000000000..9ab133fc314ba --- /dev/null +++ b/src/stream/src/executor/join/row.rs @@ -0,0 +1,82 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::row::{self, CompactedRow, OwnedRow, Row, RowExt}; +use risingwave_common::types::{DataType, ScalarImpl}; + +use crate::executor::StreamExecutorResult; + +/// This is a row with a match degree +#[derive(Clone, Debug)] +pub struct JoinRow { + pub row: R, + pub degree: DegreeType, +} + +impl JoinRow { + pub fn new(row: R, degree: DegreeType) -> Self { + Self { row, degree } + } + + pub fn is_zero_degree(&self) -> bool { + self.degree == 0 + } + + /// Return row and degree in `Row` format. The degree part will be inserted in degree table + /// later, so a pk prefix will be added. + /// + /// * `state_order_key_indices` - the order key of `row` + pub fn to_table_rows<'a>( + &'a self, + state_order_key_indices: &'a [usize], + ) -> (&'a R, impl Row + 'a) { + let order_key = (&self.row).project(state_order_key_indices); + let degree = build_degree_row(order_key, self.degree); + (&self.row, degree) + } + + pub fn encode(&self) -> EncodedJoinRow { + EncodedJoinRow { + compacted_row: (&self.row).into(), + degree: self.degree, + } + } +} + +pub type DegreeType = u64; + +fn build_degree_row(order_key: impl Row, degree: DegreeType) -> impl Row { + order_key.chain(row::once(Some(ScalarImpl::Int64(degree as i64)))) +} + +#[derive(Clone, Debug, EstimateSize)] +pub struct EncodedJoinRow { + pub compacted_row: CompactedRow, + pub degree: DegreeType, +} + +impl EncodedJoinRow { + pub fn decode(&self, data_types: &[DataType]) -> StreamExecutorResult> { + Ok(JoinRow { + row: self.decode_row(data_types)?, + degree: self.degree, + }) + } + + fn decode_row(&self, data_types: &[DataType]) -> StreamExecutorResult { + let row = self.compacted_row.deserialize(data_types)?; + Ok(row) + } +} diff --git a/src/stream/src/executor/lookup/impl_.rs b/src/stream/src/executor/lookup/impl_.rs index 8b86cfc602f08..a503a73a0c8a0 100644 --- a/src/stream/src/executor/lookup/impl_.rs +++ b/src/stream/src/executor/lookup/impl_.rs @@ -31,8 +31,8 @@ use risingwave_storage::StateStore; use super::sides::{stream_lookup_arrange_prev_epoch, stream_lookup_arrange_this_epoch}; use crate::cache::cache_may_stale; use crate::common::metrics::MetricsInfo; -use crate::common::JoinStreamChunkBuilder; use crate::executor::error::{StreamExecutorError, StreamExecutorResult}; +use crate::executor::join::builder::JoinStreamChunkBuilder; use crate::executor::lookup::cache::LookupCache; use crate::executor::lookup::sides::{ArrangeJoinSide, ArrangeMessage, StreamJoinSide}; use crate::executor::lookup::LookupExecutor; diff --git a/src/stream/src/executor/managed_state/mod.rs b/src/stream/src/executor/managed_state/mod.rs deleted file mode 100644 index c32dfb11be7c6..0000000000000 --- a/src/stream/src/executor/managed_state/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub mod join; diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index cb1d2a497ef8a..40ae252c03cf0 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -70,9 +70,9 @@ mod flow_control; mod hash_agg; pub mod hash_join; mod hop_window; +mod join; mod lookup; mod lookup_union; -mod managed_state; mod merge; mod mview; mod no_op; @@ -123,6 +123,7 @@ pub use flow_control::FlowControlExecutor; pub use hash_agg::HashAggExecutor; pub use hash_join::*; pub use hop_window::HopWindowExecutor; +pub use join::JoinType; pub use lookup::*; pub use lookup_union::LookupUnionExecutor; pub use merge::MergeExecutor; diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index 099abe658f615..32a0c5747083b 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -39,15 +39,16 @@ use risingwave_storage::table::batch_table::storage_table::StorageTable; use risingwave_storage::table::TableIter; use risingwave_storage::StateStore; +use super::join::{JoinType, JoinTypePrimitive}; use super::{ Barrier, Executor, ExecutorInfo, Message, MessageStream, StreamExecutorError, StreamExecutorResult, }; use crate::cache::{cache_may_stale, new_with_hasher_in, ManagedLruCache}; use crate::common::metrics::MetricsInfo; -use crate::common::JoinStreamChunkBuilder; +use crate::executor::join::builder::JoinStreamChunkBuilder; use crate::executor::monitor::StreamingMetrics; -use crate::executor::{ActorContextRef, BoxedExecutor, JoinType, JoinTypePrimitive, Watermark}; +use crate::executor::{ActorContextRef, BoxedExecutor, Watermark}; use crate::task::AtomicU64Ref; pub struct TemporalJoinExecutor { diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index d04db948853ba..0c50f26f941b6 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -27,7 +27,7 @@ use super::*; use crate::common::table::state_table::StateTable; use crate::executor::hash_join::*; use crate::executor::monitor::StreamingMetrics; -use crate::executor::ActorContextRef; +use crate::executor::{ActorContextRef, JoinType}; use crate::task::AtomicU64Ref; pub struct HashJoinExecutorBuilder; From 1ef8e5917f59026a4ca8d5991c7a1bdab4ea3e93 Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Tue, 30 Jan 2024 15:52:24 +0800 Subject: [PATCH 2/2] license --- src/stream/src/executor/join/hash_join.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/stream/src/executor/join/hash_join.rs b/src/stream/src/executor/join/hash_join.rs index c3fbaeb83face..123bd6e42e45e 100644 --- a/src/stream/src/executor/join/hash_join.rs +++ b/src/stream/src/executor/join/hash_join.rs @@ -1,3 +1,17 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::alloc::Global; use std::ops::{Bound, Deref, DerefMut}; use std::sync::Arc;