From 808bb089f45903c203de226174d7ab833f351e02 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Thu, 18 Jul 2024 17:16:40 +0800 Subject: [PATCH] chore: support transaction --- src/lib.rs | 69 ++++++++++-------- src/ondisk/scan.rs | 18 ++--- src/oracle/mod.rs | 13 ++-- src/oracle/timestamp.rs | 2 +- src/record/mod.rs | 4 +- src/stream/merge.rs | 19 ++--- src/stream/mod.rs | 30 ++++---- src/transaction.rs | 153 +++++++++++++++++++++++++++++++++++----- 8 files changed, 219 insertions(+), 89 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e40e5877..e64591dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,13 +16,15 @@ use std::{ collections::VecDeque, io, marker::PhantomData, mem, ops::Bound, path::PathBuf, sync::Arc, }; -use async_lock::{RwLock, RwLockReadGuard}; +use async_lock::RwLock; +use fs::FileProvider; use futures_core::Stream; use futures_util::StreamExt; use inmem::{immutable::Immutable, mutable::Mutable}; -use oracle::Timestamp; +use oracle::{Oracle, Timestamp}; use parquet::errors::ParquetError; use record::Record; +use transaction::Transaction; use crate::{ executor::Executor, @@ -42,15 +44,6 @@ pub struct DbOption { pub clean_channel_buffer: usize, } -pub struct DB -where - R: Record, - E: Executor, -{ - schema: Arc>>, - _p: PhantomData, -} - impl DbOption { pub fn new(path: impl Into + Send) -> Self { DbOption { @@ -90,6 +83,16 @@ impl DbOption { } } +pub struct DB +where + R: Record, + E: Executor, +{ + schema: Arc>>, + oracle: Oracle, + _p: PhantomData, +} + impl Default for DB where R: Record, @@ -99,6 +102,7 @@ where Self { schema: Arc::new(RwLock::new(Schema::default())), _p: Default::default(), + oracle: Oracle::default(), } } } @@ -109,15 +113,16 @@ where E: Executor, { pub fn empty() -> Self { - Self { - schema: Arc::new(RwLock::new(Schema::default())), - _p: Default::default(), - } + Self::default() + } + + pub async fn transaction(&self) -> Transaction<'_, R, E> { + Transaction::new(&self.oracle, self.schema.read().await) } pub(crate) async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> { - let columns = self.schema.read().await; - columns.write(record, ts).await + let schema = self.schema.read().await; + schema.write(record, ts).await } pub(crate) async fn write_batch( @@ -131,21 +136,18 @@ where } Ok(()) } - - pub(crate) async fn read(&self) -> RwLockReadGuard<'_, Schema> { - self.schema.read().await - } } -pub(crate) struct Schema +pub(crate) struct Schema where R: Record, { mutable: Mutable, immutables: VecDeque>, + _marker: PhantomData, } -impl Default for Schema +impl Default for Schema where R: Record, { @@ -153,44 +155,51 @@ where Self { mutable: Mutable::default(), immutables: VecDeque::default(), + _marker: Default::default(), } } } -impl Schema +impl Schema where R: Record + Send, + FP: FileProvider, { async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> { self.mutable.insert(record, ts); Ok(()) } - async fn get<'get, E>( + async fn remove(&self, key: R::Key, ts: Timestamp) -> io::Result<()> { + self.mutable.remove(key, ts); + Ok(()) + } + + async fn get<'get>( &'get self, key: &'get R::Key, ts: Timestamp, ) -> Result>, ParquetError> where - E: Executor, + FP: FileProvider, { - self.scan::(Bound::Included(key), Bound::Unbounded, ts) + self.scan(Bound::Included(key), Bound::Unbounded, ts) .await? .next() .await .transpose() } - async fn scan<'scan, E>( + async fn scan<'scan>( &'scan self, lower: Bound<&'scan R::Key>, uppwer: Bound<&'scan R::Key>, ts: Timestamp, ) -> Result, ParquetError>>, ParquetError> where - E: Executor, + FP: FileProvider, { - let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); + let mut streams = Vec::>::with_capacity(self.immutables.len() + 1); streams.push(self.mutable.scan((lower, uppwer), ts).into()); for immutable in &self.immutables { streams.push(immutable.scan((lower, uppwer), ts).into()); diff --git a/src/ondisk/scan.rs b/src/ondisk/scan.rs index 4616e738..1e4de7a8 100644 --- a/src/ondisk/scan.rs +++ b/src/ondisk/scan.rs @@ -8,37 +8,37 @@ use parquet::arrow::async_reader::ParquetRecordBatchStream; use pin_project_lite::pin_project; use tokio_util::compat::Compat; +use crate::fs::FileProvider; use crate::{ - executor::Executor, record::Record, stream::record_batch::{RecordBatchEntry, RecordBatchIterator}, }; pin_project! { #[derive(Debug)] - pub struct SsTableScan + pub struct SsTableScan where - E: Executor, + FP: FileProvider, { #[pin] - stream: ParquetRecordBatchStream>, + stream: ParquetRecordBatchStream>, iter: Option>, } } -impl SsTableScan +impl SsTableScan where - E: Executor, + FP: FileProvider, { - pub fn new(stream: ParquetRecordBatchStream>) -> Self { + pub fn new(stream: ParquetRecordBatchStream>) -> Self { SsTableScan { stream, iter: None } } } -impl Stream for SsTableScan +impl Stream for SsTableScan where R: Record, - E: Executor, + FP: FileProvider, { type Item = Result, parquet::errors::ParquetError>; diff --git a/src/oracle/mod.rs b/src/oracle/mod.rs index 28011c80..3c6280a8 100644 --- a/src/oracle/mod.rs +++ b/src/oracle/mod.rs @@ -61,7 +61,7 @@ impl Oracle where K: Eq + Hash + Clone, { - fn start_read(&self) -> Timestamp { + pub(crate) fn start_read(&self) -> Timestamp { let mut in_read = self.in_read.lock().unwrap(); let now = self.now.load(Ordering::Relaxed).into(); match in_read.entry(now) { @@ -75,7 +75,7 @@ where now } - fn read_commit(&self, ts: Timestamp) { + pub(crate) fn read_commit(&self, ts: Timestamp) { match self.in_read.lock().unwrap().entry(ts) { Entry::Vacant(_) => panic!("commit non-existing read"), Entry::Occupied(mut o) => match o.get_mut() { @@ -89,27 +89,28 @@ where } } - fn start_write(&self) -> Timestamp { + pub(crate) fn start_write(&self) -> Timestamp { (self.now.fetch_add(1, Ordering::Relaxed) + 1).into() } - fn write_commit( + pub(crate) fn write_commit( &self, read_at: Timestamp, write_at: Timestamp, in_write: HashSet, ) -> Result<(), WriteConflict> { let mut committed_txns = self.committed_txns.lock().unwrap(); - let conflicts: Vec<_> = committed_txns + let conflicts = committed_txns .range((Bound::Excluded(read_at), Bound::Excluded(write_at))) .flat_map(|(_, txn)| txn.intersection(&in_write)) .cloned() - .collect(); + .collect::>(); if !conflicts.is_empty() { return Err(WriteConflict { keys: conflicts }); } + // TODO: clean committed transactions committed_txns.insert(write_at, in_write); Ok(()) } diff --git a/src/oracle/timestamp.rs b/src/oracle/timestamp.rs index d26c2d2c..ca50c98a 100644 --- a/src/oracle/timestamp.rs +++ b/src/oracle/timestamp.rs @@ -3,7 +3,7 @@ use std::{borrow::Borrow, cmp::Ordering, marker::PhantomData, mem::transmute}; use crate::oracle::Timestamp; #[derive(PartialEq, Eq, Debug, Clone)] -pub(crate) struct Timestamped { +pub struct Timestamped { pub(crate) ts: Timestamp, pub(crate) value: V, } diff --git a/src/record/mod.rs b/src/record/mod.rs index fa805c19..ea4172e4 100644 --- a/src/record/mod.rs +++ b/src/record/mod.rs @@ -1,7 +1,7 @@ pub(crate) mod internal; mod str; -use std::sync::Arc; +use std::{hash::Hash, sync::Arc}; use arrow::{ array::{Datum, RecordBatch}, @@ -14,7 +14,7 @@ use crate::{ serdes::{Decode, Encode}, }; -pub trait Key: 'static + Encode + Decode + Ord + Clone + Send { +pub trait Key: 'static + Encode + Decode + Ord + Clone + Send + Hash + std::fmt::Debug { type Ref<'r>: KeyRef<'r, Key = Self> + Copy where Self: 'r; diff --git a/src/stream/merge.rs b/src/stream/merge.rs index 2734572b..3162e553 100644 --- a/src/stream/merge.rs +++ b/src/stream/merge.rs @@ -10,27 +10,28 @@ use futures_util::stream::StreamExt; use pin_project_lite::pin_project; use super::{Entry, ScanStream}; -use crate::{executor::Executor, record::Record}; +use crate::fs::FileProvider; +use crate::record::Record; pin_project! { - pub(crate) struct MergeStream<'merge, R, E> + pub(crate) struct MergeStream<'merge, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { - streams: Vec>, + streams: Vec>, peeked: BinaryHeap>, buf: Option>, } } -impl<'merge, R, E> MergeStream<'merge, R, E> +impl<'merge, R, FP> MergeStream<'merge, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { pub(crate) async fn from_vec( - mut streams: Vec>, + mut streams: Vec>, ) -> Result { let mut peeked = BinaryHeap::with_capacity(streams.len()); @@ -51,10 +52,10 @@ where } } -impl<'merge, R, E> Stream for MergeStream<'merge, R, E> +impl<'merge, R, FP> Stream for MergeStream<'merge, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { type Item = Result, parquet::errors::ParquetError>; diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 733a252d..0f740474 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -14,8 +14,8 @@ use futures_util::{ready, stream}; use pin_project_lite::pin_project; use record_batch::RecordBatchEntry; +use crate::fs::FileProvider; use crate::{ - executor::Executor, inmem::{immutable::ImmutableScan, mutable::MutableScan}, ondisk::scan::SsTableScan, oracle::timestamp::Timestamped, @@ -75,10 +75,10 @@ where pin_project! { #[project = ScanStreamProject] - pub enum ScanStream<'scan, R, E> + pub enum ScanStream<'scan, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { Mutable { #[pin] @@ -90,15 +90,15 @@ pin_project! { }, SsTable { #[pin] - inner: SsTableScan, + inner: SsTableScan, }, } } -impl<'scan, R, E> From> for ScanStream<'scan, R, E> +impl<'scan, R, FP> From> for ScanStream<'scan, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { fn from(inner: MutableScan<'scan, R>) -> Self { ScanStream::Mutable { @@ -107,10 +107,10 @@ where } } -impl<'scan, R, E> From> for ScanStream<'scan, R, E> +impl<'scan, R, FP> From> for ScanStream<'scan, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { fn from(inner: ImmutableScan<'scan, R>) -> Self { ScanStream::Immutable { @@ -119,20 +119,20 @@ where } } -impl<'scan, R, E> From> for ScanStream<'scan, R, E> +impl<'scan, R, FP> From> for ScanStream<'scan, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { - fn from(inner: SsTableScan) -> Self { + fn from(inner: SsTableScan) -> Self { ScanStream::SsTable { inner } } } -impl fmt::Debug for ScanStream<'_, R, E> +impl fmt::Debug for ScanStream<'_, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { @@ -143,10 +143,10 @@ where } } -impl<'scan, R, E> Stream for ScanStream<'scan, R, E> +impl<'scan, R, FP> Stream for ScanStream<'scan, R, FP> where R: Record, - E: Executor, + FP: FileProvider, { type Item = Result, parquet::errors::ParquetError>; diff --git a/src/transaction.rs b/src/transaction.rs index 20fdd367..356c99d8 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -1,35 +1,154 @@ -use std::{collections::BTreeMap, io, sync::Arc}; +use std::{ + collections::{btree_map::Entry, BTreeMap}, + io, + mem::transmute, +}; -use crate::{executor::Executor, oracle::Timestamp, Record, DB}; +use async_lock::RwLockReadGuard; +use parquet::errors::ParquetError; +use thiserror::Error; -pub struct Transaction +use crate::fs::FileProvider; +use crate::oracle::Oracle; +use crate::{ + oracle::{Timestamp, WriteConflict}, + record::KeyRef, + stream, Record, Schema, +}; + +pub struct Transaction<'txn, R, FP> where R: Record, - E: Executor, { - db: Arc>, read_at: Timestamp, local: BTreeMap>, + share: RwLockReadGuard<'txn, Schema>, + oracle: &'txn Oracle, } -impl Transaction +impl<'txn, R, FP> Transaction<'txn, R, FP> where - R: Record, - E: Executor, + R: Record + Send, + FP: FileProvider, { - pub(crate) fn new(db: Arc>, read_at: Timestamp) -> Self { + pub(crate) fn new( + oracle: &'txn Oracle, + share: RwLockReadGuard<'txn, Schema>, + ) -> Self { Self { - db, - read_at, + read_at: oracle.start_read(), local: BTreeMap::new(), + share, + oracle, + } + } + + pub async fn get<'get>( + &'get self, + key: &'get R::Key, + ) -> Result>, ParquetError> { + Ok(match self.local.get(key).and_then(|v| v.as_ref()) { + Some(v) => Some(TransactionEntry::Local(v.as_record_ref())), + None => self + .share + .get(key, self.read_at) + .await? + .map(TransactionEntry::Stream), + }) + } + + pub fn set(&mut self, value: R) { + self.entry(value.key().to_key(), Some(value)) + } + + pub fn remove(&mut self, key: R::Key) { + self.entry(key, None) + } + + fn entry(&mut self, key: R::Key, value: Option) { + match self.local.entry(key) { + Entry::Vacant(v) => { + v.insert(value); + } + Entry::Occupied(mut o) => *o.get_mut() = value, + } + } + + pub async fn commit(self) -> Result<(), CommitError> { + self.oracle.read_commit(self.read_at); + if self.local.is_empty() { + return Ok(()); + } + let write_at = self.oracle.start_write(); + self.oracle + .write_commit(self.read_at, write_at, self.local.keys().cloned().collect())?; + + for (key, record) in self.local { + match record { + Some(record) => self.share.write(record, write_at).await?, + None => self.share.remove(key, write_at).await?, + } + } + Ok(()) + } +} + +pub enum TransactionEntry<'entry, R> +where + R: Record, +{ + Stream(stream::Entry<'entry, R>), + Local(R::Ref<'entry>), +} + +impl<'entry, R> TransactionEntry<'entry, R> +where + R: Record, +{ + pub fn get(&self) -> R::Ref<'_> { + match self { + TransactionEntry::Stream(entry) => entry.value(), + TransactionEntry::Local(value) => { + // Safety: shorter lifetime must be safe + unsafe { transmute::, R::Ref<'_>>(*value) } + } } } +} + +#[derive(Debug, Error)] +pub enum CommitError +where + R: Record, +{ + #[error("commit transaction error {:?}", .0)] + Io(#[from] io::Error), + #[error(transparent)] + WriteConflict(#[from] WriteConflict), +} + +#[cfg(test)] +mod tests { + use crate::{executor::tokio::TokioExecutor, DB}; - pub async fn get(&self, key: &R::Key) -> io::Result> { - // match self.local.get(key).and_then(|v| v.as_ref()) { - // Some(v) => Ok(Some(v)), - // None => self.db.get(key, self.read_at).await, - // } - todo!() + #[tokio::test] + async fn transaction_read_write() { + let db = DB::::default(); + { + let mut txn1 = db.transaction().await; + txn1.set("foo".to_string()); + + let txn2 = db.transaction().await; + dbg!(txn2.get(&"foo".to_string()).await.unwrap().is_none()); + + txn1.commit().await.unwrap(); + txn2.commit().await.unwrap(); + } + + { + let txn3 = db.transaction().await; + dbg!(txn3.get(&"foo".to_string()).await.unwrap().is_none()); + txn3.commit().await.unwrap(); + } } }