diff --git a/src/transactions.rs b/src/transactions.rs index b60b8b8a..9e22f5f2 100644 --- a/src/transactions.rs +++ b/src/transactions.rs @@ -1183,7 +1183,6 @@ impl Drop for WriteTransaction { pub struct ReadTransaction { mem: Arc, tree: TableTree, - transaction_guard: Arc, } impl ReadTransaction { @@ -1195,9 +1194,8 @@ impl ReadTransaction { let guard = Arc::new(guard); Ok(Self { mem: mem.clone(), - tree: TableTree::new(root_page, PageHint::Clean, guard.clone(), mem) + tree: TableTree::new(root_page, PageHint::Clean, guard, mem) .map_err(TransactionError::Storage)?, - transaction_guard: guard, }) } @@ -1215,7 +1213,7 @@ impl ReadTransaction { definition.name().to_string(), header.get_root(), PageHint::Clean, - self.transaction_guard.clone(), + self.tree.clone_transaction_guard(), self.mem.clone(), )?) } @@ -1252,7 +1250,7 @@ impl ReadTransaction { header.get_root(), header.get_length(), PageHint::Clean, - self.transaction_guard.clone(), + self.tree.clone_transaction_guard(), self.mem.clone(), )?) } @@ -1298,7 +1296,9 @@ impl ReadTransaction { /// /// Returns `ReadTransactionStillInUse` error if a table or other object retrieved from the transaction still references this transaction pub fn close(self) -> Result<(), TransactionError> { - if Arc::strong_count(&self.transaction_guard) > 1 { + let cloned = self.tree.clone_transaction_guard(); + // Check for count greater than 2 because we just cloned the guard to get a reference to it + if Arc::strong_count(&cloned) > 2 { return Err(TransactionError::ReadTransactionStillInUse(self)); } // No-op, just drop ourself diff --git a/src/tree_store/btree.rs b/src/tree_store/btree.rs index 8abf7e0b..bedcd631 100644 --- a/src/tree_store/btree.rs +++ b/src/tree_store/btree.rs @@ -543,7 +543,7 @@ impl RawBtree { pub(crate) struct Btree { mem: Arc, - _transaction_guard: Arc, + transaction_guard: Arc, // Cache of the root page to avoid repeated lookups cached_root: Option, root: Option, @@ -566,7 +566,7 @@ impl Btree { }; Ok(Self { mem, - _transaction_guard: guard, + transaction_guard: guard, cached_root, root, hint, @@ -575,6 +575,10 @@ impl Btree { }) } + pub(crate) fn clone_transaction_guard(&self) -> Arc { + self.transaction_guard.clone() + } + pub(crate) fn get_root(&self) -> Option { self.root } diff --git a/src/tree_store/table_tree.rs b/src/tree_store/table_tree.rs index f1435461..909bba8d 100644 --- a/src/tree_store/table_tree.rs +++ b/src/tree_store/table_tree.rs @@ -439,6 +439,10 @@ impl TableTree { }) } + pub(crate) fn clone_transaction_guard(&self) -> Arc { + self.tree.clone_transaction_guard() + } + // root_page: the root of the master table pub(crate) fn list_tables(&self, table_type: TableType) -> Result> { let iter = self.tree.range::(&(..))?; diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 408a19a7..b9be0165 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -9,7 +9,7 @@ use rand::Rng; use redb::{ AccessGuard, Builder, CompactionError, Database, Durability, Key, MultimapRange, MultimapTableDefinition, MultimapValue, Range, ReadableTable, ReadableTableMetadata, - TableDefinition, TableStats, Value, + TableDefinition, TableStats, TransactionError, Value, }; use redb::{DatabaseError, ReadableMultimapTable, SavepointError, StorageError, TableError}; @@ -253,6 +253,27 @@ fn many_pairs() { wtx.commit().unwrap(); } +#[test] +fn explicit_close() { + let tmpfile = create_tempfile(); + const TABLE: TableDefinition = TableDefinition::new("TABLE"); + let db = Database::create(tmpfile.path()).unwrap(); + let wtx = db.begin_write().unwrap(); + wtx.open_table(TABLE).unwrap(); + wtx.commit().unwrap(); + + let tx = db.begin_read().unwrap(); + let table = tx.open_table(TABLE).unwrap(); + assert!(matches!( + tx.close(), + Err(TransactionError::ReadTransactionStillInUse(_)) + )); + drop(table); + + let tx2 = db.begin_read().unwrap(); + tx2.close().unwrap(); +} + #[test] fn large_keys() { let tmpfile = create_tempfile();