diff --git a/Cargo.lock b/Cargo.lock index a0b3e08b46441..db79c002c83e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8353,6 +8353,7 @@ dependencies = [ "async-recursion", "async-stream", "async-trait", + "auto_enums", "await-tree", "bytes", "criterion", diff --git a/src/stream/Cargo.toml b/src/stream/Cargo.toml index 9e9e77b92ceec..655effee51cfd 100644 --- a/src/stream/Cargo.toml +++ b/src/stream/Cargo.toml @@ -19,6 +19,7 @@ anyhow = "1" async-recursion = "1" async-stream = "0.3" async-trait = "0.1" +auto_enums = "0.8" await-tree = { workspace = true } bytes = "1" educe = "0.4" diff --git a/src/stream/src/executor/managed_state/join/join_entry_state.rs b/src/stream/src/executor/managed_state/join/join_entry_state.rs index 69fa706883bae..1efda013838ea 100644 --- a/src/stream/src/executor/managed_state/join/join_entry_state.rs +++ b/src/stream/src/executor/managed_state/join/join_entry_state.rs @@ -12,21 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{btree_map, BTreeMap}; - use risingwave_common::estimate_size::KvSize; use super::*; -#[expect(dead_code)] -type JoinEntryStateIter<'a> = btree_map::Iter<'a, PkType, StateValueType>; - -#[expect(dead_code)] -type JoinEntryStateValues<'a> = btree_map::Values<'a, PkType, StateValueType>; - -#[expect(dead_code)] -type JoinEntryStateValuesMut<'a> = btree_map::ValuesMut<'a, PkType, StateValueType>; - /// We manages a `HashMap` in memory for all entries belonging to a join key. /// When evicted, `cached` does not hold any entries. /// @@ -35,7 +24,7 @@ type JoinEntryStateValuesMut<'a> = btree_map::ValuesMut<'a, PkType, StateValueTy #[derive(Default)] pub struct JoinEntryState { /// The full copy of the state. - cached: BTreeMap, + cached: join_row_set::JoinRowSet, kv_heap_size: KvSize, } @@ -97,20 +86,11 @@ mod tests { use super::*; - #[tokio::test] - async fn test_managed_all_or_none_state() { - let mut managed_state = JoinEntryState::default(); - let pk_indices = [0]; - let col1 = [1, 2, 3]; - let col2 = [6, 5, 4]; - let col_types = vec![DataType::Int64, DataType::Int64]; - let data_chunk = DataChunk::from_pretty( - "I I - 3 4 - 2 5 - 1 6", - ); - + fn insert_chunk( + managed_state: &mut JoinEntryState, + pk_indices: &[usize], + data_chunk: &DataChunk, + ) { for row_ref in data_chunk.rows() { let row: OwnedRow = row_ref.into_owned_row(); let value_indices = (0..row.len() - 1).collect_vec(); @@ -120,9 +100,16 @@ mod tests { let join_row = JoinRow { row, degree: 0 }; managed_state.insert(pk, join_row.encode()); } + } + fn check( + managed_state: &mut JoinEntryState, + col_types: &[DataType], + col1: &[i64], + col2: &[i64], + ) { for ((_, matched_row), (d1, d2)) in managed_state - .values_mut(&col_types) + .values_mut(col_types) .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter())) { let matched_row = matched_row.unwrap(); @@ -131,4 +118,35 @@ mod tests { assert_eq!(matched_row.degree, 0); } } + + #[tokio::test] + async fn test_managed_all_or_none_state() { + let mut managed_state = JoinEntryState::default(); + let col_types = vec![DataType::Int64, DataType::Int64]; + let pk_indices = [0]; + + let col1 = [3, 2, 1]; + let col2 = [4, 5, 6]; + let data_chunk1 = DataChunk::from_pretty( + "I I + 3 4 + 2 5 + 1 6", + ); + + // `Vec` in state + insert_chunk(&mut managed_state, &pk_indices, &data_chunk1); + check(&mut managed_state, &col_types, &col1, &col2); + + // `BtreeMap` in state + let col1 = [1, 2, 3, 4, 5]; + let col2 = [6, 5, 4, 9, 8]; + let data_chunk2 = DataChunk::from_pretty( + "I I + 5 8 + 4 9", + ); + insert_chunk(&mut managed_state, &pk_indices, &data_chunk2); + check(&mut managed_state, &col_types, &col1, &col2); + } } diff --git a/src/stream/src/executor/managed_state/join/join_row_set.rs b/src/stream/src/executor/managed_state/join/join_row_set.rs new file mode 100644 index 0000000000000..5b35ff0c68bc6 --- /dev/null +++ b/src/stream/src/executor/managed_state/join/join_row_set.rs @@ -0,0 +1,116 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError; +use std::collections::BTreeMap; +use std::fmt::Debug; +use std::mem; + +use auto_enums::auto_enum; +use enum_as_inner::EnumAsInner; + +const MAX_VEC_SIZE: usize = 4; + +#[derive(Debug, EnumAsInner)] +pub enum JoinRowSet { + BTree(BTreeMap), + Vec(Vec<(K, V)>), +} + +impl Default for JoinRowSet { + fn default() -> Self { + Self::Vec(Vec::new()) + } +} + +#[derive(Debug)] +#[allow(dead_code)] +pub struct VecOccupiedError<'a, K, V> { + key: &'a K, + old_value: &'a V, + new_value: V, +} + +#[derive(Debug)] +pub enum JoinRowSetOccupiedError<'a, K: Ord, V> { + BTree(BTreeMapOccupiedError<'a, K, V>), + Vec(VecOccupiedError<'a, K, V>), +} + +impl JoinRowSet { + pub fn try_insert( + &mut self, + key: K, + value: V, + ) -> Result<&'_ mut V, JoinRowSetOccupiedError<'_, K, V>> { + if let Self::Vec(inner) = self && inner.len() >= MAX_VEC_SIZE { + let btree = BTreeMap::from_iter(inner.drain(..)); + mem::swap(self, &mut Self::BTree(btree)); + } + + match self { + Self::BTree(inner) => inner + .try_insert(key, value) + .map_err(JoinRowSetOccupiedError::BTree), + Self::Vec(inner) => { + if let Some(pos) = inner.iter().position(|elem| elem.0 == key) { + Err(JoinRowSetOccupiedError::Vec(VecOccupiedError { + key: &inner[pos].0, + old_value: &inner[pos].1, + new_value: value, + })) + } else { + if inner.capacity() == 0 { + // `Vec` will give capacity 4 when `1 < mem::size_of:: <= 1024` + // We only give one for memory optimization + inner.reserve_exact(1); + } + inner.push((key, value)); + Ok(&mut inner.last_mut().unwrap().1) + } + } + } + } + + pub fn remove(&mut self, key: &K) -> Option { + let ret = match self { + Self::BTree(inner) => inner.remove(key), + Self::Vec(inner) => inner + .iter() + .position(|elem| &elem.0 == key) + .map(|pos| inner.swap_remove(pos).1), + }; + if let Self::BTree(inner) = self && inner.len() <= MAX_VEC_SIZE / 2 { + let btree = mem::take(inner); + let vec = Vec::from_iter(btree); + mem::swap(self, &mut Self::Vec(vec)); + } + ret + } + + pub fn len(&self) -> usize { + match self { + Self::BTree(inner) => inner.len(), + Self::Vec(inner) => inner.len(), + } + } + + #[auto_enum(Iterator)] + pub fn values_mut(&mut self) -> impl Iterator { + match self { + Self::BTree(inner) => inner.values_mut(), + Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v), + } + } +} diff --git a/src/stream/src/executor/managed_state/join/mod.rs b/src/stream/src/executor/managed_state/join/mod.rs index b7a81a0f75745..d8ad231c677c7 100644 --- a/src/stream/src/executor/managed_state/join/mod.rs +++ b/src/stream/src/executor/managed_state/join/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. mod join_entry_state; +mod join_row_set; use std::alloc::Global; use std::ops::{Bound, Deref, DerefMut};