Skip to content

Commit

Permalink
Improve type safety of persistent savepoint tables
Browse files Browse the repository at this point in the history
  • Loading branch information
cberner committed May 30, 2023
1 parent a171ef5 commit 68aa714
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 100 deletions.
25 changes: 13 additions & 12 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::transaction_tracker::{SavepointId, TransactionId, TransactionTracker};
use crate::tree_store::{
AllPageNumbersBtreeIter, BtreeRangeIter, Checksum, FreedPageList, FreedTableKey,
InternalTableDefinition, PageHint, PageNumber, RawBtree, TableTree, TableType,
TransactionalMemory, PAGE_SIZE,
InternalTableDefinition, PageHint, PageNumber, RawBtree, SerializedSavepoint, TableTree,
TableType, TransactionalMemory, PAGE_SIZE,
};
use crate::types::{RedbKey, RedbValue};
use crate::{
CompactionError, DatabaseError, Durability, ReadOnlyTable, ReadableTable, Savepoint,
SavepointError, StorageError,
CompactionError, DatabaseError, Durability, ReadOnlyTable, ReadableTable, SavepointError,
StorageError,
};
use crate::{ReadTransaction, Result, WriteTransaction};
use std::fmt::{Display, Formatter};
Expand Down Expand Up @@ -419,20 +419,21 @@ impl Database {
let table_tree = TableTree::new(system_root, mem, freed_list);
let fake_transaction_tracker = Arc::new(Mutex::new(TransactionTracker::new()));
if let Some(savepoint_table_def) = table_tree
.get_table::<u64, &[u8]>(SAVEPOINT_TABLE.name(), TableType::Normal)
.get_table::<SavepointId, SerializedSavepoint>(
SAVEPOINT_TABLE.name(),
TableType::Normal,
)
.map_err(|e| {
e.into_storage_error_or_corrupted("Persistent savepoint table corrupted")
})?
{
let savepoint_table: ReadOnlyTable<u64, &[u8]> =
let savepoint_table: ReadOnlyTable<SavepointId, SerializedSavepoint> =
ReadOnlyTable::new(savepoint_table_def.get_root(), PageHint::None, mem)?;
for result in savepoint_table.range::<u64>(..)? {
for result in savepoint_table.range::<SavepointId>(..)? {
let (_, savepoint_data) = result?;
let savepoint = Savepoint::from_bytes(
savepoint_data.value(),
fake_transaction_tracker.clone(),
false,
);
let savepoint = savepoint_data
.value()
.to_savepoint(fake_transaction_tracker.clone());
if let Some((root, _)) = savepoint.get_user_root() {
Self::mark_tables_recursive(root, mem, true)?;
}
Expand Down
42 changes: 39 additions & 3 deletions src/transaction_tracker.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::Savepoint;
use crate::{RedbKey, RedbValue, Savepoint, TypeName};
use std::cmp::Ordering;
use std::collections::btree_map::BTreeMap;
use std::collections::btree_set::BTreeSet;
use std::mem::size_of;

#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub(crate) struct TransactionId(pub u64);
Expand All @@ -19,15 +21,49 @@ impl TransactionId {
}
}

#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub(crate) struct SavepointId(pub u64);

impl SavepointId {
fn next(&self) -> SavepointId {
pub(crate) fn next(&self) -> SavepointId {
SavepointId(self.0 + 1)
}
}

impl RedbValue for SavepointId {
type SelfType<'a> = SavepointId;
type AsBytes<'a> = [u8; size_of::<u64>()];

fn fixed_width() -> Option<usize> {
Some(size_of::<u64>())
}

fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
SavepointId(u64::from_le_bytes(data.try_into().unwrap()))
}

fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
value.0.to_le_bytes()
}

fn type_name() -> TypeName {
TypeName::internal("redb::SavepointId")
}
}

impl RedbKey for SavepointId {
fn compare(data1: &[u8], data2: &[u8]) -> Ordering {
Self::from_bytes(data1).0.cmp(&Self::from_bytes(data2).0)
}
}

pub(crate) struct TransactionTracker {
next_savepoint_id: SavepointId,
// reference count of read transactions per transaction id
Expand Down
38 changes: 21 additions & 17 deletions src/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::sealed::Sealed;
use crate::transaction_tracker::{SavepointId, TransactionId, TransactionTracker};
use crate::tree_store::{
Btree, BtreeMut, FreedPageList, FreedTableKey, InternalTableDefinition, PageHint, PageNumber,
TableTree, TableType, TransactionalMemory,
SerializedSavepoint, TableTree, TableType, TransactionalMemory,
};
use crate::types::{RedbKey, RedbValue};
use crate::{
Expand All @@ -22,9 +22,9 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock};
use std::{panic, thread};

const NEXT_SAVEPOINT_TABLE: SystemTableDefinition<(), u64> =
const NEXT_SAVEPOINT_TABLE: SystemTableDefinition<(), SavepointId> =
SystemTableDefinition::new("next_savepoint_id");
pub(crate) const SAVEPOINT_TABLE: SystemTableDefinition<u64, &[u8]> =
pub(crate) const SAVEPOINT_TABLE: SystemTableDefinition<SavepointId, SerializedSavepoint> =
SystemTableDefinition::new("persistent_savepoints");

pub struct SystemTableDefinition<'a, K: RedbKey + 'static, V: RedbValue + 'static> {
Expand Down Expand Up @@ -208,7 +208,7 @@ pub struct WriteTransaction<'db> {
dirty: AtomicBool,
durability: Durability,
// Persistent savepoints created during this transaction
created_persistent_savepoints: Mutex<HashSet<u64>>,
created_persistent_savepoints: Mutex<HashSet<SavepointId>>,
deleted_persistent_savepoints: Mutex<Vec<(SavepointId, TransactionId)>>,
live_write_transaction: MutexGuard<'db, Option<TransactionId>>,
}
Expand Down Expand Up @@ -326,16 +326,19 @@ impl<'db> WriteTransaction<'db> {

let mut next_table = self.open_system_table(NEXT_SAVEPOINT_TABLE)?;
let mut savepoint_table = self.open_system_table(SAVEPOINT_TABLE)?;
next_table.insert((), savepoint.get_id().0 + 1)?;
next_table.insert((), savepoint.get_id().next())?;

savepoint_table.insert(savepoint.get_id().0, savepoint.to_bytes().as_slice())?;
savepoint_table.insert(
savepoint.get_id(),
SerializedSavepoint::from_savepoint(&savepoint),
)?;

savepoint.set_persistent();

self.created_persistent_savepoints
.lock()
.unwrap()
.insert(savepoint.get_id().0);
.insert(savepoint.get_id());

Ok(savepoint.get_id().0)
}
Expand All @@ -344,7 +347,7 @@ impl<'db> WriteTransaction<'db> {
let next_table = self.open_system_table(NEXT_SAVEPOINT_TABLE)?;
let value = next_table.get(())?;
if let Some(next_id) = value {
Ok(Some(SavepointId(next_id.value())))
Ok(Some(next_id.value()))
} else {
Ok(None)
}
Expand All @@ -353,10 +356,10 @@ impl<'db> WriteTransaction<'db> {
/// Get a persistent savepoint given its id
pub fn get_persistent_savepoint(&self, id: u64) -> Result<Savepoint, SavepointError> {
let table = self.open_system_table(SAVEPOINT_TABLE)?;
let value = table.get(id)?;
let value = table.get(SavepointId(id))?;

value
.map(|x| Savepoint::from_bytes(x.value(), self.transaction_tracker.clone(), false))
.map(|x| x.value().to_savepoint(self.transaction_tracker.clone()))
.ok_or(SavepointError::InvalidSavepoint)
}

Expand All @@ -374,10 +377,11 @@ impl<'db> WriteTransaction<'db> {
return Err(SavepointError::InvalidSavepoint);
}
let mut table = self.open_system_table(SAVEPOINT_TABLE)?;
let savepoint = table.remove(id)?;
if let Some(bytes) = savepoint {
let savepoint =
Savepoint::from_bytes(bytes.value(), self.transaction_tracker.clone(), false);
let savepoint = table.remove(SavepointId(id))?;
if let Some(serialized) = savepoint {
let savepoint = serialized
.value()
.to_savepoint(self.transaction_tracker.clone());
self.deleted_persistent_savepoints
.lock()
.unwrap()
Expand All @@ -392,8 +396,8 @@ impl<'db> WriteTransaction<'db> {
pub fn list_persistent_savepoints(&self) -> Result<impl Iterator<Item = u64>> {
let table = self.open_system_table(SAVEPOINT_TABLE)?;
let mut savepoints = vec![];
for savepoint in table.range::<u64>(..)? {
savepoints.push(savepoint?.0.value());
for savepoint in table.range::<SavepointId>(..)? {
savepoints.push(savepoint?.0.value().0);
}
Ok(savepoints.into_iter())
}
Expand Down Expand Up @@ -763,7 +767,7 @@ impl<'db> WriteTransaction<'db> {
#[cfg(feature = "logging")]
info!("Aborting transaction id={:?}", self.transaction_id);
for savepoint in self.created_persistent_savepoints.lock().unwrap().iter() {
match self.delete_persistent_savepoint(*savepoint) {
match self.delete_persistent_savepoint(savepoint.0) {
Ok(_) => {}
Err(err) => match err {
SavepointError::InvalidSavepoint => {
Expand Down
4 changes: 2 additions & 2 deletions src/tree_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pub(crate) use btree_iters::{
};
pub use page_store::Savepoint;
pub(crate) use page_store::{
Page, PageHint, PageNumber, TransactionalMemory, FILE_FORMAT_VERSION, MAX_VALUE_LENGTH,
PAGE_SIZE,
Page, PageHint, PageNumber, SerializedSavepoint, TransactionalMemory, FILE_FORMAT_VERSION,
MAX_VALUE_LENGTH, PAGE_SIZE,
};
pub(crate) use table_tree::{
FreedPageList, FreedTableKey, InternalTableDefinition, TableTree, TableType,
Expand Down
1 change: 1 addition & 0 deletions src/tree_store/page_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub(crate) use base::{Page, PageHint, PageNumber, MAX_VALUE_LENGTH};
pub(crate) use header::PAGE_SIZE;
pub(crate) use page_manager::{xxh3_checksum, TransactionalMemory, FILE_FORMAT_VERSION};
pub use savepoint::Savepoint;
pub(crate) use savepoint::SerializedSavepoint;

pub(super) use base::{PageImpl, PageMut};
pub(super) use xxh3::hash128_with_seed;
2 changes: 1 addition & 1 deletion src/tree_store/page_store/page_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const MIN_DESIRED_USABLE_BYTES: u64 = 1024 * 1024;
const NUM_REGIONS: u32 = 1000;

// TODO: set to 1, when version 1.0 is released
pub(crate) const FILE_FORMAT_VERSION: u8 = 116;
pub(crate) const FILE_FORMAT_VERSION: u8 = 117;

fn ceil_log2(x: usize) -> u8 {
if x.is_power_of_two() {
Expand Down
Loading

0 comments on commit 68aa714

Please sign in to comment.