Skip to content

Commit 436694b

Browse files
committed
feat: add preupdate hook
1 parent 42ce24d commit 436694b

File tree

6 files changed

+430
-15
lines changed

6 files changed

+430
-15
lines changed

Cargo.lock

+6-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-sqlite/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ default-features = false
5858
features = [
5959
"pkg-config",
6060
"vcpkg",
61-
"unlock_notify"
61+
"unlock_notify",
62+
"preupdate_hook"
6263
]
6364

6465
[dependencies.sqlx-core]

sqlx-sqlite/src/connection/establish.rs

+1
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ impl EstablishParams {
296296
log_settings: self.log_settings.clone(),
297297
progress_handler_callback: None,
298298
update_hook_callback: None,
299+
preupdate_hook_callback: None,
299300
commit_hook_callback: None,
300301
rollback_hook_callback: None,
301302
})

sqlx-sqlite/src/connection/mod.rs

+199-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ use futures_core::future::BoxFuture;
1111
use futures_intrusive::sync::MutexGuard;
1212
use futures_util::future;
1313
use libsqlite3_sys::{
14-
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
15-
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
14+
sqlite3, sqlite3_commit_hook, sqlite3_preupdate_count, sqlite3_preupdate_depth,
15+
sqlite3_preupdate_hook, sqlite3_preupdate_new, sqlite3_preupdate_old, sqlite3_progress_handler,
16+
sqlite3_rollback_hook, sqlite3_update_hook, sqlite3_value, sqlite3_value_type, SQLITE_DELETE,
17+
SQLITE_INSERT, SQLITE_OK, SQLITE_UPDATE,
1618
};
1719

1820
pub(crate) use handle::ConnectionHandle;
@@ -26,7 +28,8 @@ use crate::connection::establish::EstablishParams;
2628
use crate::connection::worker::ConnectionWorker;
2729
use crate::options::OptimizeOnClose;
2830
use crate::statement::VirtualStatement;
29-
use crate::{Sqlite, SqliteConnectOptions};
31+
use crate::type_info::DataType;
32+
use crate::{Sqlite, SqliteConnectOptions, SqliteError, SqliteTypeInfo, SqliteValue};
3033

3134
pub(crate) mod collation;
3235
pub(crate) mod describe;
@@ -88,6 +91,14 @@ pub struct UpdateHookResult<'a> {
8891
pub table: &'a str,
8992
pub rowid: i64,
9093
}
94+
95+
pub struct PreupdateHookResult<'a> {
96+
pub operation: SqliteOperation,
97+
pub database: &'a str,
98+
pub table: &'a str,
99+
pub case: PreupdateCase,
100+
}
101+
91102
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
92103
unsafe impl Send for UpdateHookHandler {}
93104

@@ -97,6 +108,108 @@ unsafe impl Send for CommitHookHandler {}
97108
pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
98109
unsafe impl Send for RollbackHookHandler {}
99110

