diff --git a/src/index/src/bloom_filter/creator.rs b/src/index/src/bloom_filter/creator.rs index 2f10521559a5..f8c54239645b 100644 --- a/src/index/src/bloom_filter/creator.rs +++ b/src/index/src/bloom_filter/creator.rs @@ -96,6 +96,49 @@ impl BloomFilterCreator { } } + /// Adds multiple rows of elements to the bloom filter. If the number of accumulated rows + /// reaches `rows_per_segment`, it finalizes the current segment. + pub async fn push_n_row_elems( + &mut self, + mut nrows: usize, + elems: impl IntoIterator, + ) -> Result<()> { + if nrows == 0 { + return Ok(()); + } + if nrows == 1 { + return self.push_row_elems(elems).await; + } + + let elems = elems.into_iter().collect::>(); + while nrows > 0 { + let rows_to_seg_end = + self.rows_per_segment - (self.accumulated_row_count % self.rows_per_segment); + let rows_to_push = nrows.min(rows_to_seg_end); + nrows -= rows_to_push; + + self.accumulated_row_count += rows_to_push; + + let mut mem_diff = 0; + for elem in &elems { + let len = elem.len(); + let is_new = self.cur_seg_distinct_elems.insert(elem.clone()); + if is_new { + mem_diff += len; + } + } + self.cur_seg_distinct_elems_mem_usage += mem_diff; + self.global_memory_usage + .fetch_add(mem_diff, Ordering::Relaxed); + + if self.accumulated_row_count % self.rows_per_segment == 0 { + self.finalize_segment().await?; + } + } + + Ok(()) + } + /// Adds a row of elements to the bloom filter. If the number of accumulated rows /// reaches `rows_per_segment`, it finalizes the current segment. pub async fn push_row_elems(&mut self, elems: impl IntoIterator) -> Result<()> { @@ -181,6 +224,13 @@ impl BloomFilterCreator { } } +impl Drop for BloomFilterCreator { + fn drop(&mut self) { + self.global_memory_usage + .fetch_sub(self.cur_seg_distinct_elems_mem_usage, Ordering::Relaxed); + } +} + #[cfg(test)] mod tests { use fastbloom::BloomFilter; @@ -266,4 +316,79 @@ mod tests { assert!(bfs[1].contains(&b"e")); assert!(bfs[1].contains(&b"f")); } + + #[tokio::test] + async fn test_bloom_filter_creator_batch_push() { + let mut writer = Cursor::new(Vec::new()); + let mut creator = BloomFilterCreator::new( + 2, + Box::new(MockExternalTempFileProvider::new()), + Arc::new(AtomicUsize::new(0)), + None, + ); + + creator + .push_n_row_elems(5, vec![b"a".to_vec(), b"b".to_vec()]) + .await + .unwrap(); + assert!(creator.cur_seg_distinct_elems_mem_usage > 0); + assert!(creator.memory_usage() > 0); + + creator + .push_n_row_elems(5, vec![b"c".to_vec(), b"d".to_vec()]) + .await + .unwrap(); + assert_eq!(creator.cur_seg_distinct_elems_mem_usage, 0); + assert!(creator.memory_usage() > 0); + + creator + .push_n_row_elems(10, vec![b"e".to_vec(), b"f".to_vec()]) + .await + .unwrap(); + assert_eq!(creator.cur_seg_distinct_elems_mem_usage, 0); + assert!(creator.memory_usage() > 0); + + creator.finish(&mut writer).await.unwrap(); + + let bytes = writer.into_inner(); + let total_size = bytes.len(); + let meta_size_offset = total_size - 4; + let meta_size = u32::from_le_bytes((&bytes[meta_size_offset..]).try_into().unwrap()); + + let meta_bytes = &bytes[total_size - meta_size as usize - 4..total_size - 4]; + let meta: BloomFilterMeta = serde_json::from_slice(meta_bytes).unwrap(); + + assert_eq!(meta.rows_per_segment, 2); + assert_eq!(meta.seg_count, 10); + assert_eq!(meta.row_count, 20); + assert_eq!( + meta.bloom_filter_segments_size + meta_bytes.len() + 4, + total_size + ); + + let mut bfs = Vec::new(); + for segment in meta.bloom_filter_segments { + let bloom_filter_bytes = + &bytes[segment.offset as usize..(segment.offset + segment.size) as usize]; + let v = u64_vec_from_bytes(bloom_filter_bytes); + let bloom_filter = BloomFilter::from_vec(v) + .seed(&SEED) + .expected_items(segment.elem_count); + bfs.push(bloom_filter); + } + + assert_eq!(bfs.len(), 10); + for bf in bfs.iter().take(3) { + assert!(bf.contains(&b"a")); + assert!(bf.contains(&b"b")); + } + for bf in bfs.iter().take(5).skip(2) { + assert!(bf.contains(&b"c")); + assert!(bf.contains(&b"d")); + } + for bf in bfs.iter().take(10).skip(5) { + assert!(bf.contains(&b"e")); + assert!(bf.contains(&b"f")); + } + } } diff --git a/src/index/src/bloom_filter/creator/finalize_segment.rs b/src/index/src/bloom_filter/creator/finalize_segment.rs index 65b090de3eee..091b1ee6aac0 100644 --- a/src/index/src/bloom_filter/creator/finalize_segment.rs +++ b/src/index/src/bloom_filter/creator/finalize_segment.rs @@ -183,6 +183,13 @@ impl FinalizedBloomFilterStorage { } } +impl Drop for FinalizedBloomFilterStorage { + fn drop(&mut self) { + self.global_memory_usage + .fetch_sub(self.memory_usage, Ordering::Relaxed); + } +} + /// A finalized Bloom filter segment. #[derive(Debug, Clone, PartialEq, Eq)] pub struct FinalizedBloomFilterSegment {