Skip to content

Commit

Permalink
Make WriteTransaction implement Send
Browse files Browse the repository at this point in the history
  • Loading branch information
cberner committed Sep 17, 2023
1 parent b15c5bb commit 6c0f3fb
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
34 changes: 28 additions & 6 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::marker::PhantomData;
use std::ops::RangeFull;
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Condvar, Mutex};

use crate::error::TransactionError;
use crate::multimap_table::{parse_subtree_roots, DynamicCollection};
Expand Down Expand Up @@ -236,7 +236,8 @@ pub struct Database {
mem: TransactionalMemory,
next_transaction_id: AtomicTransactionId,
transaction_tracker: Arc<Mutex<TransactionTracker>>,
pub(crate) live_write_transaction: Mutex<Option<TransactionId>>,
live_write_transaction: Mutex<Option<TransactionId>>,
live_write_transaction_available: Condvar,
}

impl Database {
Expand All @@ -253,6 +254,30 @@ impl Database {
Self::builder().open(path)
}

pub(crate) fn start_write_transaction(&self) -> TransactionId {
let mut live_write_transaction = self.live_write_transaction.lock().unwrap();
while live_write_transaction.is_some() {
live_write_transaction = self
.live_write_transaction_available
.wait(live_write_transaction)
.unwrap();
}
assert!(live_write_transaction.is_none());
let transaction_id = self.next_transaction_id.next();
#[cfg(feature = "logging")]
info!("Beginning write transaction id={:?}", transaction_id);
*live_write_transaction = Some(transaction_id);

transaction_id
}

pub(crate) fn end_write_transaction(&self, id: TransactionId) {
let mut live_write_transaction = self.live_write_transaction.lock().unwrap();
assert_eq!(live_write_transaction.unwrap(), id);
*live_write_transaction = None;
self.live_write_transaction_available.notify_one();
}

pub(crate) fn get_memory(&self) -> &TransactionalMemory {
&self.mem
}
Expand Down Expand Up @@ -601,6 +626,7 @@ impl Database {
next_transaction_id: AtomicTransactionId::new(next_transaction_id),
transaction_tracker: Arc::new(Mutex::new(TransactionTracker::new())),
live_write_transaction: Mutex::new(None),
live_write_transaction_available: Condvar::new(),
};

// Restore the tracker state for any persistent savepoints
Expand Down Expand Up @@ -648,10 +674,6 @@ impl Database {
Ok((id, self.allocate_read_transaction()?))
}

pub(crate) fn increment_transaction_id(&self) -> TransactionId {
self.next_transaction_id.next()
}

/// Convenience method for [`Builder::new`]
pub fn builder() -> Builder {
Builder::new()
Expand Down
13 changes: 3 additions & 10 deletions src/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::fmt::{Display, Formatter};
use std::marker::PhantomData;
use std::ops::{RangeBounds, RangeFull};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::{Arc, Mutex};
use std::{panic, thread};

const NEXT_SAVEPOINT_TABLE: SystemTableDefinition<(), SavepointId> =
Expand Down Expand Up @@ -394,20 +394,14 @@ pub struct WriteTransaction<'db> {
// Persistent savepoints created during this transaction
created_persistent_savepoints: Mutex<HashSet<SavepointId>>,
deleted_persistent_savepoints: Mutex<Vec<(SavepointId, TransactionId)>>,
live_write_transaction: MutexGuard<'db, Option<TransactionId>>,
}

impl<'db> WriteTransaction<'db> {
pub(crate) fn new(
db: &'db Database,
transaction_tracker: Arc<Mutex<TransactionTracker>>,
) -> Result<Self> {
let mut live_write_transaction = db.live_write_transaction.lock().unwrap();
assert!(live_write_transaction.is_none());
let transaction_id = db.increment_transaction_id();
#[cfg(feature = "logging")]
info!("Beginning write transaction id={:?}", transaction_id);
*live_write_transaction = Some(transaction_id);
let transaction_id = db.start_write_transaction();

let root_page = db.get_memory().get_data_root();
let system_page = db.get_memory().get_system_root();
Expand Down Expand Up @@ -442,7 +436,6 @@ impl<'db> WriteTransaction<'db> {
durability: Durability::Immediate,
created_persistent_savepoints: Mutex::new(Default::default()),
deleted_persistent_savepoints: Mutex::new(vec![]),
live_write_transaction,
})
}

Expand Down Expand Up @@ -1110,7 +1103,7 @@ impl<'db> WriteTransaction<'db> {

impl<'a> Drop for WriteTransaction<'a> {
fn drop(&mut self) {
*self.live_write_transaction = None;
self.db.end_write_transaction(self.transaction_id);
if !self.completed && !thread::panicking() && !self.mem.storage_failure() {
#[allow(unused_variables)]
if let Err(error) = self.abort_inner() {
Expand Down
15 changes: 9 additions & 6 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1393,8 +1393,8 @@ fn compaction() {
assert!(file_size2 < file_size);
}

fn require_send<T: Send>(_: T) {}
fn require_sync<T: Sync + Send>(_: T) {}
fn require_send<T: Send>(_: &T) {}
fn require_sync<T: Sync + Send>(_: &T) {}

#[test]
fn is_send() {
Expand All @@ -1403,12 +1403,15 @@ fn is_send() {
let definition: TableDefinition<u32, &[u8]> = TableDefinition::new("x");

let txn = db.begin_write().unwrap();
let table = txn.open_table(definition).unwrap();
require_send(table);
{
let table = txn.open_table(definition).unwrap();
require_send(&table);
require_sync(&txn);
}
txn.commit().unwrap();

let txn = db.begin_read().unwrap();
let table = txn.open_table(definition).unwrap();
require_sync(table);
require_sync(txn);
require_sync(&table);
require_sync(&txn);
}

0 comments on commit 6c0f3fb

Please sign in to comment.