From 196b57977cd399d5ade5c558822aa425164e5a12 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Tue, 23 Jul 2024 16:44:41 +0800 Subject: [PATCH] chore: ensure safety condition of unsafe blocks --- src/arrows/mod.rs | 4 ++-- src/compaction/mod.rs | 5 ++++- src/inmem/immutable.rs | 18 +++++++++--------- src/ondisk/scan.rs | 14 ++++++++++---- src/ondisk/sstable.rs | 3 +++ src/stream/level.rs | 6 ++++-- src/stream/mod.rs | 13 +++++++------ src/timestamp/timestamped.rs | 32 +++++++------------------------- 8 files changed, 46 insertions(+), 49 deletions(-) diff --git a/src/arrows/mod.rs b/src/arrows/mod.rs index 0fb328bf..7cd89083 100644 --- a/src/arrows/mod.rs +++ b/src/arrows/mod.rs @@ -32,11 +32,11 @@ where let key = match range { Bound::Included(key) => { cmp = >_eq; - Some(unsafe { &*(key as *const _) }) + Some(&*(key as *const _)) } Bound::Excluded(key) => { cmp = > - Some(unsafe { &*(key as *const _) }) + Some(&*(key as *const _)) } Bound::Unbounded => { cmp = &|this, _| { diff --git a/src/compaction/mod.rs b/src/compaction/mod.rs index eb916cb6..63972d5a 100644 --- a/src/compaction/mod.rs +++ b/src/compaction/mod.rs @@ -301,7 +301,10 @@ where version_edits: &mut Vec::Key>>, level: usize, streams: Vec>, - ) -> Result<(), CompactionError> { + ) -> Result<(), CompactionError> + where + FP: 'scan, + { let mut stream = MergeStream::::from_vec(streams).await?; // Kould: is the capacity parameter necessary? diff --git a/src/inmem/immutable.rs b/src/inmem/immutable.rs index 227fa2b4..666d08ae 100644 --- a/src/inmem/immutable.rs +++ b/src/inmem/immutable.rs @@ -10,10 +10,7 @@ use super::mutable::Mutable; use crate::{ record::{internal::InternalRecordRef, Key, Record, RecordRef}, stream::record_batch::RecordBatchEntry, - timestamp::{ - timestamped::{Timestamped, TimestampedRef}, - Timestamp, EPOCH, - }, + timestamp::{Timestamp, Timestamped, TimestampedRef, EPOCH}, }; pub trait ArrowArrays: Sized { @@ -154,10 +151,13 @@ where self.range.next().map(|(_, &offset)| { let record_ref = R::Ref::from_record_batch(self.record_batch, offset as usize); // TODO: remove cloning record batch - RecordBatchEntry::new(self.record_batch.clone(), unsafe { - transmute::>, InternalRecordRef>>( - record_ref, - ) + RecordBatchEntry::new(self.record_batch.clone(), { + // Safety: record_ref self-references the record batch + unsafe { + transmute::>, InternalRecordRef>>( + record_ref, + ) + } }) }) } @@ -179,7 +179,7 @@ pub(crate) mod tests { use crate::{ record::Record, tests::{Test, TestRef}, - timestamp::Timestamped, + timestamp::timestamped::Timestamped, }; #[derive(Debug)] diff --git a/src/ondisk/scan.rs b/src/ondisk/scan.rs index fa5d397a..e96fe4d2 100644 --- a/src/ondisk/scan.rs +++ b/src/ondisk/scan.rs @@ -1,4 +1,5 @@ use std::{ + marker::PhantomData, pin::Pin, task::{Context, Poll}, }; @@ -16,26 +17,31 @@ use crate::{ pin_project! { #[derive(Debug)] - pub struct SsTableScan + pub struct SsTableScan<'scan, R, FP> where FP: FileProvider, { #[pin] stream: ParquetRecordBatchStream>, iter: Option>, + _marker: PhantomData<&'scan ()> } } -impl SsTableScan +impl SsTableScan<'_, R, FP> where FP: FileProvider, { pub fn new(stream: ParquetRecordBatchStream>) -> Self { - SsTableScan { stream, iter: None } + SsTableScan { + stream, + iter: None, + _marker: PhantomData, + } } } -impl Stream for SsTableScan +impl<'scan, R, FP> Stream for SsTableScan<'scan, R, FP> where R: Record, FP: FileProvider, diff --git a/src/ondisk/sstable.rs b/src/ondisk/sstable.rs index 0cb9088f..d25e9c12 100644 --- a/src/ondisk/sstable.rs +++ b/src/ondisk/sstable.rs @@ -111,6 +111,9 @@ where let builder = self.into_parquet_builder(limit).await?; let schema_descriptor = builder.metadata().file_metadata().schema_descr(); + + // Safety: filter's lifetime relies on range's lifetime, sstable must not live longer than + // it let filter = unsafe { get_range_filter::(schema_descriptor, range, ts) }; Ok(SsTableScan::new(builder.with_row_filter(filter).build()?)) diff --git a/src/stream/level.rs b/src/stream/level.rs index 6d92c670..46dab6f1 100644 --- a/src/stream/level.rs +++ b/src/stream/level.rs @@ -27,9 +27,11 @@ where FP: FileProvider, { Init(FileId), - Ready(SsTableScan), + Ready(SsTableScan<'level, R, FP>), OpenFile(Pin> + 'level>>), - LoadStream(Pin, ParquetError>> + 'level>>), + LoadStream( + Pin, ParquetError>> + 'level>>, + ), } pub(crate) struct LevelStream<'level, R, FP> diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 8f5100e2..b021c700 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -39,9 +39,10 @@ where { pub(crate) fn key(&self) -> Timestamped<::Ref<'_>> { match self { - Entry::Mutable(entry) => entry - .key() - .map(|key| unsafe { transmute(key.as_key_ref()) }), + Entry::Mutable(entry) => entry.key().map(|key| { + // Safety: shorter lifetime must be safe + unsafe { transmute(key.as_key_ref()) } + }), Entry::SsTable(entry) => entry.internal_key(), Entry::Immutable(entry) => entry.internal_key(), Entry::Level(entry) => entry.internal_key(), @@ -95,7 +96,7 @@ pin_project! { }, SsTable { #[pin] - inner: SsTableScan, + inner: SsTableScan<'scan, R, FP>, }, Level { #[pin] @@ -128,12 +129,12 @@ where } } -impl<'scan, R, FP> From> for ScanStream<'scan, R, FP> +impl<'scan, R, FP> From> for ScanStream<'scan, R, FP> where R: Record, FP: FileProvider, { - fn from(inner: SsTableScan) -> Self { + fn from(inner: SsTableScan<'scan, R, FP>) -> Self { ScanStream::SsTable { inner } } } diff --git a/src/timestamp/timestamped.rs b/src/timestamp/timestamped.rs index c1833ec3..249749f7 100644 --- a/src/timestamp/timestamped.rs +++ b/src/timestamp/timestamped.rs @@ -104,13 +104,7 @@ where V: PartialEq, { fn eq(&self, other: &Self) -> bool { - unsafe { - let this = transmute::<&TimestampedRef, [usize; 2]>(self); - let other = transmute::<&TimestampedRef, [usize; 2]>(other); - let this_value = transmute::(this[0]); - let other_value = transmute::(other[0]); - this_value == other_value && this[1] == other[1] - } + self.value() == other.value() && self.ts() == other.ts() } } @@ -121,15 +115,9 @@ where V: PartialOrd, { fn partial_cmp(&self, other: &Self) -> Option { - unsafe { - let this = transmute::<&TimestampedRef, [usize; 2]>(self); - let other = transmute::<&TimestampedRef, [usize; 2]>(other); - let this_value = transmute::(this[0]); - let other_value = transmute::(other[0]); - this_value - .partial_cmp(other_value) - .map(|ordering| ordering.then_with(|| other[1].cmp(&this[1]))) - } + self.value() + .partial_cmp(other.value()) + .map(|ordering| ordering.then_with(|| other.ts().cmp(&self.ts()))) } } @@ -138,15 +126,9 @@ where K: Ord, { fn cmp(&self, other: &Self) -> Ordering { - unsafe { - let this = transmute::<&TimestampedRef, [usize; 2]>(self); - let other = transmute::<&TimestampedRef, [usize; 2]>(other); - let this_value = transmute::(this[0]); - let other_value = transmute::(other[0]); - this_value - .cmp(other_value) - .then_with(|| other[1].cmp(&this[1])) - } + self.value() + .cmp(other.value()) + .then_with(|| other.ts().cmp(&self.ts())) } }