diff --git a/src/stream/src/common/table/state_table.rs b/src/stream/src/common/table/state_table.rs index be3c2620ddfe9..045fb1fdaeba9 100644 --- a/src/stream/src/common/table/state_table.rs +++ b/src/stream/src/common/table/state_table.rs @@ -17,6 +17,7 @@ use std::ops::Bound::*; use std::sync::Arc; use bytes::{BufMut, Bytes, BytesMut}; +use either::Either; use futures::{pin_mut, FutureExt, Stream, StreamExt}; use futures_async_stream::for_await; use itertools::{izip, Itertools}; @@ -1195,6 +1196,27 @@ where .await } + /// This function scans rows from the relational table with specific `prefix` and `pk_sub_range` under the same + /// `vnode`. + pub async fn iter_row_with_pk_prefix_sub_range( + &self, + pk_prefix: impl Row, + sub_range: &(Bound, Bound), + prefetch_options: PrefetchOptions, + ) -> StreamExecutorResult> { + let vnode = self.compute_prefix_vnode(&pk_prefix).to_be_bytes(); + + let memcomparable_range = + prefix_and_sub_range_to_memcomparable(&self.pk_serde, sub_range, pk_prefix); + + let memcomparable_range_with_vnode = prefixed_range(memcomparable_range, &vnode); + Ok(deserialize_keyed_row_stream( + self.iter_kv(memcomparable_range_with_vnode, None, prefetch_options) + .await?, + &self.row_serde, + )) + } + /// This function scans raw key-values from the relational table with specific `pk_range` under /// the same `vnode`. async fn iter_kv_with_pk_range( @@ -1297,15 +1319,38 @@ pub fn prefix_range_to_memcomparable( range: &(Bound, Bound), ) -> (Bound, Bound) { ( - to_memcomparable(pk_serde, &range.0, false), - to_memcomparable(pk_serde, &range.1, true), + start_range_to_memcomparable(pk_serde, &range.0), + end_range_to_memcomparable(pk_serde, &range.1, None), + ) +} + +fn prefix_and_sub_range_to_memcomparable( + pk_serde: &OrderedRowSerde, + sub_range: &(Bound, Bound), + pk_prefix: impl Row, +) -> (Bound, Bound) { + let (range_start, range_end) = sub_range; + let prefix_serializer = pk_serde.prefix(pk_prefix.len()); + let serialized_pk_prefix = serialize_pk(&pk_prefix, &prefix_serializer); + let start_range = match range_start { + Included(start_range) => Bound::Included(Either::Left((&pk_prefix).chain(start_range))), + Excluded(start_range) => Bound::Excluded(Either::Left((&pk_prefix).chain(start_range))), + Unbounded => Bound::Included(Either::Right(&pk_prefix)), + }; + let end_range = match range_end { + Included(end_range) => Bound::Included((&pk_prefix).chain(end_range)), + Excluded(end_range) => Bound::Excluded((&pk_prefix).chain(end_range)), + Unbounded => Unbounded, + }; + ( + start_range_to_memcomparable(pk_serde, &start_range), + end_range_to_memcomparable(pk_serde, &end_range, Some(serialized_pk_prefix)), ) } -fn to_memcomparable( +fn start_range_to_memcomparable( pk_serde: &OrderedRowSerde, bound: &Bound, - is_upper: bool, ) -> Bound { let serialize_pk_prefix = |pk_prefix: &R| { let prefix_serializer = pk_serde.prefix(pk_prefix.len()); @@ -1315,20 +1360,39 @@ fn to_memcomparable( Unbounded => Unbounded, Included(r) => { let serialized = serialize_pk_prefix(r); - if is_upper { - end_bound_of_prefix(&serialized) - } else { - Included(serialized) - } + + Included(serialized) } Excluded(r) => { let serialized = serialize_pk_prefix(r); - if !is_upper { - // if lower - start_bound_of_excluded_prefix(&serialized) - } else { - Excluded(serialized) - } + + start_bound_of_excluded_prefix(&serialized) + } + } +} + +fn end_range_to_memcomparable( + pk_serde: &OrderedRowSerde, + bound: &Bound, + serialized_pk_prefix: Option, +) -> Bound { + let serialize_pk_prefix = |pk_prefix: &R| { + let prefix_serializer = pk_serde.prefix(pk_prefix.len()); + serialize_pk(pk_prefix, &prefix_serializer) + }; + match bound { + Unbounded => match serialized_pk_prefix { + Some(serialized_pk_prefix) => end_bound_of_prefix(&serialized_pk_prefix), + None => Unbounded, + }, + Included(r) => { + let serialized = serialize_pk_prefix(r); + + end_bound_of_prefix(&serialized) + } + Excluded(r) => { + let serialized = serialize_pk_prefix(r); + Excluded(serialized) } } } diff --git a/src/stream/src/common/table/test_state_table.rs b/src/stream/src/common/table/test_state_table.rs index c3e5759a47ae6..2f5dc3202adb9 100644 --- a/src/stream/src/common/table/test_state_table.rs +++ b/src/stream/src/common/table/test_state_table.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Bound; + use futures::{pin_mut, StreamExt}; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; @@ -1833,3 +1835,183 @@ async fn test_state_table_watermark_cache_refill() { .as_scalar_ref_impl() ) } + +#[tokio::test] +async fn test_state_table_iter_prefix_and_sub_range() { + const TEST_TABLE_ID: TableId = TableId { table_id: 233 }; + let test_env = prepare_hummock_test_env().await; + + let order_types = vec![OrderType::ascending(), OrderType::ascending()]; + let column_ids = [ColumnId::from(0), ColumnId::from(1), ColumnId::from(2)]; + let column_descs = vec![ + ColumnDesc::unnamed(column_ids[0], DataType::Int32), + ColumnDesc::unnamed(column_ids[1], DataType::Int32), + ColumnDesc::unnamed(column_ids[2], DataType::Int32), + ]; + let pk_index = vec![0_usize, 1_usize]; + let read_prefix_len_hint = 0; + let table = gen_prost_table( + TEST_TABLE_ID, + column_descs, + order_types, + pk_index, + read_prefix_len_hint, + ); + + test_env.register_table(table.clone()).await; + let mut state_table = + StateTable::from_table_catalog_inconsistent_op(&table, test_env.storage.clone(), None) + .await; + let mut epoch = EpochPair::new_test_epoch(1); + state_table.init_epoch(epoch); + + state_table.insert(OwnedRow::new(vec![ + Some(1_i32.into()), + Some(11_i32.into()), + Some(111_i32.into()), + ])); + state_table.insert(OwnedRow::new(vec![ + Some(1_i32.into()), + Some(22_i32.into()), + Some(222_i32.into()), + ])); + state_table.insert(OwnedRow::new(vec![ + Some(1_i32.into()), + Some(33_i32.into()), + Some(333_i32.into()), + ])); + + state_table.insert(OwnedRow::new(vec![ + Some(4_i32.into()), + Some(44_i32.into()), + Some(444_i32.into()), + ])); + + epoch.inc(); + state_table.commit(epoch).await.unwrap(); + + let pk_prefix = OwnedRow::new(vec![Some(1_i32.into())]); + + let sub_range1 = ( + std::ops::Bound::Included(OwnedRow::new(vec![Some(11_i32.into())])), + std::ops::Bound::Excluded(OwnedRow::new(vec![Some(33_i32.into())])), + ); + + let iter = state_table + .iter_row_with_pk_prefix_sub_range(pk_prefix, &sub_range1, Default::default()) + .await + .unwrap(); + + pin_mut!(iter); + + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(11_i32.into()), + Some(111_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(22_i32.into()), + Some(222_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await; + assert!(res.is_none()); + + let sub_range2: (Bound, Bound) = ( + std::ops::Bound::Excluded(OwnedRow::new(vec![Some(11_i32.into())])), + std::ops::Bound::Unbounded, + ); + + let pk_prefix = OwnedRow::new(vec![Some(1_i32.into())]); + let iter = state_table + .iter_row_with_pk_prefix_sub_range(pk_prefix, &sub_range2, Default::default()) + .await + .unwrap(); + + pin_mut!(iter); + + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(22_i32.into()), + Some(222_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(33_i32.into()), + Some(333_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await; + assert!(res.is_none()); + + let sub_range3: (Bound, Bound) = ( + std::ops::Bound::Unbounded, + std::ops::Bound::Included(OwnedRow::new(vec![Some(33_i32.into())])), + ); + + let pk_prefix = OwnedRow::new(vec![Some(1_i32.into())]); + let iter = state_table + .iter_row_with_pk_prefix_sub_range(pk_prefix, &sub_range3, Default::default()) + .await + .unwrap(); + + pin_mut!(iter); + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(11_i32.into()), + Some(111_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(22_i32.into()), + Some(222_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await.unwrap().unwrap(); + + assert_eq!( + &OwnedRow::new(vec![ + Some(1_i32.into()), + Some(33_i32.into()), + Some(333_i32.into()), + ]), + res.as_ref() + ); + + let res = iter.next().await; + assert!(res.is_none()); +}