111+
pub(crate) struct PreupdateHookHandler(NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>);
112+
unsafe impl Send for PreupdateHookHandler {}
113+
114+
/// The possible cases for when a PreUpdate Hook gets triggered. Allows access to the relevant
115+
/// functions for each case through the contained values.
116+
pub enum PreupdateCase {
117+
/// Pre-update hook was triggered by an insert.
118+
Insert(PreupdateNewValueAccessor),
119+
/// Pre-update hook was triggered by a delete.
120+
Delete(PreupdateOldValueAccessor),
121+
/// Pre-update hook was triggered by an update.
122+
Update {
123+
old_value_accessor: PreupdateOldValueAccessor,
124+
new_value_accessor: PreupdateNewValueAccessor,
125+
},
126+
/// This variant is not normally produced by SQLite. You may encounter it
127+
/// if you're using a different version than what's supported by this library.
128+
Unknown,
129+
}
130+
131+
/// An accessor for the new values of the row being inserted/updated during the preupdate callback.
132+
#[derive(Debug)]
133+
pub struct PreupdateNewValueAccessor {
134+
db: *mut sqlite3,
135+
new_row_id: i64,
136+
}
137+
138+
impl PreupdateNewValueAccessor {
139+
/// Gets the amount of columns in the row being inserted/updated.
140+
pub fn get_column_count(&self) -> i32 {
141+
unsafe { sqlite3_preupdate_count(self.db) }
142+
}
143+
144+
/// Gets the depth of the query that triggered the preupdate hook.
145+
/// Returns 0 if the preupdate callback was invoked as a result of
146+
/// a direct insert, update, or delete operation;
147+
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
148+
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
149+
pub fn get_query_depth(&self) -> i32 {
150+
unsafe { sqlite3_preupdate_depth(self.db) }
151+
}
152+
153+
/// Gets the row id of the row being inserted/updated.
154+
pub fn get_new_row_id(&self) -> i64 {
155+
self.new_row_id
156+
}
157+
158+
/// Gets the value of the row being updated/deleted at the specified index.
159+
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValue, Error> {
160+
let mut p_value: *mut sqlite3_value = ptr::null_mut();
161+
unsafe {
162+
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
163+
if ret != SQLITE_OK {
164+
return Err(Error::Database(Box::new(SqliteError::new(self.db))));
165+
}
166+
let data_type = DataType::from_code(sqlite3_value_type(p_value));
167+
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
168+
}
169+
}
170+
}
171+
172+
/// An accessor for the old values of the row being deleted/updated during the preupdate callback.
173+
#[derive(Debug)]
174+
pub struct PreupdateOldValueAccessor {
175+
db: *mut sqlite3,
176+
old_row_id: i64,
177+
}
178+
179+
impl PreupdateOldValueAccessor {
180+
/// Gets the amount of columns in the row being deleted/updated.
181+
pub fn get_column_count(&self) -> i32 {
182+
unsafe { sqlite3_preupdate_count(self.db) }
183+
}
184+
185+
/// Gets the depth of the query that triggered the preupdate hook.
186+
/// Returns 0 if the preupdate callback was invoked as a result of
187+
/// a direct insert, update, or delete operation;
188+
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
189+
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
190+
pub fn get_query_depth(&self) -> i32 {
191+
unsafe { sqlite3_preupdate_depth(self.db) }
192+
}
193+
194+
/// Gets the row id of the row being updated/deleted.
195+
pub fn get_old_row_id(&self) -> i64 {
196+
self.old_row_id
197+
}
198+
199+
/// Gets the value of the row being updated/deleted at the specified index.
200+
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValue, Error> {
201+
let mut p_value: *mut sqlite3_value = ptr::null_mut();
202+
unsafe {
203+
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
204+
if ret != SQLITE_OK {
205+
return Err(Error::Database(Box::new(SqliteError::new(self.db))));
206+
}
207+
let data_type = DataType::from_code(sqlite3_value_type(p_value));
208+
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
209+
}
210+
}
211+
}
212+
100213
pub(crate) struct ConnectionState {
101214
pub(crate) handle: ConnectionHandle,
102215

@@ -113,6 +226,8 @@ pub(crate) struct ConnectionState {
113226

114227
update_hook_callback: Option<UpdateHookHandler>,
115228

229+
preupdate_hook_callback: Option<PreupdateHookHandler>,
230+
116231
commit_hook_callback: Option<CommitHookHandler>,
117232

118233
rollback_hook_callback: Option<RollbackHookHandler>,
@@ -138,6 +253,15 @@ impl ConnectionState {
138253
}
139254
}
140255

256+
pub(crate) fn remove_preupdate_hook(&mut self) {
257+
if let Some(mut handler) = self.preupdate_hook_callback.take() {
258+
unsafe {
259+
sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
260+
let _ = { Box::from_raw(handler.0.as_mut()) };
261+
}
262+
}
263+
}
264+
141265
pub(crate) fn remove_commit_hook(&mut self) {
142266
if let Some(mut handler) = self.commit_hook_callback.take() {
143267
unsafe {
@@ -312,6 +436,47 @@ extern "C" fn update_hook<F>(
312436
}
313437
}
314438

439+
extern "C" fn preupdate_hook<F>(
440+
callback: *mut c_void,
441+
db: *mut sqlite3,
442+
op_code: c_int,
443+
database: *const i8,
444+
table: *const i8,
445+
old_row_id: i64,
446+
new_row_id: i64,
447+
) where
448+
F: FnMut(PreupdateHookResult),
449+
{
450+
unsafe {
451+
let _ = catch_unwind(|| {
452+
let callback: *mut F = callback.cast::<F>();
453+
let operation: SqliteOperation = op_code.into();
454+
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
455+
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
456+
457+
let preupdate_case = match operation {
458+
SqliteOperation::Insert => {
459+
PreupdateCase::Insert(PreupdateNewValueAccessor { db, new_row_id })
460+
}
461+
SqliteOperation::Delete => {
462+
PreupdateCase::Delete(PreupdateOldValueAccessor { db, old_row_id })
463+
}
464+
SqliteOperation::Update => PreupdateCase::Update {
465+
old_value_accessor: PreupdateOldValueAccessor { db, old_row_id },
466+
new_value_accessor: PreupdateNewValueAccessor { db, new_row_id },
467+
},
468+
SqliteOperation::Unknown(_) => PreupdateCase::Unknown,
469+
};
470+
(*callback)(PreupdateHookResult {
471+
operation,
472+
database,
473+
table,
474+
case: preupdate_case,
475+
})
476+
});
477+
}
478+
}
479+
315480
extern "C" fn commit_hook<F>(callback: *mut c_void) -> c_int
316481
where
317482
F: FnMut() -> bool,
@@ -476,6 +641,33 @@ impl LockedSqliteHandle<'_> {
476641
}
477642
}
478643

644+
/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
645+
/// At most one preupdate hook may be registered at a time on a single database connection.
646+
///
647+
/// The preupdate hook only fires for changes to real database tables;
648+
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
649+
///
650+
/// See https://sqlite.org/c3ref/preupdate_count.html
651+
pub fn set_preupdate_hook<F>(&mut self, callback: F)
652+
where
653+
F: FnMut(PreupdateHookResult) + Send + 'static,
654+
{
655+
unsafe {
656+
let callback_boxed = Box::new(callback);
657+
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
658+
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
659+
let handler = callback.as_ptr() as *mut _;
660+
self.guard.remove_preupdate_hook();
661+
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));
662+
663+
sqlite3_preupdate_hook(
664+
self.as_raw_handle().as_mut(),
665+
Some(preupdate_hook::<F>),
666+
handler,
667+
);
668+
}
669+
}
670+
479671
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
480672
pub fn remove_progress_handler(&mut self) {
481673
self.guard.remove_progress_handler();
@@ -492,6 +684,10 @@ impl LockedSqliteHandle<'_> {
492684
pub fn remove_rollback_hook(&mut self) {
493685
self.guard.remove_rollback_hook();
494686
}
687+
688+
pub fn remove_preupdate_hook(&mut self) {
689+
self.guard.remove_preupdate_hook();
690+
}
495691
}
496692

497693
impl Drop for ConnectionState {

sqlx-sqlite/src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ use std::sync::atomic::AtomicBool;
4646

4747
pub use arguments::{SqliteArgumentValue, SqliteArguments};
4848
pub use column::SqliteColumn;
49-
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
49+
pub use connection::{
50+
LockedSqliteHandle, PreupdateCase, PreupdateHookResult, PreupdateNewValueAccessor,
51+
PreupdateOldValueAccessor, SqliteConnection, SqliteOperation, UpdateHookResult,
52+
};
5053
pub use database::Sqlite;
5154
pub use error::SqliteError;
5255
pub use options::{

0 commit comments

Comments
 (0)