From 8a2764d81104073fc60419324e314f684e2ad2e7 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Tue, 16 Jul 2024 20:52:59 +0800 Subject: [PATCH] use file as trait object --- Cargo.toml | 4 ++-- src/executor.rs | 6 ++++++ src/fs/mod.rs | 9 ++++++++- src/lib.rs | 13 +++++-------- src/ondisk/scan.rs | 18 ++++++------------ src/ondisk/sstable.rs | 28 ++++++++++++++-------------- src/stream/merge.rs | 33 ++++++++++++++------------------- src/stream/mod.rs | 23 ++++++++--------------- src/version/mod.rs | 13 ++++++++----- 9 files changed, 71 insertions(+), 76 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bbd5b556..8b398acd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ resolver = "2" version = "0.1.0" [features] -tokio = ["dep:tokio", "dep:tokio-util"] +tokio = ["dep:tokio"] [dependencies] arrow = "52" @@ -21,7 +21,7 @@ parquet = { version = "52", features = ["async"] } pin-project-lite = "0.2" thiserror = "1" tokio = { version = "1", optional = true } -tokio-util = { version = "0.7", features = ["compat"], optional = true } +tokio-util = { version = "0.7", features = ["compat"] } tracing = "0.1" ulid = "1" diff --git a/src/executor.rs b/src/executor.rs index 886130b2..07836790 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -20,6 +20,12 @@ pub mod tokio { handle: Handle, } + impl Default for TokioExecutor { + fn default() -> Self { + Self::new() + } + } + impl TokioExecutor { pub fn new() -> Self { Self { diff --git a/src/fs/mod.rs b/src/fs/mod.rs index 51adde4e..7525879f 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -19,7 +19,14 @@ pub enum FileType { LOG, } -pub trait AsyncFile: AsyncRead + AsyncWrite + AsyncSeek + Send + Sync + Unpin + 'static {} +pub trait AsyncFile: AsyncRead + AsyncWrite + AsyncSeek + Send + Sync + Unpin + 'static { + fn to_file(self) -> Box + where + Self: Sized, + { + Box::new(self) as Box + } +} impl AsyncFile for T where T: AsyncRead + AsyncWrite + AsyncSeek + Send + Sync + Unpin + 'static {} diff --git a/src/lib.rs b/src/lib.rs index c9fc95ba..04e26d51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ use record::Record; use crate::{ executor::Executor, fs::{FileId, FileType}, - stream::{merge::MergeStream, Entry, ScanStream}, + stream::{merge::MergeStream, Entry}, version::Version, }; @@ -176,23 +176,20 @@ where where E: Executor, { - self.scan::(Bound::Included(key), Bound::Unbounded, ts) + self.scan(Bound::Included(key), Bound::Unbounded, ts) .await? .next() .await .transpose() } - async fn scan<'scan, E>( + async fn scan<'scan>( &'scan self, lower: Bound<&'scan R::Key>, uppwer: Bound<&'scan R::Key>, ts: Timestamp, - ) -> Result, ParquetError>>, ParquetError> - where - E: Executor, - { - let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); + ) -> Result, ParquetError>>, ParquetError> { + let mut streams = Vec::with_capacity(self.immutables.len() + 1); streams.push(self.mutable.scan((lower, uppwer), ts).into()); for immutable in &self.immutables { streams.push(immutable.scan((lower, uppwer), ts).into()); diff --git a/src/ondisk/scan.rs b/src/ondisk/scan.rs index 4508be6e..acbb1902 100644 --- a/src/ondisk/scan.rs +++ b/src/ondisk/scan.rs @@ -8,37 +8,31 @@ use parquet::arrow::async_reader::ParquetRecordBatchStream; use pin_project_lite::pin_project; use tokio_util::compat::Compat; +use crate::fs::AsyncFile; use crate::{ - executor::Executor, record::Record, stream::record_batch::{RecordBatchEntry, RecordBatchIterator}, }; pin_project! { #[derive(Debug)] - pub struct SsTableScan - where - E: Executor + pub struct SsTableScan { #[pin] - stream: ParquetRecordBatchStream>, + stream: ParquetRecordBatchStream>>, iter: Option>, } } -impl SsTableScan -where - E: Executor, -{ - pub fn new(stream: ParquetRecordBatchStream>) -> Self { +impl SsTableScan { + pub fn new(stream: ParquetRecordBatchStream>>) -> Self { SsTableScan { stream, iter: None } } } -impl Stream for SsTableScan +impl Stream for SsTableScan where R: Record, - E: Executor, { type Item = Result, parquet::errors::ParquetError>; diff --git a/src/ondisk/sstable.rs b/src/ondisk/sstable.rs index 334281e5..ece6c80f 100644 --- a/src/ondisk/sstable.rs +++ b/src/ondisk/sstable.rs @@ -15,36 +15,34 @@ use parquet::{ use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt}; use super::scan::SsTableScan; +use crate::fs::AsyncFile; use crate::{ arrows::get_range_filter, - executor::Executor, oracle::{timestamp::TimestampedRef, Timestamp}, record::Record, stream::record_batch::RecordBatchEntry, }; -pub(crate) struct SsTable +pub(crate) struct SsTable where R: Record, - E: Executor, { - file: E::File, - _marker: PhantomData<(R, E)>, + file: Box, + _marker: PhantomData, } -impl SsTable +impl SsTable where R: Record, - E: Executor, { - pub(crate) fn open(file: E::File) -> Self { + pub(crate) fn open(file: Box) -> Self { SsTable { file, _marker: PhantomData, } } - fn create_writer(&mut self) -> AsyncArrowWriter> { + fn create_writer(&mut self) -> AsyncArrowWriter> { // TODO: expose writer options let options = ArrowWriterOptions::new().with_properties( WriterProperties::builder() @@ -53,7 +51,7 @@ where .build(), ); AsyncArrowWriter::try_new_with_options( - (&mut self.file).compat(), + (&mut self.file as &mut dyn AsyncFile).compat(), R::arrow_schema().clone(), options, ) @@ -75,7 +73,7 @@ where async fn into_parquet_builder( self, limit: usize, - ) -> parquet::errors::Result>>> { + ) -> parquet::errors::Result>>>> { Ok(ParquetRecordBatchStreamBuilder::new_with_options( self.file.compat(), ArrowReaderOptions::default().with_page_index(true), @@ -99,7 +97,7 @@ where self, range: (Bound<&'scan R::Key>, Bound<&'scan R::Key>), ts: Timestamp, - ) -> Result, parquet::errors::ParquetError> { + ) -> Result, parquet::errors::ParquetError> { let builder = self.into_parquet_builder(1).await?; let schema_descriptor = builder.metadata().file_metadata().schema_descr(); @@ -114,6 +112,7 @@ mod tests { use std::borrow::Borrow; use super::SsTable; + use crate::fs::AsyncFile; use crate::{ executor::tokio::TokioExecutor, fs::Fs, @@ -127,8 +126,9 @@ mod tests { let record_batch = get_test_record_batch::().await; let file = TokioExecutor::open(&temp_dir.path().join("test.parquet")) .await - .unwrap(); - let mut sstable = SsTable::::open(file); + .unwrap() + .to_file(); + let mut sstable = SsTable::::open(file); sstable.write(record_batch).await.unwrap(); diff --git a/src/stream/merge.rs b/src/stream/merge.rs index 2734572b..5932558b 100644 --- a/src/stream/merge.rs +++ b/src/stream/merge.rs @@ -10,27 +10,25 @@ use futures_util::stream::StreamExt; use pin_project_lite::pin_project; use super::{Entry, ScanStream}; -use crate::{executor::Executor, record::Record}; +use crate::record::Record; pin_project! { - pub(crate) struct MergeStream<'merge, R, E> + pub(crate) struct MergeStream<'merge, R> where R: Record, - E: Executor, { - streams: Vec>, + streams: Vec>, peeked: BinaryHeap>, buf: Option>, } } -impl<'merge, R, E> MergeStream<'merge, R, E> +impl<'merge, R> MergeStream<'merge, R> where R: Record, - E: Executor, { pub(crate) async fn from_vec( - mut streams: Vec>, + mut streams: Vec>, ) -> Result { let mut peeked = BinaryHeap::with_capacity(streams.len()); @@ -51,10 +49,9 @@ where } } -impl<'merge, R, E> Stream for MergeStream<'merge, R, E> +impl<'merge, R> Stream for MergeStream<'merge, R> where R: Record, - E: Executor, { type Item = Result, parquet::errors::ParquetError>; @@ -138,7 +135,7 @@ mod tests { use futures_util::StreamExt; use super::MergeStream; - use crate::{executor::tokio::TokioExecutor, inmem::mutable::Mutable}; + use crate::{inmem::mutable::Mutable}; #[tokio::test] async fn merge_mutable() { @@ -158,7 +155,7 @@ mod tests { let lower = "a".to_string(); let upper = "e".to_string(); let bound = (Bound::Included(&lower), Bound::Included(&upper)); - let mut merge = MergeStream::::from_vec(vec![ + let mut merge = MergeStream::from_vec(vec![ m1.scan(bound, 6.into()).into(), m2.scan(bound, 6.into()).into(), m3.scan(bound, 6.into()).into(), @@ -186,10 +183,9 @@ mod tests { let lower = "1".to_string(); let upper = "4".to_string(); let bound = (Bound::Included(&lower), Bound::Included(&upper)); - let mut merge = - MergeStream::::from_vec(vec![m1.scan(bound, 0.into()).into()]) - .await - .unwrap(); + let mut merge = MergeStream::from_vec(vec![m1.scan(bound, 0.into()).into()]) + .await + .unwrap(); dbg!(merge.next().await); dbg!(merge.next().await); @@ -198,10 +194,9 @@ mod tests { let lower = "1".to_string(); let upper = "4".to_string(); let bound = (Bound::Included(&lower), Bound::Included(&upper)); - let mut merge = - MergeStream::::from_vec(vec![m1.scan(bound, 1.into()).into()]) - .await - .unwrap(); + let mut merge = MergeStream::from_vec(vec![m1.scan(bound, 1.into()).into()]) + .await + .unwrap(); dbg!(merge.next().await); dbg!(merge.next().await); diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 75c1bcf6..819601aa 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -14,7 +14,6 @@ use pin_project_lite::pin_project; use record_batch::RecordBatchEntry; use crate::{ - executor::Executor, inmem::{immutable::ImmutableScan, mutable::MutableScan}, ondisk::scan::SsTableScan, oracle::timestamp::Timestamped, @@ -66,10 +65,9 @@ where pin_project! { #[project = ScanStreamProject] - pub enum ScanStream<'scan, R, E> + pub enum ScanStream<'scan, R> where R: Record, - E: Executor, { Mutable { #[pin] @@ -81,15 +79,14 @@ pin_project! { }, SsTable { #[pin] - inner: SsTableScan, + inner: SsTableScan, }, } } -impl<'scan, R, E> From> for ScanStream<'scan, R, E> +impl<'scan, R> From> for ScanStream<'scan, R> where R: Record, - E: Executor, { fn from(inner: MutableScan<'scan, R>) -> Self { ScanStream::Mutable { @@ -98,10 +95,9 @@ where } } -impl<'scan, R, E> From> for ScanStream<'scan, R, E> +impl<'scan, R> From> for ScanStream<'scan, R> where R: Record, - E: Executor, { fn from(inner: ImmutableScan<'scan, R>) -> Self { ScanStream::Immutable { @@ -110,20 +106,18 @@ where } } -impl<'scan, R, E> From> for ScanStream<'scan, R, E> +impl<'scan, R> From> for ScanStream<'scan, R> where R: Record, - E: Executor, { - fn from(inner: SsTableScan) -> Self { + fn from(inner: SsTableScan) -> Self { ScanStream::SsTable { inner } } } -impl fmt::Debug for ScanStream<'_, R, E> +impl fmt::Debug for ScanStream<'_, R> where R: Record, - E: Executor, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { @@ -134,10 +128,9 @@ where } } -impl<'scan, R, E> Stream for ScanStream<'scan, R, E> +impl<'scan, R> Stream for ScanStream<'scan, R> where R: Record, - E: Executor, { type Item = Result, parquet::errors::ParquetError>; diff --git a/src/version/mod.rs b/src/version/mod.rs index 6dd364be..6630de17 100644 --- a/src/version/mod.rs +++ b/src/version/mod.rs @@ -10,6 +10,7 @@ use futures_util::SinkExt; use thiserror::Error; use tracing::error; +use crate::fs::AsyncFile; use crate::{ executor::Executor, fs::FileId, @@ -103,8 +104,9 @@ where ) -> Result>, VersionError> { let file = E::open(self.option.table_path(gen)) .await - .map_err(VersionError::Io)?; - let table = SsTable::::open(file); + .map_err(VersionError::Io)? + .to_file(); + let table = SsTable::open(file); table.get(key).await.map_err(VersionError::Parquet) } @@ -132,15 +134,16 @@ where pub(crate) async fn iters<'a>( &self, - iters: &mut Vec>, + iters: &mut Vec>, range: (Bound<&'a R::Key>, Bound<&'a R::Key>), ts: Timestamp, ) -> Result<(), VersionError> { for scope in self.level_slice[0].iter() { let file = E::open(self.option.table_path(&scope.gen)) .await - .map_err(VersionError::Io)?; - let table = SsTable::::open(file); + .map_err(VersionError::Io)? + .to_file(); + let table = SsTable::open(file); iters.push(ScanStream::SsTable { inner: table.scan(range, ts).await.map_err(VersionError::Parquet)?,