diff --git a/src/compaction/mod.rs b/src/compaction/mod.rs index d5ba0569..f84c923b 100644 --- a/src/compaction/mod.rs +++ b/src/compaction/mod.rs @@ -30,7 +30,7 @@ where FP: FileProvider, { pub(crate) option: Arc, - pub(crate) schema: Arc>>, + pub(crate) schema: Arc>>, pub(crate) version_set: VersionSet, } @@ -40,7 +40,7 @@ where FP: FileProvider, { pub(crate) fn new( - schema: Arc>>, + schema: Arc>>, option: Arc, version_set: VersionSet, ) -> Self { diff --git a/src/lib.rs b/src/lib.rs index 6d7d4493..f8b9de9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,14 +17,15 @@ use std::{ collections::VecDeque, io, marker::PhantomData, mem, ops::Bound, path::PathBuf, sync::Arc, }; -use async_lock::{RwLock, RwLockReadGuard}; +use async_lock::RwLock; use fs::FileProvider; use futures_core::Stream; use futures_util::StreamExt; use inmem::{immutable::Immutable, mutable::Mutable}; -use oracle::Timestamp; +use oracle::{Oracle, Timestamp}; use parquet::{errors::ParquetError, file::properties::WriterProperties}; use record::Record; +use transaction::Transaction; use crate::{ executor::Executor, @@ -45,15 +46,6 @@ pub struct DbOption { pub write_parquet_option: Option, } -pub struct DB -where - R: Record, - E: Executor, -{ - schema: Arc>>, - _p: PhantomData, -} - impl DbOption { pub fn new(path: impl Into + Send) -> Self { DbOption { @@ -94,6 +86,16 @@ impl DbOption { } } +pub struct DB +where + R: Record, + E: Executor, +{ + schema: Arc>>, + oracle: Oracle, + _p: PhantomData, +} + impl Default for DB where R: Record, @@ -103,6 +105,7 @@ where Self { schema: Arc::new(RwLock::new(Schema::default())), _p: Default::default(), + oracle: Oracle::default(), } } } @@ -113,15 +116,16 @@ where E: Executor, { pub fn empty() -> Self { - Self { - schema: Arc::new(RwLock::new(Schema::default())), - _p: Default::default(), - } + Self::default() + } + + pub async fn transaction(&self) -> Transaction<'_, R, E> { + Transaction::new(&self.oracle, self.schema.read().await) } pub(crate) async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> { - let columns = self.schema.read().await; - columns.write(record, ts).await + let schema = self.schema.read().await; + schema.write(record, ts).await } pub(crate) async fn write_batch( @@ -135,21 +139,18 @@ where } Ok(()) } - - pub(crate) async fn read(&self) -> RwLockReadGuard<'_, Schema> { - self.schema.read().await - } } -pub(crate) struct Schema +pub(crate) struct Schema where R: Record, { mutable: Mutable, immutables: VecDeque>, + _marker: PhantomData, } -impl Default for Schema +impl Default for Schema where R: Record, { @@ -157,44 +158,51 @@ where Self { mutable: Mutable::default(), immutables: VecDeque::default(), + _marker: Default::default(), } } } -impl Schema +impl Schema where R: Record + Send, + FP: FileProvider, { async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> { self.mutable.insert(record, ts); Ok(()) } - async fn get<'get, E>( + async fn remove(&self, key: R::Key, ts: Timestamp) -> io::Result<()> { + self.mutable.remove(key, ts); + Ok(()) + } + + async fn get<'get>( &'get self, key: &'get R::Key, ts: Timestamp, ) -> Result>, ParquetError> where - E: Executor + 'get, + FP: FileProvider, { - self.scan::(Bound::Included(key), Bound::Unbounded, ts) + self.scan(Bound::Included(key), Bound::Unbounded, ts) .await? .next() .await .transpose() } - async fn scan<'scan, E>( + async fn scan<'scan>( &'scan self, lower: Bound<&'scan R::Key>, uppwer: Bound<&'scan R::Key>, ts: Timestamp, ) -> Result, ParquetError>>, ParquetError> where - E: Executor + 'scan, + FP: FileProvider, { - let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); + let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); streams.push(self.mutable.scan((lower, uppwer), ts).into()); for immutable in &self.immutables { streams.push(immutable.scan((lower, uppwer), ts).into()); diff --git a/src/oracle/mod.rs b/src/oracle/mod.rs index 28011c80..3c6280a8 100644 --- a/src/oracle/mod.rs +++ b/src/oracle/mod.rs @@ -61,7 +61,7 @@ impl Oracle where K: Eq + Hash + Clone, { - fn start_read(&self) -> Timestamp { + pub(crate) fn start_read(&self) -> Timestamp { let mut in_read = self.in_read.lock().unwrap(); let now = self.now.load(Ordering::Relaxed).into(); match in_read.entry(now) { @@ -75,7 +75,7 @@ where now } - fn read_commit(&self, ts: Timestamp) { + pub(crate) fn read_commit(&self, ts: Timestamp) { match self.in_read.lock().unwrap().entry(ts) { Entry::Vacant(_) => panic!("commit non-existing read"), Entry::Occupied(mut o) => match o.get_mut() { @@ -89,27 +89,28 @@ where } } - fn start_write(&self) -> Timestamp { + pub(crate) fn start_write(&self) -> Timestamp { (self.now.fetch_add(1, Ordering::Relaxed) + 1).into() } - fn write_commit( + pub(crate) fn write_commit( &self, read_at: Timestamp, write_at: Timestamp, in_write: HashSet, ) -> Result<(), WriteConflict> { let mut committed_txns = self.committed_txns.lock().unwrap(); - let conflicts: Vec<_> = committed_txns + let conflicts = committed_txns .range((Bound::Excluded(read_at), Bound::Excluded(write_at))) .flat_map(|(_, txn)| txn.intersection(&in_write)) .cloned() - .collect(); + .collect::>(); if !conflicts.is_empty() { return Err(WriteConflict { keys: conflicts }); } + // TODO: clean committed transactions committed_txns.insert(write_at, in_write); Ok(()) } diff --git a/src/oracle/timestamp.rs b/src/oracle/timestamp.rs index 655118b5..049cc673 100644 --- a/src/oracle/timestamp.rs +++ b/src/oracle/timestamp.rs @@ -8,7 +8,7 @@ use std::{ use crate::{oracle::Timestamp, serdes::Encode}; #[derive(PartialEq, Eq, Debug, Clone)] -pub(crate) struct Timestamped { +pub struct Timestamped { pub(crate) ts: Timestamp, pub(crate) value: V, } diff --git a/src/record/mod.rs b/src/record/mod.rs index b76cc376..cac84cd6 100644 --- a/src/record/mod.rs +++ b/src/record/mod.rs @@ -1,7 +1,7 @@ pub(crate) mod internal; mod str; -use std::{fmt::Debug, sync::Arc}; +use std::{hash::Hash, sync::Arc}; use arrow::{ array::{Datum, RecordBatch}, @@ -14,7 +14,7 @@ use crate::{ serdes::{Decode, Encode}, }; -pub trait Key: 'static + Debug + Encode + Decode + Ord + Clone + Send { +pub trait Key: 'static + Encode + Decode + Ord + Clone + Send + Hash + std::fmt::Debug { type Ref<'r>: KeyRef<'r, Key = Self> + Copy where Self: 'r; diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 119af3b7..3a7c2b07 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -138,16 +138,6 @@ where } } -impl<'scan, R, FP> From> for ScanStream<'scan, R, FP> -where - R: Record, - FP: FileProvider, -{ - fn from(inner: LevelStream<'scan, R, FP>) -> Self { - ScanStream::Level { inner } - } -} - impl fmt::Debug for ScanStream<'_, R, FP> where R: Record, diff --git a/src/transaction.rs b/src/transaction.rs index 20fdd367..0d541687 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -1,35 +1,153 @@ -use std::{collections::BTreeMap, io, sync::Arc}; +use std::{ + collections::{btree_map::Entry, BTreeMap}, + io, + mem::transmute, +}; -use crate::{executor::Executor, oracle::Timestamp, Record, DB}; +use async_lock::RwLockReadGuard; +use parquet::errors::ParquetError; +use thiserror::Error; -pub struct Transaction +use crate::{ + fs::FileProvider, + oracle::{Oracle, Timestamp, WriteConflict}, + record::KeyRef, + stream, Record, Schema, +}; + +pub struct Transaction<'txn, R, FP> where R: Record, - E: Executor, { - db: Arc>, read_at: Timestamp, local: BTreeMap>, + share: RwLockReadGuard<'txn, Schema>, + oracle: &'txn Oracle, } -impl Transaction +impl<'txn, R, FP> Transaction<'txn, R, FP> where - R: Record, - E: Executor, + R: Record + Send, + FP: FileProvider, { - pub(crate) fn new(db: Arc>, read_at: Timestamp) -> Self { + pub(crate) fn new( + oracle: &'txn Oracle, + share: RwLockReadGuard<'txn, Schema>, + ) -> Self { Self { - db, - read_at, + read_at: oracle.start_read(), local: BTreeMap::new(), + share, + oracle, + } + } + + pub async fn get<'get>( + &'get self, + key: &'get R::Key, + ) -> Result>, ParquetError> { + Ok(match self.local.get(key).and_then(|v| v.as_ref()) { + Some(v) => Some(TransactionEntry::Local(v.as_record_ref())), + None => self + .share + .get(key, self.read_at) + .await? + .map(TransactionEntry::Stream), + }) + } + + pub fn set(&mut self, value: R) { + self.entry(value.key().to_key(), Some(value)) + } + + pub fn remove(&mut self, key: R::Key) { + self.entry(key, None) + } + + fn entry(&mut self, key: R::Key, value: Option) { + match self.local.entry(key) { + Entry::Vacant(v) => { + v.insert(value); + } + Entry::Occupied(mut o) => *o.get_mut() = value, + } + } + + pub async fn commit(self) -> Result<(), CommitError> { + self.oracle.read_commit(self.read_at); + if self.local.is_empty() { + return Ok(()); + } + let write_at = self.oracle.start_write(); + self.oracle + .write_commit(self.read_at, write_at, self.local.keys().cloned().collect())?; + + for (key, record) in self.local { + match record { + Some(record) => self.share.write(record, write_at).await?, + None => self.share.remove(key, write_at).await?, + } + } + Ok(()) + } +} + +pub enum TransactionEntry<'entry, R> +where + R: Record, +{ + Stream(stream::Entry<'entry, R>), + Local(R::Ref<'entry>), +} + +impl<'entry, R> TransactionEntry<'entry, R> +where + R: Record, +{ + pub fn get(&self) -> R::Ref<'_> { + match self { + TransactionEntry::Stream(entry) => entry.value(), + TransactionEntry::Local(value) => { + // Safety: shorter lifetime must be safe + unsafe { transmute::, R::Ref<'_>>(*value) } + } } } +} + +#[derive(Debug, Error)] +pub enum CommitError +where + R: Record, +{ + #[error("commit transaction error {:?}", .0)] + Io(#[from] io::Error), + #[error(transparent)] + WriteConflict(#[from] WriteConflict), +} + +#[cfg(test)] +mod tests { + use crate::{executor::tokio::TokioExecutor, DB}; - pub async fn get(&self, key: &R::Key) -> io::Result> { - // match self.local.get(key).and_then(|v| v.as_ref()) { - // Some(v) => Ok(Some(v)), - // None => self.db.get(key, self.read_at).await, - // } - todo!() + #[tokio::test] + async fn transaction_read_write() { + let db = DB::::default(); + { + let mut txn1 = db.transaction().await; + txn1.set("foo".to_string()); + + let txn2 = db.transaction().await; + dbg!(txn2.get(&"foo".to_string()).await.unwrap().is_none()); + + txn1.commit().await.unwrap(); + txn2.commit().await.unwrap(); + } + + { + let txn3 = db.transaction().await; + dbg!(txn3.get(&"foo".to_string()).await.unwrap().is_none()); + txn3.commit().await.unwrap(); + } } }