diff --git a/src/db.rs b/src/db.rs index 3c0392de..96ef9ed3 100644 --- a/src/db.rs +++ b/src/db.rs @@ -17,7 +17,7 @@ use std::marker::PhantomData; use std::ops::RangeFull; use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Condvar, Mutex}; use crate::error::TransactionError; use crate::multimap_table::{parse_subtree_roots, DynamicCollection}; @@ -236,7 +236,8 @@ pub struct Database { mem: TransactionalMemory, next_transaction_id: AtomicTransactionId, transaction_tracker: Arc>, - pub(crate) live_write_transaction: Mutex>, + live_write_transaction: Mutex>, + live_write_transaction_available: Condvar, } impl Database { @@ -253,6 +254,30 @@ impl Database { Self::builder().open(path) } + pub(crate) fn start_write_transaction(&self) -> TransactionId { + let mut live_write_transaction = self.live_write_transaction.lock().unwrap(); + while live_write_transaction.is_some() { + live_write_transaction = self + .live_write_transaction_available + .wait(live_write_transaction) + .unwrap(); + } + assert!(live_write_transaction.is_none()); + let transaction_id = self.next_transaction_id.next(); + #[cfg(feature = "logging")] + info!("Beginning write transaction id={:?}", transaction_id); + *live_write_transaction = Some(transaction_id); + + transaction_id + } + + pub(crate) fn end_write_transaction(&self, id: TransactionId) { + let mut live_write_transaction = self.live_write_transaction.lock().unwrap(); + assert_eq!(live_write_transaction.unwrap(), id); + *live_write_transaction = None; + self.live_write_transaction_available.notify_one(); + } + pub(crate) fn get_memory(&self) -> &TransactionalMemory { &self.mem } @@ -601,6 +626,7 @@ impl Database { next_transaction_id: AtomicTransactionId::new(next_transaction_id), transaction_tracker: Arc::new(Mutex::new(TransactionTracker::new())), live_write_transaction: Mutex::new(None), + live_write_transaction_available: Condvar::new(), }; // Restore the tracker state for any persistent savepoints @@ -648,10 +674,6 @@ impl Database { Ok((id, self.allocate_read_transaction()?)) } - pub(crate) fn increment_transaction_id(&self) -> TransactionId { - self.next_transaction_id.next() - } - /// Convenience method for [`Builder::new`] pub fn builder() -> Builder { Builder::new() diff --git a/src/transactions.rs b/src/transactions.rs index 6659671b..e0b6299d 100644 --- a/src/transactions.rs +++ b/src/transactions.rs @@ -20,7 +20,7 @@ use std::fmt::{Display, Formatter}; use std::marker::PhantomData; use std::ops::{RangeBounds, RangeFull}; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{Arc, Mutex}; use std::{panic, thread}; const NEXT_SAVEPOINT_TABLE: SystemTableDefinition<(), SavepointId> = @@ -394,7 +394,6 @@ pub struct WriteTransaction<'db> { // Persistent savepoints created during this transaction created_persistent_savepoints: Mutex>, deleted_persistent_savepoints: Mutex>, - live_write_transaction: MutexGuard<'db, Option>, } impl<'db> WriteTransaction<'db> { @@ -402,12 +401,7 @@ impl<'db> WriteTransaction<'db> { db: &'db Database, transaction_tracker: Arc>, ) -> Result { - let mut live_write_transaction = db.live_write_transaction.lock().unwrap(); - assert!(live_write_transaction.is_none()); - let transaction_id = db.increment_transaction_id(); - #[cfg(feature = "logging")] - info!("Beginning write transaction id={:?}", transaction_id); - *live_write_transaction = Some(transaction_id); + let transaction_id = db.start_write_transaction(); let root_page = db.get_memory().get_data_root(); let system_page = db.get_memory().get_system_root(); @@ -442,7 +436,6 @@ impl<'db> WriteTransaction<'db> { durability: Durability::Immediate, created_persistent_savepoints: Mutex::new(Default::default()), deleted_persistent_savepoints: Mutex::new(vec![]), - live_write_transaction, }) } @@ -1110,7 +1103,7 @@ impl<'db> WriteTransaction<'db> { impl<'a> Drop for WriteTransaction<'a> { fn drop(&mut self) { - *self.live_write_transaction = None; + self.db.end_write_transaction(self.transaction_id); if !self.completed && !thread::panicking() && !self.mem.storage_failure() { #[allow(unused_variables)] if let Err(error) = self.abort_inner() { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index dd8488c8..02771922 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1393,8 +1393,8 @@ fn compaction() { assert!(file_size2 < file_size); } -fn require_send(_: T) {} -fn require_sync(_: T) {} +fn require_send(_: &T) {} +fn require_sync(_: &T) {} #[test] fn is_send() { @@ -1403,12 +1403,15 @@ fn is_send() { let definition: TableDefinition = TableDefinition::new("x"); let txn = db.begin_write().unwrap(); - let table = txn.open_table(definition).unwrap(); - require_send(table); + { + let table = txn.open_table(definition).unwrap(); + require_send(&table); + require_sync(&txn); + } txn.commit().unwrap(); let txn = db.begin_read().unwrap(); let table = txn.open_table(definition).unwrap(); - require_sync(table); - require_sync(txn); + require_sync(&table); + require_sync(&txn); }