diff --git a/src/lib.rs b/src/lib.rs index e40e5877..6ecdf53f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,13 +16,14 @@ use std::{ collections::VecDeque, io, marker::PhantomData, mem, ops::Bound, path::PathBuf, sync::Arc, }; -use async_lock::{RwLock, RwLockReadGuard}; +use async_lock::RwLock; use futures_core::Stream; use futures_util::StreamExt; use inmem::{immutable::Immutable, mutable::Mutable}; -use oracle::Timestamp; +use oracle::{Oracle, Timestamp}; use parquet::errors::ParquetError; use record::Record; +use transaction::Transaction; use crate::{ executor::Executor, @@ -42,15 +43,6 @@ pub struct DbOption { pub clean_channel_buffer: usize, } -pub struct DB -where - R: Record, - E: Executor, -{ - schema: Arc>>, - _p: PhantomData, -} - impl DbOption { pub fn new(path: impl Into + Send) -> Self { DbOption { @@ -90,6 +82,16 @@ impl DbOption { } } +pub struct DB +where + R: Record, + E: Executor, +{ + schema: Arc>>, + oracle: Oracle, + _p: PhantomData, +} + impl Default for DB where R: Record, @@ -99,6 +101,7 @@ where Self { schema: Arc::new(RwLock::new(Schema::default())), _p: Default::default(), + oracle: Oracle::default(), } } } @@ -109,15 +112,16 @@ where E: Executor, { pub fn empty() -> Self { - Self { - schema: Arc::new(RwLock::new(Schema::default())), - _p: Default::default(), - } + Self::default() + } + + pub async fn transaction(&self) -> Transaction<'_, R, E> { + Transaction::new(self, self.schema.read().await) } pub(crate) async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> { - let columns = self.schema.read().await; - columns.write(record, ts).await + let schema = self.schema.read().await; + schema.write(record, ts).await } pub(crate) async fn write_batch( @@ -131,10 +135,6 @@ where } Ok(()) } - - pub(crate) async fn read(&self) -> RwLockReadGuard<'_, Schema> { - self.schema.read().await - } } pub(crate) struct Schema @@ -166,6 +166,11 @@ where Ok(()) } + async fn remove(&self, key: R::Key, ts: Timestamp) -> io::Result<()> { + self.mutable.remove(key, ts); + Ok(()) + } + async fn get<'get, E>( &'get self, key: &'get R::Key, diff --git a/src/oracle/mod.rs b/src/oracle/mod.rs index 28011c80..3c6280a8 100644 --- a/src/oracle/mod.rs +++ b/src/oracle/mod.rs @@ -61,7 +61,7 @@ impl Oracle where K: Eq + Hash + Clone, { - fn start_read(&self) -> Timestamp { + pub(crate) fn start_read(&self) -> Timestamp { let mut in_read = self.in_read.lock().unwrap(); let now = self.now.load(Ordering::Relaxed).into(); match in_read.entry(now) { @@ -75,7 +75,7 @@ where now } - fn read_commit(&self, ts: Timestamp) { + pub(crate) fn read_commit(&self, ts: Timestamp) { match self.in_read.lock().unwrap().entry(ts) { Entry::Vacant(_) => panic!("commit non-existing read"), Entry::Occupied(mut o) => match o.get_mut() { @@ -89,27 +89,28 @@ where } } - fn start_write(&self) -> Timestamp { + pub(crate) fn start_write(&self) -> Timestamp { (self.now.fetch_add(1, Ordering::Relaxed) + 1).into() } - fn write_commit( + pub(crate) fn write_commit( &self, read_at: Timestamp, write_at: Timestamp, in_write: HashSet, ) -> Result<(), WriteConflict> { let mut committed_txns = self.committed_txns.lock().unwrap(); - let conflicts: Vec<_> = committed_txns + let conflicts = committed_txns .range((Bound::Excluded(read_at), Bound::Excluded(write_at))) .flat_map(|(_, txn)| txn.intersection(&in_write)) .cloned() - .collect(); + .collect::>(); if !conflicts.is_empty() { return Err(WriteConflict { keys: conflicts }); } + // TODO: clean committed transactions committed_txns.insert(write_at, in_write); Ok(()) } diff --git a/src/oracle/timestamp.rs b/src/oracle/timestamp.rs index d26c2d2c..ca50c98a 100644 --- a/src/oracle/timestamp.rs +++ b/src/oracle/timestamp.rs @@ -3,7 +3,7 @@ use std::{borrow::Borrow, cmp::Ordering, marker::PhantomData, mem::transmute}; use crate::oracle::Timestamp; #[derive(PartialEq, Eq, Debug, Clone)] -pub(crate) struct Timestamped { +pub struct Timestamped { pub(crate) ts: Timestamp, pub(crate) value: V, } diff --git a/src/record/mod.rs b/src/record/mod.rs index fa805c19..ea4172e4 100644 --- a/src/record/mod.rs +++ b/src/record/mod.rs @@ -1,7 +1,7 @@ pub(crate) mod internal; mod str; -use std::sync::Arc; +use std::{hash::Hash, sync::Arc}; use arrow::{ array::{Datum, RecordBatch}, @@ -14,7 +14,7 @@ use crate::{ serdes::{Decode, Encode}, }; -pub trait Key: 'static + Encode + Decode + Ord + Clone + Send { +pub trait Key: 'static + Encode + Decode + Ord + Clone + Send + Hash + std::fmt::Debug { type Ref<'r>: KeyRef<'r, Key = Self> + Copy where Self: 'r; diff --git a/src/transaction.rs b/src/transaction.rs index 20fdd367..0d99fa70 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -1,35 +1,154 @@ -use std::{collections::BTreeMap, io, sync::Arc}; +use std::{ + collections::{btree_map::Entry, BTreeMap}, + io, + mem::transmute, +}; -use crate::{executor::Executor, oracle::Timestamp, Record, DB}; +use async_lock::RwLockReadGuard; +use parquet::errors::ParquetError; +use thiserror::Error; -pub struct Transaction +use crate::{ + executor::Executor, + oracle::{Timestamp, WriteConflict}, + record::KeyRef, + stream, Record, Schema, DB, +}; + +pub struct Transaction<'txn, R, E> where R: Record, E: Executor, { - db: Arc>, read_at: Timestamp, local: BTreeMap>, + share: RwLockReadGuard<'txn, Schema>, + db: &'txn DB, } -impl Transaction +impl<'txn, R, E> Transaction<'txn, R, E> where - R: Record, + R: Record + Send, E: Executor, { - pub(crate) fn new(db: Arc>, read_at: Timestamp) -> Self { + pub(crate) fn new(db: &'txn DB, share: RwLockReadGuard<'txn, Schema>) -> Self { Self { - db, - read_at, + read_at: db.oracle.start_read(), local: BTreeMap::new(), + share, + db, + } + } + + pub async fn get<'get>( + &'get self, + key: &'get R::Key, + ) -> Result>, ParquetError> { + Ok(match self.local.get(key).and_then(|v| v.as_ref()) { + Some(v) => Some(TransactionEntry::Local(v.as_record_ref())), + None => self + .share + .get::(key, self.read_at) + .await? + .map(TransactionEntry::Stream), + }) + } + + pub fn set(&mut self, value: R) { + self.entry(value.key().to_key(), Some(value)) + } + + pub fn remove(&mut self, key: R::Key) { + self.entry(key, None) + } + + fn entry(&mut self, key: R::Key, value: Option) { + match self.local.entry(key) { + Entry::Vacant(v) => { + v.insert(value); + } + Entry::Occupied(mut o) => *o.get_mut() = value, + } + } + + pub async fn commit(self) -> Result<(), CommitError> { + self.db.oracle.read_commit(self.read_at); + if self.local.is_empty() { + return Ok(()); + } + let write_at = self.db.oracle.start_write(); + self.db.oracle.write_commit( + self.read_at, + write_at, + self.local.keys().cloned().collect(), + )?; + + for (key, record) in self.local { + match record { + Some(record) => self.share.write(record, write_at).await?, + None => self.share.remove(key, write_at).await?, + } + } + Ok(()) + } +} + +pub enum TransactionEntry<'entry, R> +where + R: Record, +{ + Stream(stream::Entry<'entry, R>), + Local(R::Ref<'entry>), +} + +impl<'entry, R> TransactionEntry<'entry, R> +where + R: Record, +{ + pub fn get(&self) -> R::Ref<'_> { + match self { + TransactionEntry::Stream(entry) => entry.value(), + TransactionEntry::Local(value) => { + // Safety: shorter lifetime must be safe + unsafe { transmute::, R::Ref<'_>>(*value) } + } } } +} - pub async fn get(&self, key: &R::Key) -> io::Result> { - // match self.local.get(key).and_then(|v| v.as_ref()) { - // Some(v) => Ok(Some(v)), - // None => self.db.get(key, self.read_at).await, - // } - todo!() +#[derive(Debug, Error)] +pub enum CommitError +where + R: Record, +{ + #[error("commit transaction error {:?}", .0)] + Io(#[from] io::Error), + #[error(transparent)] + WriteConflict(#[from] WriteConflict), +} + +#[cfg(test)] +mod tests { + use crate::{executor::tokio::TokioExecutor, DB}; + + #[tokio::test] + async fn transaction_read_write() { + let db = DB::::default(); + { + let mut txn1 = db.transaction().await; + txn1.set("foo".to_string()); + + let txn2 = db.transaction().await; + dbg!(txn2.get(&"foo".to_string()).await.unwrap().is_none()); + + txn1.commit().await.unwrap(); + txn2.commit().await.unwrap(); + } + + { + let txn3 = db.transaction().await; + dbg!(txn3.get(&"foo".to_string()).await.unwrap().is_none()); + txn3.commit().await.unwrap(); + } } }