diff --git a/src/lib.rs b/src/lib.rs index 10ebd588..94435155 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,11 @@ use futures_core::Stream; use futures_util::StreamExt; use inmem::{immutable::Immutable, mutable::Mutable}; use lockable::LockableHashMap; -use parquet::{arrow::ProjectionMask, errors::ParquetError, file::properties::WriterProperties}; +use parquet::{ + arrow::{arrow_to_parquet_schema, ProjectionMask}, + errors::ParquetError, + file::properties::WriterProperties, +}; use record::Record; use thiserror::Error; use timestamp::Timestamp; @@ -39,6 +43,11 @@ use crate::{ type LockMap = Arc>; +pub enum Projection { + All, + Parts(Vec), +} + #[derive(Debug)] pub struct DbOption { pub path: PathBuf, @@ -201,40 +210,23 @@ where &'get self, key: &'get R::Key, ts: Timestamp, - projection_mask: ProjectionMask, - ) -> Result>, ParquetError> - where - FP: FileProvider, - { - self.scan(Bound::Included(key), Bound::Unbounded, ts, projection_mask) - .await? - .next() - .await - .transpose() + projection: Projection, + ) -> Result>, ParquetError> { + let mut scan = self.scan(Bound::Included(key), Bound::Unbounded, ts); + + if let Projection::Parts(projection) = projection { + scan = scan.projection(projection) + } + scan.take().await?.next().await.transpose() } - async fn scan<'scan>( + fn scan<'scan>( &'scan self, lower: Bound<&'scan R::Key>, uppwer: Bound<&'scan R::Key>, ts: Timestamp, - projection_mask: ProjectionMask, - ) -> Result, ParquetError>>, ParquetError> - where - FP: FileProvider, - { - let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); - streams.push(self.mutable.scan((lower, uppwer), ts).into()); - for immutable in &self.immutables { - streams.push( - immutable - .scan((lower, uppwer), ts, projection_mask.clone()) - .into(), - ); - } - // TODO: sstable scan - - MergeStream::from_vec(streams).await + ) -> Scan<'scan, R, FP> { + Scan::new(self, lower, uppwer, ts) } fn check_conflict(&self, key: &R::Key, ts: Timestamp) -> bool { @@ -252,6 +244,79 @@ where } } +pub struct Scan<'scan, R, FP> +where + R: Record, + FP: FileProvider, +{ + schema: &'scan Schema, + lower: Bound<&'scan R::Key>, + uppwer: Bound<&'scan R::Key>, + ts: Timestamp, + + projection: ProjectionMask, +} + +impl<'scan, R, FP> Scan<'scan, R, FP> +where + R: Record + Send, + FP: FileProvider, +{ + fn new( + schema: &'scan Schema, + lower: Bound<&'scan R::Key>, + uppwer: Bound<&'scan R::Key>, + ts: Timestamp, + ) -> Self { + Self { + schema, + lower, + uppwer, + ts, + projection: ProjectionMask::all(), + } + } + + pub fn projection(self, mut projection: Vec) -> Self { + // skip two columns: _null and _ts + for p in &mut projection { + *p += 2; + } + + let mask = ProjectionMask::roots( + &arrow_to_parquet_schema(R::arrow_schema()).unwrap(), + projection, + ); + + Self { + projection: mask, + ..self + } + } + + pub async fn take( + self, + ) -> Result, ParquetError>>, ParquetError> { + let mut streams = Vec::>::with_capacity(self.schema.immutables.len() + 1); + streams.push( + self.schema + .mutable + .scan((self.lower, self.uppwer), self.ts) + .into(), + ); + for immutable in &self.schema.immutables { + streams.push( + immutable + .scan((self.lower, self.uppwer), self.ts, self.projection.clone()) + .into(), + ); + } + // TODO: sstable scan + + MergeStream::from_vec(streams).await + } +} + #[derive(Debug, Error)] pub enum WriteError where @@ -373,7 +438,6 @@ pub(crate) mod tests { if !vbool_array.is_null(offset) { vbool = Some(vbool_array.value(offset)); } - column_i += 1; } let record = TestRef { diff --git a/src/transaction.rs b/src/transaction.rs index 14d1c55d..84008ae9 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -6,7 +6,7 @@ use std::{ use async_lock::RwLockReadGuard; use lockable::SyncLimit; -use parquet::{arrow::ProjectionMask, errors::ParquetError}; +use parquet::errors::ParquetError; use thiserror::Error; use crate::{ @@ -15,7 +15,7 @@ use crate::{ stream, timestamp::Timestamp, version::{set::transaction_ts, VersionRef}, - LockMap, Record, Schema, + LockMap, Projection, Record, Schema, }; pub struct Transaction<'txn, R, FP> @@ -52,13 +52,13 @@ where pub async fn get<'get>( &'get self, key: &'get R::Key, - projection_mask: ProjectionMask, + projection: Projection, ) -> Result>, ParquetError> { Ok(match self.local.get(key).and_then(|v| v.as_ref()) { Some(v) => Some(TransactionEntry::Local(v.as_record_ref())), None => self .share - .get(key, self.ts, projection_mask) + .get(key, self.ts, projection) .await? .map(TransactionEntry::Stream), }) @@ -148,12 +148,11 @@ where mod tests { use std::sync::Arc; - use parquet::arrow::{arrow_to_parquet_schema, ProjectionMask}; use tempfile::TempDir; use crate::{ - executor::tokio::TokioExecutor, record::Record, tests::Test, transaction::CommitError, - DbOption, DB, + executor::tokio::TokioExecutor, tests::Test, transaction::CommitError, DbOption, + Projection, DB, }; #[tokio::test] @@ -172,7 +171,7 @@ mod tests { let txn2 = db.transaction().await; dbg!(txn2 - .get(&"foo".to_string(), ProjectionMask::all()) + .get(&"foo".to_string(), Projection::All) .await .unwrap() .is_none()); @@ -184,7 +183,7 @@ mod tests { { let txn3 = db.transaction().await; dbg!(txn3 - .get(&"foo".to_string(), ProjectionMask::all()) + .get(&"foo".to_string(), Projection::All) .await .unwrap() .is_none()); @@ -246,17 +245,7 @@ mod tests { }); let key = 0.to_string(); - let entry = txn1 - .get( - &key, - ProjectionMask::roots( - &arrow_to_parquet_schema(Test::arrow_schema()).unwrap(), - [0, 1, 2], - ), - ) - .await - .unwrap() - .unwrap(); + let entry = txn1.get(&key, Projection::All).await.unwrap().unwrap(); assert_eq!(entry.get().vstring, 0.to_string()); assert_eq!(entry.get().vu32, Some(0));