Skip to content

Commit

Permalink
chore: support transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
ethe committed Jul 18, 2024
1 parent 6cd34f1 commit 83c1281
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 45 deletions.
47 changes: 26 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ use std::{
collections::VecDeque, io, marker::PhantomData, mem, ops::Bound, path::PathBuf, sync::Arc,
};

use async_lock::{RwLock, RwLockReadGuard};
use async_lock::RwLock;
use futures_core::Stream;
use futures_util::StreamExt;
use inmem::{immutable::Immutable, mutable::Mutable};
use oracle::Timestamp;
use oracle::{Oracle, Timestamp};
use parquet::errors::ParquetError;
use record::Record;
use transaction::Transaction;

use crate::{
executor::Executor,
Expand All @@ -42,15 +43,6 @@ pub struct DbOption {
pub clean_channel_buffer: usize,
}

pub struct DB<R, E>
where
R: Record,
E: Executor,
{
schema: Arc<RwLock<Schema<R>>>,
_p: PhantomData<E>,
}

impl DbOption {
pub fn new(path: impl Into<PathBuf> + Send) -> Self {
DbOption {
Expand Down Expand Up @@ -90,6 +82,16 @@ impl DbOption {
}
}

pub struct DB<R, E>
where
R: Record,
E: Executor,
{
schema: Arc<RwLock<Schema<R>>>,
oracle: Oracle<R::Key>,
_p: PhantomData<E>,
}

impl<R, E> Default for DB<R, E>
where
R: Record,
Expand All @@ -99,6 +101,7 @@ where
Self {
schema: Arc::new(RwLock::new(Schema::default())),
_p: Default::default(),
oracle: Oracle::default(),
}
}
}
Expand All @@ -109,15 +112,16 @@ where
E: Executor,
{
pub fn empty() -> Self {
Self {
schema: Arc::new(RwLock::new(Schema::default())),
_p: Default::default(),
}
Self::default()
}

pub async fn transaction(&self) -> Transaction<'_, R, E> {
Transaction::new(self, self.schema.read().await)
}

pub(crate) async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> {
let columns = self.schema.read().await;
columns.write(record, ts).await
let schema = self.schema.read().await;
schema.write(record, ts).await
}

pub(crate) async fn write_batch(
Expand All @@ -131,10 +135,6 @@ where
}
Ok(())
}

pub(crate) async fn read(&self) -> RwLockReadGuard<'_, Schema<R>> {
self.schema.read().await
}
}

pub(crate) struct Schema<R>
Expand Down Expand Up @@ -166,6 +166,11 @@ where
Ok(())
}

async fn remove(&self, key: R::Key, ts: Timestamp) -> io::Result<()> {
self.mutable.remove(key, ts);
Ok(())
}

