From b87f231c0664aa82daf3b27da3041b166f45c577 Mon Sep 17 00:00:00 2001 From: shamb0 Date: Tue, 20 Aug 2024 14:56:49 +0530 Subject: [PATCH] - Integrate PR#30 to support benchmarking and bringup of Parquet cache configurations Signed-off-by: shamb0 --- Cargo.lock | 42 +++++++++++- Cargo.toml | 1 + src/api/csv.rs | 66 +++++++++--------- src/api/duckdb.rs | 34 +++++---- src/api/parquet.rs | 86 ++++++++++++----------- src/duckdb/connection.rs | 145 +++++++++++++++++++-------------------- src/duckdb/csv.rs | 28 ++++---- src/duckdb/delta.rs | 20 +++--- src/duckdb/iceberg.rs | 20 +++--- src/duckdb/json.rs | 6 +- src/duckdb/parquet.rs | 30 ++++---- src/duckdb/spatial.rs | 15 +++- src/env.rs | 131 +++++++++++++++++++++++++++++++++++ src/fdw/base.rs | 64 ++--------------- src/fdw/trigger.rs | 117 ++++++++++++++++++++----------- src/lib.rs | 6 +- 16 files changed, 500 insertions(+), 311 deletions(-) create mode 100644 src/env.rs diff --git a/Cargo.lock b/Cargo.lock index dd9a35a0..ee835aa9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -706,6 +706,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", +] + [[package]] name = "atomic-traits" version = "0.3.0" @@ -1806,6 +1815,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "critical-section" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64009896348fc5af4222e9cf7d7d82a95a256c634ebcf61c53e4ea461422242" + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -2935,6 +2950,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hash32" version = "0.3.1" @@ -2972,13 +2996,26 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32 0.2.1", + "rustc_version 0.4.0", + "spin", + "stable_deref_trait", +] + [[package]] name = "heapless" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" dependencies = [ - "hash32", + "hash32 0.3.1", "stable_deref_trait", ] @@ -4020,6 +4057,7 @@ dependencies = [ "duckdb", "futures", "geojson", + "heapless 0.7.17", "pgrx", "pgrx-tests", "rstest", @@ -4050,7 +4088,7 @@ dependencies = [ "bitflags 2.6.0", "bitvec", "enum-map", - "heapless", + "heapless 0.8.0", "libc", "once_cell", "pgrx-macros", diff --git a/Cargo.toml b/Cargo.toml index 020fd677..144da130 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ strum = { version = "0.26.3", features = ["derive"] } supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "c32abb7" } thiserror = "1.0.63" uuid = "1.10.0" +heapless = "0.7.16" [dev-dependencies] aws-config = "1.5.6" diff --git a/src/api/csv.rs b/src/api/csv.rs index aa35eb16..f6cacacc 100644 --- a/src/api/csv.rs +++ b/src/api/csv.rs @@ -15,12 +15,14 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use anyhow::Result; +use anyhow::{anyhow, Result}; use duckdb::types::Value; use pgrx::*; -use crate::duckdb::connection; use crate::duckdb::utils; +use crate::env::get_global_connection; +use crate::with_connection; +use duckdb::Connection; type SniffCsvRow = ( Option, @@ -62,34 +64,36 @@ pub fn sniff_csv( #[inline] fn sniff_csv_impl(files: &str, sample_size: Option) -> Result> { - let schema_str = vec![ - Some(utils::format_csv(files)), - sample_size.map(|s| s.to_string()), - ] - .into_iter() - .flatten() - .collect::>() - .join(", "); - let conn = unsafe { &*connection::get_global_connection().get() }; - let query = format!("SELECT * FROM sniff_csv({schema_str})"); - let mut stmt = conn.prepare(&query)?; + with_connection!(|conn: &Connection| { + let schema_str = vec![ + Some(utils::format_csv(files)), + sample_size.map(|s| s.to_string()), + ] + .into_iter() + .flatten() + .collect::>() + .join(", "); - Ok(stmt - .query_map([], |row| { - Ok(( - row.get::<_, Option>(0)?, - row.get::<_, Option>(1)?, - row.get::<_, Option>(2)?, - row.get::<_, Option>(3)?, - row.get::<_, Option>(4)?, - row.get::<_, Option>(5)?, - row.get::<_, Option>(6)?.map(|v| format!("{:?}", v)), - row.get::<_, Option>(7)?, - row.get::<_, Option>(8)?, - row.get::<_, Option>(9)?, - row.get::<_, Option>(10)?, - )) - })? - .map(|row| row.unwrap()) - .collect::>()) + let query = format!("SELECT * FROM sniff_csv({schema_str})"); + let mut stmt = conn.prepare(&query)?; + + Ok(stmt + .query_map([], |row| { + Ok(( + row.get::<_, Option>(0)?, + row.get::<_, Option>(1)?, + row.get::<_, Option>(2)?, + row.get::<_, Option>(3)?, + row.get::<_, Option>(4)?, + row.get::<_, Option>(5)?, + row.get::<_, Option>(6)?.map(|v| format!("{:?}", v)), + row.get::<_, Option>(7)?, + row.get::<_, Option>(8)?, + row.get::<_, Option>(9)?, + row.get::<_, Option>(10)?, + )) + })? + .map(|row| row.unwrap()) + .collect::>()) + }) } diff --git a/src/api/duckdb.rs b/src/api/duckdb.rs index 6f220816..68153de4 100644 --- a/src/api/duckdb.rs +++ b/src/api/duckdb.rs @@ -1,7 +1,10 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use pgrx::*; use crate::duckdb::connection; +use crate::env::get_global_connection; +use crate::with_connection; +use duckdb::Connection; type DuckdbSettingsRow = ( Option, @@ -36,19 +39,20 @@ pub fn duckdb_settings() -> iter::TableIterator< #[inline] fn duckdb_settings_impl() -> Result> { - let conn = unsafe { &*connection::get_global_connection().get() }; - let mut stmt = conn.prepare("SELECT * FROM duckdb_settings()")?; + with_connection!(|conn: &Connection| { + let mut stmt = conn.prepare("SELECT * FROM duckdb_settings()")?; - Ok(stmt - .query_map([], |row| { - Ok(( - row.get::<_, Option>(0)?, - row.get::<_, Option>(1)?, - row.get::<_, Option>(2)?, - row.get::<_, Option>(3)?, - row.get::<_, Option>(4)?, - )) - })? - .map(|row| row.unwrap()) - .collect::>()) + Ok(stmt + .query_map([], |row| { + Ok(( + row.get::<_, Option>(0)?, + row.get::<_, Option>(1)?, + row.get::<_, Option>(2)?, + row.get::<_, Option>(3)?, + row.get::<_, Option>(4)?, + )) + })? + .map(|row| row.unwrap()) + .collect::>()) + }) } diff --git a/src/api/parquet.rs b/src/api/parquet.rs index a557a328..2989122e 100644 --- a/src/api/parquet.rs +++ b/src/api/parquet.rs @@ -15,11 +15,13 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use anyhow::Result; +use anyhow::{anyhow, Result}; use pgrx::*; -use crate::duckdb::connection; use crate::duckdb::utils; +use crate::env::get_global_connection; +use crate::with_connection; +use duckdb::Connection; type ParquetSchemaRow = ( Option, @@ -87,49 +89,51 @@ pub fn parquet_schema( #[inline] fn parquet_schema_impl(files: &str) -> Result> { - let schema_str = utils::format_csv(files); - let conn = unsafe { &*connection::get_global_connection().get() }; - let query = format!("SELECT * FROM parquet_schema({schema_str})"); - let mut stmt = conn.prepare(&query)?; + with_connection!(|conn: &Connection| { + let schema_str = utils::format_csv(files); + let query = format!("SELECT * FROM parquet_schema({schema_str})"); + let mut stmt = conn.prepare(&query)?; - Ok(stmt - .query_map([], |row| { - Ok(( - row.get::<_, Option>(0)?, - row.get::<_, Option>(1)?, - row.get::<_, Option>(2)?, - row.get::<_, Option>(3)?, - row.get::<_, Option>(4)?, - row.get::<_, Option>(5)?, - row.get::<_, Option>(6)?, - row.get::<_, Option>(7)?, - row.get::<_, Option>(8)?, - row.get::<_, Option>(9)?, - row.get::<_, Option>(10)?, - )) - })? - .map(|row| row.unwrap()) - .collect::>()) + Ok(stmt + .query_map([], |row| { + Ok(( + row.get::<_, Option>(0)?, + row.get::<_, Option>(1)?, + row.get::<_, Option>(2)?, + row.get::<_, Option>(3)?, + row.get::<_, Option>(4)?, + row.get::<_, Option>(5)?, + row.get::<_, Option>(6)?, + row.get::<_, Option>(7)?, + row.get::<_, Option>(8)?, + row.get::<_, Option>(9)?, + row.get::<_, Option>(10)?, + )) + })? + .map(|row| row.unwrap()) + .collect::>()) + }) } #[inline] fn parquet_describe_impl(files: &str) -> Result> { - let schema_str = utils::format_csv(files); - let conn = unsafe { &*connection::get_global_connection().get() }; - let query = format!("DESCRIBE SELECT * FROM {schema_str}"); - let mut stmt = conn.prepare(&query)?; + with_connection!(|conn: &Connection| { + let schema_str = utils::format_csv(files); + let query = format!("DESCRIBE SELECT * FROM {schema_str}"); + let mut stmt = conn.prepare(&query)?; - Ok(stmt - .query_map([], |row| { - Ok(( - row.get::<_, Option>(0)?, - row.get::<_, Option>(1)?, - row.get::<_, Option>(2)?, - row.get::<_, Option>(3)?, - row.get::<_, Option>(4)?, - row.get::<_, Option>(5)?, - )) - })? - .map(|row| row.unwrap()) - .collect::>()) + Ok(stmt + .query_map([], |row| { + Ok(( + row.get::<_, Option>(0)?, + row.get::<_, Option>(1)?, + row.get::<_, Option>(2)?, + row.get::<_, Option>(3)?, + row.get::<_, Option>(4)?, + row.get::<_, Option>(5)?, + )) + })? + .map(|row| row.unwrap()) + .collect::>()) + }) } diff --git a/src/duckdb/connection.rs b/src/duckdb/connection.rs index b420916e..6f475941 100644 --- a/src/duckdb/connection.rs +++ b/src/duckdb/connection.rs @@ -25,18 +25,18 @@ use std::collections::HashMap; use std::sync::Once; use std::thread; +use crate::env::{get_global_connection, interrupt_all_connections}; +use crate::with_connection; + use super::{csv, delta, iceberg, json, parquet, secret, spatial}; // Global mutable static variables -static mut GLOBAL_CONNECTION: Option> = None; static mut GLOBAL_STATEMENT: Option>>> = None; static mut GLOBAL_ARROW: Option>>> = None; static INIT: Once = Once::new(); fn init_globals() { - let conn = Connection::open_in_memory().expect("failed to open duckdb connection"); unsafe { - GLOBAL_CONNECTION = Some(UnsafeCell::new(conn)); GLOBAL_STATEMENT = Some(UnsafeCell::new(None)); GLOBAL_ARROW = Some(UnsafeCell::new(None)); } @@ -44,33 +44,33 @@ fn init_globals() { thread::spawn(move || { let mut signals = Signals::new([SIGTERM, SIGINT, SIGQUIT]).expect("error registering signal listener"); + for _ in signals.forever() { - let conn = unsafe { &mut *get_global_connection().get() }; - conn.interrupt(); + if let Err(err) = interrupt_all_connections() { + eprintln!("Failed to interrupt connections: {}", err); + } } }); } fn check_extension_loaded(extension_name: &str) -> Result { - unsafe { - let conn = &mut *get_global_connection().get(); + with_connection!(|conn: &Connection| { let mut statement = conn.prepare(format!("SELECT * FROM duckdb_extensions() WHERE extension_name = '{extension_name}' AND installed = true AND loaded = true").as_str())?; match statement.query([])?.next() { Ok(Some(_)) => Ok(true), _ => Ok(false), } - } + }) } -pub fn get_global_connection() -> &'static UnsafeCell { - INIT.call_once(|| { - init_globals(); - }); - unsafe { - GLOBAL_CONNECTION - .as_ref() - .expect("Connection not initialized") - } +fn iceberg_loaded() -> Result { + with_connection!(|conn: &Connection| { + let mut statement = conn.prepare("SELECT * FROM duckdb_extensions() WHERE extension_name = 'iceberg' AND installed = true AND loaded = true")?; + match statement.query([])?.next() { + Ok(Some(_)) => Ok(true), + _ => Ok(false), + } + }) } fn get_global_statement() -> &'static UnsafeCell>> { @@ -91,48 +91,48 @@ fn get_global_arrow() -> &'static UnsafeCell>> { unsafe { GLOBAL_ARROW.as_ref().expect("Arrow not initialized") } } -pub fn create_csv_view( +pub fn create_csv_relation( table_name: &str, schema_name: &str, table_options: HashMap, ) -> Result { - let statement = csv::create_view(table_name, schema_name, table_options)?; + let statement = csv::create_duckdb_relation(table_name, schema_name, table_options)?; execute(statement.as_str(), []) } -pub fn create_delta_view( +pub fn create_delta_relation( table_name: &str, schema_name: &str, table_options: HashMap, ) -> Result { - let statement = delta::create_view(table_name, schema_name, table_options)?; + let statement = delta::create_duckdb_relation(table_name, schema_name, table_options)?; execute(statement.as_str(), []) } -pub fn create_iceberg_view( +pub fn create_iceberg_relation( table_name: &str, schema_name: &str, table_options: HashMap, ) -> Result { - if !check_extension_loaded("iceberg")? { + if !iceberg_loaded()? { execute("INSTALL iceberg", [])?; execute("LOAD iceberg", [])?; } - let statement = iceberg::create_view(table_name, schema_name, table_options)?; + let statement = iceberg::create_duckdb_relation(table_name, schema_name, table_options)?; execute(statement.as_str(), []) } -pub fn create_parquet_view( +pub fn create_parquet_relation( table_name: &str, schema_name: &str, table_options: HashMap, ) -> Result { - let statement = parquet::create_view(table_name, schema_name, table_options)?; + let statement = parquet::create_duckdb_relation(table_name, schema_name, table_options)?; execute(statement.as_str(), []) } -pub fn create_spatial_view( +pub fn create_spatial_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -142,37 +142,37 @@ pub fn create_spatial_view( execute("LOAD spatial", [])?; } - let statement = spatial::create_view(table_name, schema_name, table_options)?; + let statement = spatial::create_duckdb_relation(table_name, schema_name, table_options)?; execute(statement.as_str(), []) } -pub fn create_json_view( +pub fn create_json_relation( table_name: &str, schema_name: &str, table_options: HashMap, ) -> Result { - let statement = json::create_view(table_name, schema_name, table_options)?; + let statement = json::create_duckdb_relation(table_name, schema_name, table_options)?; execute(statement.as_str(), []) } pub fn create_arrow(sql: &str) -> Result { - unsafe { - let conn = &mut *get_global_connection().get(); - let statement = conn.prepare(sql)?; - let static_statement: Statement<'static> = std::mem::transmute(statement); - - *get_global_statement().get() = Some(static_statement); - - if let Some(static_statement) = get_global_statement().get().as_mut().unwrap() { - let arrow = static_statement.query_arrow([])?; - *get_global_arrow().get() = Some(std::mem::transmute::< - duckdb::Arrow<'_>, - duckdb::Arrow<'_>, - >(arrow)); + with_connection!(|conn: &Connection| { + unsafe { + let statement = conn.prepare(sql)?; + let static_statement: Statement<'static> = std::mem::transmute(statement); + + *get_global_statement().get() = Some(static_statement); + + if let Some(static_statement) = get_global_statement().get().as_mut().unwrap() { + let arrow = static_statement.query_arrow([])?; + *get_global_arrow().get() = Some(std::mem::transmute::< + duckdb::Arrow<'_>, + duckdb::Arrow<'_>, + >(arrow)); + } } - } - - Ok(true) + Ok(true) + }) } pub fn clear_arrow() { @@ -182,11 +182,9 @@ pub fn clear_arrow() { } } -pub fn create_secret( - secret_name: &str, - user_mapping_options: HashMap, -) -> Result { - let statement = secret::create_secret(secret_name, user_mapping_options)?; +pub fn create_secret(user_mapping_options: HashMap) -> Result { + const DEFAULT_SECRET: &str = "default_secret"; + let statement = secret::create_secret(DEFAULT_SECRET, user_mapping_options)?; execute(statement.as_str(), []) } @@ -211,35 +209,36 @@ pub fn get_batches() -> Result> { } pub fn execute(sql: &str, params: P) -> Result { - unsafe { - let conn = &*get_global_connection().get(); + with_connection!(|conn: &Connection| { conn.execute(sql, params).map_err(|err| anyhow!("{err}")) - } + }) } -pub fn view_exists(table_name: &str, schema_name: &str) -> Result { - unsafe { - let conn = &mut *get_global_connection().get(); - let mut statement = conn.prepare(format!("SELECT * from information_schema.tables WHERE table_schema = '{schema_name}' AND table_name = '{table_name}' AND table_type = 'VIEW'").as_str())?; - match statement.query([])?.next() { - Ok(Some(_)) => Ok(true), - _ => Ok(false), +pub fn drop_relation(table_name: &str, schema_name: &str) -> Result<()> { + with_connection!(|conn: &Connection| { + let mut statement = conn.prepare(format!("SELECT table_type from information_schema.tables WHERE table_schema = '{schema_name}' AND table_name = '{table_name}' LIMIT 1").as_str())?; + if let Ok(Some(row)) = statement.query([])?.next() { + let table_type: String = row.get(0)?; + let table_type = table_type.replace("BASE", "").trim().to_string(); + let statement = format!("DROP {table_type} {schema_name}.{table_name}"); + conn.execute(statement.as_str(), [])?; } - } + Ok(()) + }) } pub fn get_available_schemas() -> Result> { - let conn = unsafe { &*get_global_connection().get() }; - let mut stmt = conn.prepare("select DISTINCT(nspname) from pg_namespace;")?; - let schemas: Vec = stmt - .query_map([], |row| { - let s: String = row.get(0)?; - Ok(s) - })? - .map(|x| x.unwrap()) - .collect(); - - Ok(schemas) + with_connection!(|conn: &Connection| { + let mut stmt = conn.prepare("select DISTINCT(nspname) from pg_namespace;")?; + let schemas: Vec = stmt + .query_map([], |row| { + let s: String = row.get(0)?; + Ok(s) + })? + .map(|x| x.unwrap()) + .collect(); + Ok(schemas) + }) } pub fn set_search_path(search_path: Vec) -> Result<()> { diff --git a/src/duckdb/csv.rs b/src/duckdb/csv.rs index c87f9652..5ca7f13c 100644 --- a/src/duckdb/csv.rs +++ b/src/duckdb/csv.rs @@ -30,6 +30,7 @@ pub enum CsvOption { AllowQuotedNulls, AutoDetect, AutoTypeCandidates, + Cache, Columns, Compression, Dateformat, @@ -69,6 +70,7 @@ impl OptionValidator for CsvOption { Self::AllowQuotedNulls => false, Self::AutoDetect => false, Self::AutoTypeCandidates => false, + Self::Cache => false, Self::Columns => false, Self::Compression => false, Self::Dateformat => false, @@ -103,7 +105,7 @@ impl OptionValidator for CsvOption { } } -pub fn create_view( +pub fn create_duckdb_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -277,12 +279,14 @@ pub fn create_view( .collect::>() .join(", "); - let default_select = "*".to_string(); - let select = table_options - .get(CsvOption::Select.as_ref()) - .unwrap_or(&default_select); + let cache = table_options + .get(CsvOption::Cache.as_ref()) + .map(|s| s.eq_ignore_ascii_case("true")) + .unwrap_or(false); - Ok(format!("CREATE VIEW IF NOT EXISTS {schema_name}.{table_name} AS SELECT {select} FROM read_csv({create_csv_str})")) + let relation = if cache { "TABLE" } else { "VIEW" }; + + Ok(format!("CREATE {relation} IF NOT EXISTS {schema_name}.{table_name} AS SELECT * FROM read_csv({create_csv_str})")) } #[cfg(test)] @@ -291,7 +295,7 @@ mod tests { use duckdb::Connection; #[test] - fn test_create_csv_view_single_file() { + fn test_create_csv_relation_single_file() { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( @@ -300,7 +304,7 @@ mod tests { )]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_csv('/data/file.csv')"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); @@ -312,7 +316,7 @@ mod tests { } #[test] - fn test_create_csv_view_multiple_files() { + fn test_create_csv_relation_multiple_files() { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( @@ -321,7 +325,7 @@ mod tests { )]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_csv(['/data/file1.csv', '/data/file2.csv'])"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); @@ -333,7 +337,7 @@ mod tests { } #[test] - fn test_create_csv_view_with_options() { + fn test_create_csv_relation_with_options() { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([ @@ -441,7 +445,7 @@ mod tests { ]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_csv('/data/file.csv', all_varchar = true, allow_quoted_nulls = true, auto_detect = true, auto_type_candidates = ['BIGINT', 'DATE'], columns = {'col1': 'INTEGER', 'col2': 'VARCHAR'}, compression = 'gzip', dateformat = '%d/%m/%Y', decimal_separator = '.', delim = ',', escape = '\"', filename = true, force_not_null = ['col1', 'col2'], header = true, hive_partitioning = true, hive_types = true, hive_types_autocast = true, ignore_errors = true, max_line_size = 1000, names = ['col1', 'col2'], new_line = '\n', normalize_names = true, null_padding = true, nullstr = ['none', 'null'], parallel = true, quote = '\"', sample_size = 100, sep = ',', skip = 0, timestampformat = 'yyyy-MM-dd HH:mm:ss', types = ['BIGINT', 'VARCHAR'], union_by_name = true)"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); diff --git a/src/duckdb/delta.rs b/src/duckdb/delta.rs index 70c8a09c..2b024882 100644 --- a/src/duckdb/delta.rs +++ b/src/duckdb/delta.rs @@ -23,6 +23,7 @@ use strum::{AsRefStr, EnumIter}; #[derive(EnumIter, AsRefStr, PartialEq, Debug)] #[strum(serialize_all = "snake_case")] pub enum DeltaOption { + Cache, Files, PreserveCasing, Select, @@ -31,6 +32,7 @@ pub enum DeltaOption { impl OptionValidator for DeltaOption { fn is_required(&self) -> bool { match self { + Self::Cache => false, Self::Files => true, Self::PreserveCasing => false, Self::Select => false, @@ -38,7 +40,7 @@ impl OptionValidator for DeltaOption { } } -pub fn create_view( +pub fn create_duckdb_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -50,13 +52,15 @@ pub fn create_view( .ok_or_else(|| anyhow!("files option is required"))? ); - let default_select = "*".to_string(); - let select = table_options - .get(DeltaOption::Select.as_ref()) - .unwrap_or(&default_select); + let cache = table_options + .get(DeltaOption::Cache.as_ref()) + .map(|s| s.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let relation = if cache { "TABLE" } else { "VIEW" }; Ok(format!( - "CREATE VIEW IF NOT EXISTS {schema_name}.{table_name} AS SELECT {select} FROM delta_scan({files})" + "CREATE {relation} IF NOT EXISTS {schema_name}.{table_name} AS SELECT * FROM delta_scan({files})" )) } @@ -66,7 +70,7 @@ mod tests { use duckdb::Connection; #[test] - fn test_create_delta_view() { + fn test_create_delta_relation() { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( @@ -76,7 +80,7 @@ mod tests { let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM delta_scan('/data/delta')"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); diff --git a/src/duckdb/iceberg.rs b/src/duckdb/iceberg.rs index a803ae56..426ad8c7 100644 --- a/src/duckdb/iceberg.rs +++ b/src/duckdb/iceberg.rs @@ -27,6 +27,7 @@ pub enum IcebergOption { AllowMovedPaths, MetadataCompressionCodec, SkipSchemaInference, + Cache, Files, PreserveCasing, Select, @@ -38,6 +39,7 @@ impl OptionValidator for IcebergOption { Self::AllowMovedPaths => false, Self::MetadataCompressionCodec => false, Self::SkipSchemaInference => false, + Self::Cache => false, Self::Files => true, Self::PreserveCasing => false, Self::Select => false, @@ -45,7 +47,7 @@ impl OptionValidator for IcebergOption { } } -pub fn create_view( +pub fn create_duckdb_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -80,12 +82,14 @@ pub fn create_view( .collect::>() .join(", "); - let default_select = "*".to_string(); - let select = table_options - .get(IcebergOption::Select.as_ref()) - .unwrap_or(&default_select); + let cache = table_options + .get(IcebergOption::Cache.as_ref()) + .map(|s| s.eq_ignore_ascii_case("true")) + .unwrap_or(false); - Ok(format!("CREATE VIEW IF NOT EXISTS {schema_name}.{table_name} AS SELECT {select} FROM iceberg_scan({create_iceberg_str})")) + let relation = if cache { "TABLE" } else { "VIEW" }; + + Ok(format!("CREATE {relation} IF NOT EXISTS {schema_name}.{table_name} AS SELECT * FROM iceberg_scan({create_iceberg_str})")) } #[cfg(test)] @@ -94,7 +98,7 @@ mod tests { use duckdb::Connection; #[test] - fn test_create_iceberg_view() { + fn test_create_iceberg_relation() { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( @@ -104,7 +108,7 @@ mod tests { let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM iceberg_scan('/data/iceberg')"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); diff --git a/src/duckdb/json.rs b/src/duckdb/json.rs index 0772cdd6..1f871b3d 100644 --- a/src/duckdb/json.rs +++ b/src/duckdb/json.rs @@ -51,7 +51,7 @@ impl OptionValidator for JsonOption { } } -pub fn create_view( +pub fn create_duckdb_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -119,7 +119,7 @@ mod tests { )]); let expected = "CREATE VIEW IF NOT EXISTS main.json_test AS SELECT * FROM read_json('/data/file1.json')"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); @@ -175,7 +175,7 @@ mod tests { ]); let expected = "CREATE VIEW IF NOT EXISTS main.json_test AS SELECT key1 FROM read_json(['/data/file1.json', '/data/file2.json'], columns = {'key1': 'INTEGER', 'key2': 'VARCHAR'}, compression = 'uncompressed', convert_strings_to_integers = false, dateformat = '%d/%m/%Y', filename = true, format = 'array', hive_partitioning = false, ignore_errors = true, maximum_depth = 4096, maximum_object_size = 65536, records = auto, sample_size = -1, timestampformat = 'yyyy-MM-dd', union_by_name = true)"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); diff --git a/src/duckdb/parquet.rs b/src/duckdb/parquet.rs index dc515415..859859c3 100644 --- a/src/duckdb/parquet.rs +++ b/src/duckdb/parquet.rs @@ -27,6 +27,7 @@ use super::utils; #[strum(serialize_all = "snake_case")] pub enum ParquetOption { BinaryAsString, + Cache, FileName, FileRowNumber, Files, @@ -43,6 +44,7 @@ impl OptionValidator for ParquetOption { fn is_required(&self) -> bool { match self { Self::BinaryAsString => false, + Self::Cache => false, Self::FileName => false, Self::FileRowNumber => false, Self::Files => true, @@ -50,13 +52,13 @@ impl OptionValidator for ParquetOption { Self::HiveTypes => false, Self::HiveTypesAutocast => false, Self::PreserveCasing => false, - Self::Select => false, Self::UnionByName => false, + Self::Select => false, } } } -pub fn create_view( +pub fn create_duckdb_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -110,12 +112,14 @@ pub fn create_view( .collect::>() .join(", "); - let default_select = "*".to_string(); - let select = table_options - .get(ParquetOption::Select.as_ref()) - .unwrap_or(&default_select); + let cache = table_options + .get(ParquetOption::Cache.as_ref()) + .map(|s| s.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let relation = if cache { "TABLE" } else { "VIEW" }; - Ok(format!("CREATE VIEW IF NOT EXISTS {schema_name}.{table_name} AS SELECT {select} FROM read_parquet({create_parquet_str})")) + Ok(format!("CREATE {relation} IF NOT EXISTS {schema_name}.{table_name} AS SELECT * FROM read_parquet({create_parquet_str})")) } #[cfg(test)] @@ -124,14 +128,14 @@ mod tests { use duckdb::Connection; #[test] - fn test_create_parquet_view_single_file() { + fn test_create_parquet_relation_single_file() { let table_name = "test"; let schema_name = "main"; let files = "/data/file.parquet"; let table_options = HashMap::from([(ParquetOption::Files.as_ref().to_string(), files.to_string())]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_parquet('/data/file.parquet')"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); @@ -143,7 +147,7 @@ mod tests { } #[test] - fn test_create_parquet_view_multiple_files() { + fn test_create_parquet_relation_multiple_files() { let table_name = "test"; let schema_name = "main"; let files = "/data/file1.parquet, /data/file2.parquet"; @@ -151,7 +155,7 @@ mod tests { HashMap::from([(ParquetOption::Files.as_ref().to_string(), files.to_string())]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_parquet(['/data/file1.parquet', '/data/file2.parquet'])"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); @@ -163,7 +167,7 @@ mod tests { } #[test] - fn test_create_parquet_view_with_options() { + fn test_create_parquet_relation_with_options() { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([ @@ -202,7 +206,7 @@ mod tests { ]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_parquet('/data/file.parquet', binary_as_string = true, filename = false, file_row_number = true, hive_partitioning = true, hive_types = {'release': DATE, 'orders': BIGINT}, hive_types_autocast = true, union_by_name = true)"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); diff --git a/src/duckdb/spatial.rs b/src/duckdb/spatial.rs index cd76239b..5ac9ece8 100644 --- a/src/duckdb/spatial.rs +++ b/src/duckdb/spatial.rs @@ -28,6 +28,7 @@ use crate::fdw::base::OptionValidator; #[strum(serialize_all = "snake_case")] pub enum SpatialOption { Files, + Cache, SequentialLayerScan, SpatialFilter, OpenOptions, @@ -42,6 +43,7 @@ impl OptionValidator for SpatialOption { fn is_required(&self) -> bool { match self { Self::Files => true, + Self::Cache => false, Self::SequentialLayerScan => false, Self::SpatialFilter => false, Self::OpenOptions => false, @@ -54,7 +56,7 @@ impl OptionValidator for SpatialOption { } } -pub fn create_view( +pub fn create_duckdb_relation( table_name: &str, schema_name: &str, table_options: HashMap, @@ -73,8 +75,15 @@ pub fn create_view( }) .collect::>(); + let cache = table_options + .get(SpatialOption::Cache.as_ref()) + .map(|s| s.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let relation = if cache { "TABLE" } else { "VIEW" }; + Ok(format!( - "CREATE VIEW IF NOT EXISTS {}.{} AS SELECT * FROM st_read({})", + "CREATE {relation} IF NOT EXISTS {}.{} AS SELECT * FROM st_read({})", schema_name, table_name, spatial_options.join(", "), @@ -97,7 +106,7 @@ mod tests { let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM st_read('/data/spatial')"; - let actual = create_view(table_name, schema_name, table_options).unwrap(); + let actual = create_duckdb_relation(table_name, schema_name, table_options).unwrap(); assert_eq!(expected, actual); diff --git a/src/env.rs b/src/env.rs new file mode 100644 index 00000000..120ce301 --- /dev/null +++ b/src/env.rs @@ -0,0 +1,131 @@ +use anyhow::{anyhow, Result}; +use duckdb::Connection; +use pgrx::*; +use std::ffi::CStr; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +// One connection per database, so 128 databases can have a DuckDB connection +const MAX_CONNECTIONS: usize = 128; +pub static DUCKDB_CONNECTION_CACHE: PgLwLock = PgLwLock::new(); + +pub struct DuckdbConnection { + conn_map: heapless::FnvIndexMap, + conn_lru: heapless::Deque, +} + +unsafe impl PGRXSharedMemory for DuckdbConnection {} + +impl Default for DuckdbConnection { + fn default() -> Self { + Self::new() + } +} + +impl DuckdbConnection { + fn new() -> Self { + Self { + conn_map: heapless::FnvIndexMap::<_, _, MAX_CONNECTIONS>::new(), + conn_lru: heapless::Deque::<_, MAX_CONNECTIONS>::new(), + } + } +} + +#[derive(Clone, Debug)] +struct DuckdbConnectionInner(Arc>); + +impl Default for DuckdbConnectionInner { + fn default() -> Self { + let mut duckdb_path = postgres_data_dir_path(); + duckdb_path.push("pg_analytics"); + + if !duckdb_path.exists() { + std::fs::create_dir_all(duckdb_path.clone()) + .expect("failed to create duckdb data directory"); + } + + duckdb_path.push(postgres_database_oid().to_string()); + duckdb_path.set_extension("db3"); + + let conn = Connection::open(duckdb_path).expect("failed to open duckdb connection"); + DuckdbConnectionInner(Arc::new(Mutex::new(conn))) + } +} + +fn postgres_data_dir_path() -> PathBuf { + let data_dir = unsafe { + CStr::from_ptr(pg_sys::DataDir) + .to_string_lossy() + .into_owned() + }; + PathBuf::from(data_dir) +} + +fn postgres_database_oid() -> u32 { + unsafe { pg_sys::MyDatabaseId.as_u32() } +} + +#[macro_export] +macro_rules! with_connection { + ($body:expr) => {{ + let conn = get_global_connection()?; + let conn = conn + .lock() + .map_err(|e| anyhow!("Failed to acquire lock: {}", e))?; + $body(&*conn) // Dereference the MutexGuard to get &Connection + }}; +} + +pub fn get_global_connection() -> Result>> { + let database_id = postgres_database_oid(); + let mut cache = DUCKDB_CONNECTION_CACHE.exclusive(); + + if cache.conn_map.contains_key(&database_id) { + // Move the accessed connection to the back of the LRU queue + let mut new_lru = heapless::Deque::<_, MAX_CONNECTIONS>::new(); + for &id in cache.conn_lru.iter() { + if id != database_id { + new_lru + .push_back(id) + .unwrap_or_else(|_| panic!("Failed to push to LRU queue")); + } + } + new_lru + .push_back(database_id) + .unwrap_or_else(|_| panic!("Failed to push to LRU queue")); + cache.conn_lru = new_lru; + + // Now we can safely borrow conn_map again + Ok(cache.conn_map.get(&database_id).unwrap().0.clone()) + } else { + if cache.conn_map.len() >= MAX_CONNECTIONS { + if let Some(least_recently_used) = cache.conn_lru.pop_front() { + cache.conn_map.remove(&least_recently_used); + } + } + let conn = DuckdbConnectionInner::default(); + cache + .conn_map + .insert(database_id, conn.clone()) + .map_err(|_| anyhow!("Failed to insert into connection map"))?; + cache + .conn_lru + .push_back(database_id) + .map_err(|_| anyhow!("Failed to push to LRU queue"))?; + Ok(conn.0) + } +} + +pub fn interrupt_all_connections() -> Result<()> { + let cache = DUCKDB_CONNECTION_CACHE.exclusive(); + for &database_id in cache.conn_lru.iter() { + if let Some(conn) = cache.conn_map.get(&database_id) { + let conn = conn + .0 + .lock() + .map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?; + conn.interrupt(); + } + } + Ok(()) +} diff --git a/src/fdw/base.rs b/src/fdw/base.rs index 45ff01fc..095b15c5 100644 --- a/src/fdw/base.rs +++ b/src/fdw/base.rs @@ -15,7 +15,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, Result}; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; @@ -23,14 +23,11 @@ use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use thiserror::Error; -use super::handler::FdwHandler; use crate::duckdb::connection; use crate::schema::cell::*; #[cfg(debug_assertions)] use crate::DEBUG_GUCS; -const DEFAULT_SECRET: &str = "default_secret"; - pub trait BaseFdw { // Getter methods fn get_current_batch(&self) -> Option; @@ -69,16 +66,9 @@ pub trait BaseFdw { // Register view with DuckDB let user_mapping_options = self.get_user_mapping_options(); - let foreign_table = unsafe { pg_sys::GetForeignTable(pg_relation.oid()) }; - let table_options = unsafe { options_to_hashmap((*foreign_table).options)? }; - let handler = FdwHandler::from(foreign_table); - register_duckdb_view( - table_name, - schema_name, - table_options, - user_mapping_options, - handler, - )?; + if !user_mapping_options.is_empty() { + connection::create_secret(user_mapping_options)?; + } // Construct SQL scan statement let targets = if columns.is_empty() { @@ -212,52 +202,6 @@ pub fn validate_options(opt_list: Vec>, valid_options: Vec, - user_mapping_options: HashMap, - handler: FdwHandler, -) -> Result<()> { - if !user_mapping_options.is_empty() { - connection::create_secret(DEFAULT_SECRET, user_mapping_options)?; - } - - if !connection::view_exists(table_name, schema_name)? { - // Initialize DuckDB view - connection::execute( - format!("CREATE SCHEMA IF NOT EXISTS {schema_name}").as_str(), - [], - )?; - - match handler { - FdwHandler::Csv => { - connection::create_csv_view(table_name, schema_name, table_options)?; - } - FdwHandler::Delta => { - connection::create_delta_view(table_name, schema_name, table_options)?; - } - FdwHandler::Iceberg => { - connection::create_iceberg_view(table_name, schema_name, table_options)?; - } - FdwHandler::Parquet => { - connection::create_parquet_view(table_name, schema_name, table_options)?; - } - FdwHandler::Spatial => { - connection::create_spatial_view(table_name, schema_name, table_options)?; - } - FdwHandler::Json => { - connection::create_json_view(table_name, schema_name, table_options)?; - } - _ => { - bail!("got unexpected fdw_handler") - } - }; - } - - Ok(()) -} - #[derive(Error, Debug)] pub enum BaseFdwError { #[error(transparent)] diff --git a/src/fdw/trigger.rs b/src/fdw/trigger.rs index 9b900963..6b70268c 100644 --- a/src/fdw/trigger.rs +++ b/src/fdw/trigger.rs @@ -15,14 +15,17 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use anyhow::{bail, Result}; +use anyhow::{anyhow, bail, Result}; use pgrx::*; +use std::collections::HashMap; use std::ffi::CStr; use supabase_wrappers::prelude::{options_to_hashmap, user_mapping_options}; -use super::base::register_duckdb_view; use crate::duckdb::connection; +use crate::env::get_global_connection; use crate::fdw::handler::FdwHandler; +use crate::with_connection; +use duckdb::Connection; extension_sql!( r#" @@ -118,24 +121,20 @@ unsafe fn auto_create_schema_impl(fcinfo: pg_sys::FunctionCallInfo) -> Result<() ); } - // Drop stale view - connection::execute( - format!("DROP VIEW IF EXISTS {schema_name}.{table_name}").as_str(), - [], - )?; + // Drop stale relation + connection::drop_relation(table_name, schema_name)?; - // Register DuckDB view + // Create DuckDB secrets let foreign_server = unsafe { pg_sys::GetForeignServer((*foreign_table).serverid) }; let user_mapping_options = unsafe { user_mapping_options(foreign_server) }; + if !user_mapping_options.is_empty() { + connection::create_secret(user_mapping_options)?; + } + + // Create DuckDB relation let table_options = unsafe { options_to_hashmap((*foreign_table).options)? }; let handler = FdwHandler::from(foreign_table); - register_duckdb_view( - table_name, - schema_name, - table_options.clone(), - user_mapping_options, - handler, - )?; + create_duckdb_relation(table_name, schema_name, table_options.clone(), handler)?; // If the table already has columns, no need for auto schema creation let relation = pg_sys::relation_open(oid, pg_sys::AccessShareLock as i32); @@ -147,30 +146,31 @@ unsafe fn auto_create_schema_impl(fcinfo: pg_sys::FunctionCallInfo) -> Result<() pg_sys::RelationClose(relation); // Get DuckDB schema - let conn = unsafe { &*connection::get_global_connection().get() }; - let query = format!("DESCRIBE {schema_name}.{table_name}"); - let mut stmt = conn.prepare(&query)?; - - let schema_rows = stmt - .query_map([], |row| { - Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) - })? - .map(|row| row.unwrap()) - .collect::>(); - - if schema_rows.is_empty() { - return Ok(()); - } - - // Alter Postgres table to match DuckDB schema - let preserve_casing = table_options - .get("preserve_casing") - .map_or(false, |s| s.eq_ignore_ascii_case("true")); - let alter_table_statement = - construct_alter_table_statement(schema_name, table_name, schema_rows, preserve_casing); - Spi::run(alter_table_statement.as_str())?; - - Ok(()) + with_connection!(|conn: &Connection| { + let query = format!("DESCRIBE {schema_name}.{table_name}"); + let mut stmt = conn.prepare(&query)?; + + let schema_rows = stmt + .query_map([], |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + })? + .map(|row| row.unwrap()) + .collect::>(); + + if schema_rows.is_empty() { + return Ok(()); + } + + // Alter Postgres table to match DuckDB schema + let preserve_casing = table_options + .get("preserve_casing") + .map_or(false, |s| s.eq_ignore_ascii_case("true")); + let alter_table_statement = + construct_alter_table_statement(schema_name, table_name, schema_rows, preserve_casing); + Spi::run(alter_table_statement.as_str())?; + + Ok(()) + }) } #[inline] @@ -274,3 +274,42 @@ fn construct_alter_table_statement( column_definitions.join(", ") ) } + +#[inline] +pub fn create_duckdb_relation( + table_name: &str, + schema_name: &str, + table_options: HashMap, + handler: FdwHandler, +) -> Result<()> { + connection::execute( + format!("CREATE SCHEMA IF NOT EXISTS {schema_name}").as_str(), + [], + )?; + + match handler { + FdwHandler::Csv => { + connection::create_csv_relation(table_name, schema_name, table_options)?; + } + FdwHandler::Delta => { + connection::create_delta_relation(table_name, schema_name, table_options)?; + } + FdwHandler::Iceberg => { + connection::create_iceberg_relation(table_name, schema_name, table_options)?; + } + FdwHandler::Parquet => { + connection::create_parquet_relation(table_name, schema_name, table_options)?; + } + FdwHandler::Spatial => { + connection::create_spatial_relation(table_name, schema_name, table_options)?; + } + FdwHandler::Json => { + connection::create_json_relation(table_name, schema_name, table_options)?; + } + _ => { + bail!("got unexpected fdw_handler") + } + }; + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 5e7d5654..d95c4ad9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ mod api; #[cfg(debug_assertions)] mod debug_guc; mod duckdb; +mod env; mod fdw; mod hooks; mod schema; @@ -41,15 +42,14 @@ static mut EXTENSION_HOOK: ExtensionHook = ExtensionHook; #[pg_guard] pub extern "C" fn _PG_init() { + pgrx::warning!("pga:: extension is being initialized"); #[allow(static_mut_refs)] #[allow(deprecated)] unsafe { register_hook(&mut EXTENSION_HOOK) }; - // TODO: Depends on above TODO - // GUCS.init("pg_analytics"); - // setup_telemetry_background_worker(ParadeExtension::PgAnalytics); + pg_shmem_init!(env::DUCKDB_CONNECTION_CACHE); #[cfg(debug_assertions)] DEBUG_GUCS.init();