diff --git a/src/common/src/hash/consistent_hash/vnode.rs b/src/common/src/hash/consistent_hash/vnode.rs index 4dcfaf8d4a105..2e2d9016f2531 100644 --- a/src/common/src/hash/consistent_hash/vnode.rs +++ b/src/common/src/hash/consistent_hash/vnode.rs @@ -209,7 +209,8 @@ mod tests { #[test] fn test_serial_key_chunk() { - let mut gen = RowIdGenerator::new([VirtualNode::from_index(100)]); + let mut gen = + RowIdGenerator::new([VirtualNode::from_index(100)], VirtualNode::COUNT_FOR_TEST); let chunk = format!( "SRL I {} 1 @@ -229,7 +230,8 @@ mod tests { #[test] fn test_serial_key_row() { - let mut gen = RowIdGenerator::new([VirtualNode::from_index(100)]); + let mut gen = + RowIdGenerator::new([VirtualNode::from_index(100)], VirtualNode::COUNT_FOR_TEST); let row = OwnedRow::new(vec![ Some(ScalarImpl::Serial(gen.next().into())), Some(ScalarImpl::Int64(12345)), @@ -242,7 +244,10 @@ mod tests { #[test] fn test_serial_key_chunk_multiple_vnodes() { - let mut gen = RowIdGenerator::new([100, 200].map(VirtualNode::from_index)); + let mut gen = RowIdGenerator::new( + [100, 200].map(VirtualNode::from_index), + VirtualNode::COUNT_FOR_TEST, + ); let chunk = format!( "SRL I {} 1 diff --git a/src/common/src/util/row_id.rs b/src/common/src/util/row_id.rs index ef41f61b5f535..5d03d7691f432 100644 --- a/src/common/src/util/row_id.rs +++ b/src/common/src/util/row_id.rs @@ -15,23 +15,19 @@ use std::cmp::Ordering; use std::time::SystemTime; -use static_assertions::const_assert; - use super::epoch::UNIX_RISINGWAVE_DATE_EPOCH; use crate::hash::VirtualNode; -const TIMESTAMP_SHIFT_BITS: u8 = 22; -const VNODE_ID_SHIFT_BITS: u8 = 12; -const SEQUENCE_UPPER_BOUND: u16 = 1 << 12; -const VNODE_ID_UPPER_BOUND: u32 = 1 << 10; - -const_assert!(VNODE_ID_UPPER_BOUND >= VirtualNode::COUNT as u32); +const TIMESTAMP_SHIFT_BITS: u32 = 22; /// `RowIdGenerator` generates unique row ids using snowflake algorithm as following format: /// -/// | timestamp | vnode id | sequence | -/// |-----------|----------|----------| -/// | 41 bits | 10 bits | 12 bits | +/// | timestamp | vnode & sequence | +/// |-----------|------------------| +/// | 41 bits | 22 bits | +/// +/// The vnode part can occupy 10..=15 bits, which is determined by the vnode count. Thus, +/// the sequence part will occupy 7..=12 bits. See [`bit_for_vnode_count`] for more details. #[derive(Debug)] pub struct RowIdGenerator { /// Specific base timestamp using for generating row ids. @@ -40,8 +36,11 @@ pub struct RowIdGenerator { /// Last timestamp part of row id, based on `base`. last_timestamp_ms: i64, + /// The number of bits used for vnode. + vnode_bit: u32, + /// Virtual nodes used by this generator. - pub vnodes: Vec, + vnodes: Vec, /// Current index of `vnodes`. vnodes_index: u16, @@ -52,11 +51,28 @@ pub struct RowIdGenerator { pub type RowId = i64; +fn bit_for_vnode_count(vnode_count: usize) -> u32 { + debug_assert!( + vnode_count <= VirtualNode::MAX_COUNT as usize, + "invalid vnode count {vnode_count}" + ); + + if vnode_count <= 1024 { + 10 + } else { + vnode_count.next_power_of_two().ilog2() + } +} + #[inline] +// TODO(var-vnode): rename, not `extract` but `compute` pub fn extract_vnode_id_from_row_id(id: RowId, vnode_count: usize) -> VirtualNode { - let vnode_id = ((id >> VNODE_ID_SHIFT_BITS) & (VNODE_ID_UPPER_BOUND as i64 - 1)) as u32; - assert!(vnode_id < VNODE_ID_UPPER_BOUND); + let vnode_bit = bit_for_vnode_count(vnode_count); + let sequence_bit = TIMESTAMP_SHIFT_BITS - vnode_bit; + let vnode_part = ((id >> sequence_bit) & ((1 << vnode_bit) - 1)) as usize; + + // TODO: update comments // Previously, the vnode count was fixed to 256 for all jobs in all clusters. As a result, the // `vnode_id` must reside in the range of `0..256` and the following modulo operation will be // no-op. So this will retrieve the exact same vnode as when it was generated. @@ -65,22 +81,36 @@ pub fn extract_vnode_id_from_row_id(id: RowId, vnode_count: usize) -> VirtualNod // within the range, we need to apply modulo operation here. Therefore, there is no guarantee // that the vnode retrieved here is the same as when it was generated. However, the row ids // generated under the same vnode will still yield the same result. - VirtualNode::from_index(vnode_id as usize % vnode_count) + VirtualNode::from_index(vnode_part % vnode_count) } impl RowIdGenerator { - /// Create a new `RowIdGenerator` with given virtual nodes. - pub fn new(vnodes: impl IntoIterator) -> Self { + /// Create a new `RowIdGenerator` with given virtual nodes and vnode count. + pub fn new(vnodes: impl IntoIterator, vnode_count: usize) -> Self { let base = *UNIX_RISINGWAVE_DATE_EPOCH; + let vnode_bit = bit_for_vnode_count(vnode_count); + Self { base, last_timestamp_ms: base.elapsed().unwrap().as_millis() as i64, + vnode_bit, vnodes: vnodes.into_iter().collect(), vnodes_index: 0, sequence: 0, } } + /// Create a new `RowIdGenerator` with given virtual nodes and [`VirtualNode::COUNT_FOR_TEST`] + /// as vnode count. + pub fn new_for_test(vnodes: impl IntoIterator) -> Self { + Self::new(vnodes, VirtualNode::COUNT_FOR_TEST) + } + + /// The upper bound of the sequence part, exclusive. + fn sequence_upper_bound(&self) -> u16 { + 1 << (TIMESTAMP_SHIFT_BITS - self.vnode_bit) + } + /// Update the timestamp, so that the millisecond part of row id is **always** increased. /// /// This method will immediately return if the timestamp is increased or there's remaining @@ -99,7 +129,10 @@ impl RowIdGenerator { ); true } - Ordering::Equal => self.sequence == SEQUENCE_UPPER_BOUND, + Ordering::Equal => { + // Update the timestamp if the sequence reaches the upper bound. + self.sequence == self.sequence_upper_bound() + } Ordering::Greater => true, }; @@ -129,7 +162,7 @@ impl RowIdGenerator { /// timestamp, and `try_update_timestamp` should be called to update the timestamp and reset the /// sequence. After that, the next call of this method always returns `Some`. fn next_row_id_in_current_timestamp(&mut self) -> Option { - if self.sequence >= SEQUENCE_UPPER_BOUND { + if self.sequence >= self.sequence_upper_bound() { return None; } @@ -143,7 +176,7 @@ impl RowIdGenerator { Some( self.last_timestamp_ms << TIMESTAMP_SHIFT_BITS - | (vnode << VNODE_ID_SHIFT_BITS) as i64 + | (vnode << (TIMESTAMP_SHIFT_BITS - self.vnode_bit)) as i64 | sequence as i64, ) } @@ -196,9 +229,9 @@ mod tests { use super::*; - #[tokio::test] // `async` in favor of `madsim::time::advance` - async fn test_generator() { - let mut generator = RowIdGenerator::new([VirtualNode::from_index(0)]); + async fn test_generator_with_vnode_count(vnode_count: usize) { + let mut generator = RowIdGenerator::new([VirtualNode::from_index(0)], vnode_count); + let sequence_upper_bound = generator.sequence_upper_bound(); let mut last_row_id = generator.next(); for _ in 0..100000 { @@ -219,34 +252,75 @@ mod tests { row_id >> TIMESTAMP_SHIFT_BITS, last_row_id >> TIMESTAMP_SHIFT_BITS ); - assert_eq!(row_id & (SEQUENCE_UPPER_BOUND as i64 - 1), 0); + assert_eq!(row_id & (sequence_upper_bound as i64 - 1), 0); - let mut generator = RowIdGenerator::new([VirtualNode::from_index(1)]); - let row_ids = generator.next_batch((SEQUENCE_UPPER_BOUND + 10) as usize); - let mut expected = (0..SEQUENCE_UPPER_BOUND).collect_vec(); + let mut generator = RowIdGenerator::new([VirtualNode::from_index(1)], vnode_count); + let row_ids = generator.next_batch((sequence_upper_bound + 10) as usize); + let mut expected = (0..sequence_upper_bound).collect_vec(); expected.extend(0..10); assert_eq!( row_ids .into_iter() - .map(|id| (id as u16) & (SEQUENCE_UPPER_BOUND - 1)) + .map(|id| (id as u16) & (sequence_upper_bound - 1)) .collect_vec(), expected ); } - #[tokio::test] // `async` in favor of `madsim::time::advance` - async fn test_generator_multiple_vnodes() { - let mut generator = RowIdGenerator::new((0..10).map(VirtualNode::from_index)); + async fn test_generator_multiple_vnodes_with_vnode_count(vnode_count: usize) { + assert!(vnode_count >= 20); - let row_ids = generator.next_batch((SEQUENCE_UPPER_BOUND as usize) * 10 + 1); + let vnodes = || { + (0..10) + .chain((vnode_count - 10)..vnode_count) + .map(VirtualNode::from_index) + }; + let vnode_of = |row_id: RowId| extract_vnode_id_from_row_id(row_id, vnode_count); + + let mut generator = RowIdGenerator::new(vnodes(), vnode_count); + let sequence_upper_bound = generator.sequence_upper_bound(); + + let row_ids = generator.next_batch((sequence_upper_bound as usize) * 20 + 1); + + // Check timestamps. let timestamps = row_ids - .into_iter() - .map(|r| r >> TIMESTAMP_SHIFT_BITS) + .iter() + .map(|&r| r >> TIMESTAMP_SHIFT_BITS) .collect_vec(); let (last_timestamp, first_timestamps) = timestamps.split_last().unwrap(); let first_timestamp = first_timestamps.iter().unique().exactly_one().unwrap(); + // Check vnodes. + let expected_vnodes = vnodes().cycle(); + let actual_vnodes = row_ids.iter().map(|&r| vnode_of(r)); + + for (expected, actual) in expected_vnodes.zip(actual_vnodes) { + assert_eq!(expected, actual); + } + assert!(last_timestamp > first_timestamp); } + + macro_rules! test { + ($vnode_count:expr, $name:ident, $name_mul:ident) => { + #[tokio::test] + async fn $name() { + test_generator_with_vnode_count($vnode_count).await; + } + + #[tokio::test] + async fn $name_mul() { + test_generator_multiple_vnodes_with_vnode_count($vnode_count).await; + } + }; + } + + test!(64, test_64, test_64_mul); // less than default value + test!(114, test_114, test_114_mul); // not a power of 2, less than default value + test!(256, test_256, test_256_mul); // default value, backward compatibility + test!(1024, test_1024, test_1024_mul); // max value with 10 bits + test!(2048, test_2048, test_2048_mul); // more than 10 bits + test!(2333, test_2333, test_2333_mul); // not a power of 2, larger than default value + test!(VirtualNode::MAX_COUNT, test_max, test_max_mul); // max supported } diff --git a/src/stream/src/executor/row_id_gen.rs b/src/stream/src/executor/row_id_gen.rs index 5465a1b54ec2e..216d62432191b 100644 --- a/src/stream/src/executor/row_id_gen.rs +++ b/src/stream/src/executor/row_id_gen.rs @@ -50,7 +50,7 @@ impl RowIdGenExecutor { /// Create a new row id generator based on the assigned vnodes. fn new_generator(vnodes: &Bitmap) -> RowIdGenerator { - RowIdGenerator::new(vnodes.iter_vnodes()) + RowIdGenerator::new(vnodes.iter_vnodes(), vnodes.len()) } /// Generate a row ID column according to ops.