From 68a979597bb0145f04cc970b1db5cdac7032e515 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Wed, 24 Jul 2024 01:29:46 +0800 Subject: [PATCH] refactor: lazy projection api --- src/lib.rs | 110 ++++++++++++++++++++++++++++++++++----------- src/transaction.rs | 33 +++----------- 2 files changed, 90 insertions(+), 53 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 10ebd588..4e6d5f5b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,11 @@ use futures_core::Stream; use futures_util::StreamExt; use inmem::{immutable::Immutable, mutable::Mutable}; use lockable::LockableHashMap; -use parquet::{arrow::ProjectionMask, errors::ParquetError, file::properties::WriterProperties}; +use parquet::{ + arrow::{arrow_to_parquet_schema, ProjectionMask}, + errors::ParquetError, + file::properties::WriterProperties, +}; use record::Record; use thiserror::Error; use timestamp::Timestamp; @@ -201,40 +205,22 @@ where &'get self, key: &'get R::Key, ts: Timestamp, - projection_mask: ProjectionMask, - ) -> Result>, ParquetError> - where - FP: FileProvider, - { - self.scan(Bound::Included(key), Bound::Unbounded, ts, projection_mask) + ) -> Result>, ParquetError> { + self.scan(Bound::Included(key), Bound::Unbounded, ts) + .take() .await? .next() .await .transpose() } - async fn scan<'scan>( + fn scan<'scan>( &'scan self, lower: Bound<&'scan R::Key>, uppwer: Bound<&'scan R::Key>, ts: Timestamp, - projection_mask: ProjectionMask, - ) -> Result, ParquetError>>, ParquetError> - where - FP: FileProvider, - { - let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); - streams.push(self.mutable.scan((lower, uppwer), ts).into()); - for immutable in &self.immutables { - streams.push( - immutable - .scan((lower, uppwer), ts, projection_mask.clone()) - .into(), - ); - } - // TODO: sstable scan - - MergeStream::from_vec(streams).await + ) -> Scan<'scan, R, FP> { + Scan::new(self, lower, uppwer, ts) } fn check_conflict(&self, key: &R::Key, ts: Timestamp) -> bool { @@ -252,6 +238,79 @@ where } } +pub struct Scan<'scan, R, FP> +where + R: Record, + FP: FileProvider, +{ + schema: &'scan Schema, + lower: Bound<&'scan R::Key>, + uppwer: Bound<&'scan R::Key>, + ts: Timestamp, + + projection: ProjectionMask, +} + +impl<'scan, R, FP> Scan<'scan, R, FP> +where + R: Record + Send, + FP: FileProvider, +{ + fn new( + schema: &'scan Schema, + lower: Bound<&'scan R::Key>, + uppwer: Bound<&'scan R::Key>, + ts: Timestamp, + ) -> Self { + Self { + schema, + lower, + uppwer, + ts, + projection: ProjectionMask::all(), + } + } + + pub fn project(self, mut projection: Vec) -> Self { + // skip two columns: _null and _ts + for p in &mut projection { + *p += 2; + } + + let mask = ProjectionMask::roots( + &arrow_to_parquet_schema(R::arrow_schema()).unwrap(), + projection, + ); + + Self { + projection: mask, + ..self + } + } + + pub async fn take( + self, + ) -> Result, ParquetError>>, ParquetError> { + let mut streams = Vec::>::with_capacity(self.schema.immutables.len() + 1); + streams.push( + self.schema + .mutable + .scan((self.lower, self.uppwer), self.ts) + .into(), + ); + for immutable in &self.schema.immutables { + streams.push( + immutable + .scan((self.lower, self.uppwer), self.ts, self.projection.clone()) + .into(), + ); + } + // TODO: sstable scan + + MergeStream::from_vec(streams).await + } +} + #[derive(Debug, Error)] pub enum WriteError where @@ -373,7 +432,6 @@ pub(crate) mod tests { if !vbool_array.is_null(offset) { vbool = Some(vbool_array.value(offset)); } - column_i += 1; } let record = TestRef { diff --git a/src/transaction.rs b/src/transaction.rs index 14d1c55d..0c1ef44f 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -6,7 +6,7 @@ use std::{ use async_lock::RwLockReadGuard; use lockable::SyncLimit; -use parquet::{arrow::ProjectionMask, errors::ParquetError}; +use parquet::errors::ParquetError; use thiserror::Error; use crate::{ @@ -52,13 +52,12 @@ where pub async fn get<'get>( &'get self, key: &'get R::Key, - projection_mask: ProjectionMask, ) -> Result>, ParquetError> { Ok(match self.local.get(key).and_then(|v| v.as_ref()) { Some(v) => Some(TransactionEntry::Local(v.as_record_ref())), None => self .share - .get(key, self.ts, projection_mask) + .get(key, self.ts) .await? .map(TransactionEntry::Stream), }) @@ -148,12 +147,10 @@ where mod tests { use std::sync::Arc; - use parquet::arrow::{arrow_to_parquet_schema, ProjectionMask}; use tempfile::TempDir; use crate::{ - executor::tokio::TokioExecutor, record::Record, tests::Test, transaction::CommitError, - DbOption, DB, + executor::tokio::TokioExecutor, tests::Test, transaction::CommitError, DbOption, DB, }; #[tokio::test] @@ -171,11 +168,7 @@ mod tests { txn1.set("foo".to_string()); let txn2 = db.transaction().await; - dbg!(txn2 - .get(&"foo".to_string(), ProjectionMask::all()) - .await - .unwrap() - .is_none()); + dbg!(txn2.get(&"foo".to_string()).await.unwrap().is_none()); txn1.commit().await.unwrap(); txn2.commit().await.unwrap(); @@ -183,11 +176,7 @@ mod tests { { let txn3 = db.transaction().await; - dbg!(txn3 - .get(&"foo".to_string(), ProjectionMask::all()) - .await - .unwrap() - .is_none()); + dbg!(txn3.get(&"foo".to_string()).await.unwrap().is_none()); txn3.commit().await.unwrap(); } } @@ -246,17 +235,7 @@ mod tests { }); let key = 0.to_string(); - let entry = txn1 - .get( - &key, - ProjectionMask::roots( - &arrow_to_parquet_schema(Test::arrow_schema()).unwrap(), - [0, 1, 2], - ), - ) - .await - .unwrap() - .unwrap(); + let entry = txn1.get(&key).await.unwrap().unwrap(); assert_eq!(entry.get().vstring, 0.to_string()); assert_eq!(entry.get().vu32, Some(0));