async fn get<'get, E>(
&'get self,
key: &'get R::Key,
Expand Down
13 changes: 7 additions & 6 deletions src/oracle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<K> Oracle<K>
where
K: Eq + Hash + Clone,
{
fn start_read(&self) -> Timestamp {
pub(crate) fn start_read(&self) -> Timestamp {
let mut in_read = self.in_read.lock().unwrap();
let now = self.now.load(Ordering::Relaxed).into();
match in_read.entry(now) {
Expand All @@ -75,7 +75,7 @@ where
now
}

fn read_commit(&self, ts: Timestamp) {
pub(crate) fn read_commit(&self, ts: Timestamp) {
match self.in_read.lock().unwrap().entry(ts) {
Entry::Vacant(_) => panic!("commit non-existing read"),
Entry::Occupied(mut o) => match o.get_mut() {
Expand All @@ -89,27 +89,28 @@ where
}
}

fn start_write(&self) -> Timestamp {
pub(crate) fn start_write(&self) -> Timestamp {
(self.now.fetch_add(1, Ordering::Relaxed) + 1).into()
}

fn write_commit(
pub(crate) fn write_commit(
&self,
read_at: Timestamp,
write_at: Timestamp,
in_write: HashSet<K>,
) -> Result<(), WriteConflict<K>> {
let mut committed_txns = self.committed_txns.lock().unwrap();
let conflicts: Vec<_> = committed_txns
let conflicts = committed_txns
.range((Bound::Excluded(read_at), Bound::Excluded(write_at)))
.flat_map(|(_, txn)| txn.intersection(&in_write))
.cloned()
.collect();
.collect::<Vec<_>>();

if !conflicts.is_empty() {
return Err(WriteConflict { keys: conflicts });
}

// TODO: clean committed transactions
committed_txns.insert(write_at, in_write);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/oracle/timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{borrow::Borrow, cmp::Ordering, marker::PhantomData, mem::transmute};
use crate::oracle::Timestamp;

#[derive(PartialEq, Eq, Debug, Clone)]
pub(crate) struct Timestamped<V> {
pub struct Timestamped<V> {
pub(crate) ts: Timestamp,
pub(crate) value: V,
}
Expand Down
4 changes: 2 additions & 2 deletions src/record/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub(crate) mod internal;
mod str;

use std::sync::Arc;
use std::{hash::Hash, sync::Arc};

use arrow::{
array::{Datum, RecordBatch},
Expand All @@ -14,7 +14,7 @@ use crate::{
serdes::{Decode, Encode},
};

pub trait Key: 'static + Encode + Decode + Ord + Clone + Send {
pub trait Key: 'static + Encode + Decode + Ord + Clone + Send + Hash + std::fmt::Debug {
type Ref<'r>: KeyRef<'r, Key = Self> + Copy
where
Self: 'r;
Expand Down
149 changes: 134 additions & 15 deletions src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,154 @@
use std::{collections::BTreeMap, io, sync::Arc};
use std::{
collections::{btree_map::Entry, BTreeMap},
io,
mem::transmute,
};

use crate::{executor::Executor, oracle::Timestamp, Record, DB};
use async_lock::RwLockReadGuard;
use parquet::errors::ParquetError;
use thiserror::Error;

pub struct Transaction<R, E>
use crate::{
executor::Executor,
oracle::{Timestamp, WriteConflict},
record::KeyRef,
stream, Record, Schema, DB,
};

pub struct Transaction<'txn, R, E>
where
R: Record,
E: Executor,
{
db: Arc<DB<R, E>>,
read_at: Timestamp,
local: BTreeMap<R::Key, Option<R>>,
share: RwLockReadGuard<'txn, Schema<R>>,
db: &'txn DB<R, E>,
}

impl<R, E> Transaction<R, E>
impl<'txn, R, E> Transaction<'txn, R, E>
where
R: Record,
R: Record + Send,
E: Executor,
{
pub(crate) fn new(db: Arc<DB<R, E>>, read_at: Timestamp) -> Self {
pub(crate) fn new(db: &'txn DB<R, E>, share: RwLockReadGuard<'txn, Schema<R>>) -> Self {
Self {
db,
read_at,
read_at: db.oracle.start_read(),
local: BTreeMap::new(),
share,
db,
}
}

pub async fn get<'get>(
&'get self,
key: &'get R::Key,
) -> Result<Option<TransactionEntry<'get, R>>, ParquetError> {
Ok(match self.local.get(key).and_then(|v| v.as_ref()) {
Some(v) => Some(TransactionEntry::Local(v.as_record_ref())),
None => self
.share
.get::<E>(key, self.read_at)
.await?
.map(TransactionEntry::Stream),
})
}

pub fn set(&mut self, value: R) {
self.entry(value.key().to_key(), Some(value))
}

pub fn remove(&mut self, key: R::Key) {
self.entry(key, None)
}

fn entry(&mut self, key: R::Key, value: Option<R>) {
match self.local.entry(key) {
Entry::Vacant(v) => {
v.insert(value);
}
Entry::Occupied(mut o) => *o.get_mut() = value,
}
}

pub async fn commit(self) -> Result<(), CommitError<R>> {
self.db.oracle.read_commit(self.read_at);
if self.local.is_empty() {
return Ok(());
}
let write_at = self.db.oracle.start_write();
self.db.oracle.write_commit(
self.read_at,
write_at,
self.local.keys().cloned().collect(),
)?;

for (key, record) in self.local {
match record {
Some(record) => self.share.write(record, write_at).await?,
None => self.share.remove(key, write_at).await?,
}
}
Ok(())
}
}

pub enum TransactionEntry<'entry, R>
where
R: Record,
{
Stream(stream::Entry<'entry, R>),
Local(R::Ref<'entry>),
}

impl<'entry, R> TransactionEntry<'entry, R>
where
R: Record,
{
pub fn get(&self) -> R::Ref<'_> {
match self {
TransactionEntry::Stream(entry) => entry.value(),
TransactionEntry::Local(value) => {
// Safety: shorter lifetime must be safe
unsafe { transmute::<R::Ref<'entry>, R::Ref<'_>>(*value) }
}
}
}
}

pub async fn get(&self, key: &R::Key) -> io::Result<Option<&R>> {
// match self.local.get(key).and_then(|v| v.as_ref()) {
// Some(v) => Ok(Some(v)),
// None => self.db.get(key, self.read_at).await,
// }
todo!()
#[derive(Debug, Error)]
pub enum CommitError<R>
where
R: Record,
{
#[error("commit transaction error {:?}", .0)]
Io(#[from] io::Error),
#[error(transparent)]
WriteConflict(#[from] WriteConflict<R::Key>),
}

#[cfg(test)]
mod tests {
use crate::{executor::tokio::TokioExecutor, DB};

#[tokio::test]
async fn transaction_read_write() {
let db = DB::<String, TokioExecutor>::default();
{
let mut txn1 = db.transaction().await;
txn1.set("foo".to_string());

let txn2 = db.transaction().await;
dbg!(txn2.get(&"foo".to_string()).await.unwrap().is_none());

txn1.commit().await.unwrap();
txn2.commit().await.unwrap();
}

{
let txn3 = db.transaction().await;
dbg!(txn3.get(&"foo".to_string()).await.unwrap().is_none());
txn3.commit().await.unwrap();
}
}
}

0 comments on commit 83c1281

Please sign in to comment.