diff --git a/src/duckdb/connection.rs b/src/duckdb/connection.rs index a40ac0f4..530b2412 100644 --- a/src/duckdb/connection.rs +++ b/src/duckdb/connection.rs @@ -66,6 +66,7 @@ pub fn get_global_connection() -> &'static UnsafeCell { INIT.call_once(|| { init_globals(); }); + #[allow(static_mut_refs)] unsafe { GLOBAL_CONNECTION .as_ref() @@ -77,6 +78,7 @@ fn get_global_statement() -> &'static UnsafeCell>> { INIT.call_once(|| { init_globals(); }); + #[allow(static_mut_refs)] unsafe { GLOBAL_STATEMENT .as_ref() @@ -88,7 +90,10 @@ fn get_global_arrow() -> &'static UnsafeCell>> { INIT.call_once(|| { init_globals(); }); - unsafe { GLOBAL_ARROW.as_ref().expect("Arrow not initialized") } + #[allow(static_mut_refs)] + unsafe { + GLOBAL_ARROW.as_ref().expect("Arrow not initialized") + } } pub fn create_csv_view( diff --git a/src/duckdb/json.rs b/src/duckdb/json.rs index 0772cdd6..efb1ac39 100644 --- a/src/duckdb/json.rs +++ b/src/duckdb/json.rs @@ -98,10 +98,10 @@ fn extract_option( table_options: &HashMap, quote: bool, ) -> Option { - return table_options.get(option.as_ref()).map(|res| match quote { + table_options.get(option.as_ref()).map(|res| match quote { true => format!("{option} = '{res}'"), false => format!("{option} = {res}"), - }); + }) } #[cfg(test)] diff --git a/tests/tests/explain.rs b/tests/tests/explain.rs index 64b3bc25..8c4e3f7b 100644 --- a/tests/tests/explain.rs +++ b/tests/tests/explain.rs @@ -193,3 +193,136 @@ async fn test_explain_foreign_table(#[future(awt)] s3: S3, mut conn: PgConnectio Ok(()) } + +#[rstest] +async fn test_explain_foreign_table_duckdb_style( + #[future(awt)] s3: S3, + mut conn: PgConnection, +) -> Result<()> { + NycTripsTable::setup().execute(&mut conn); + + let rows: Vec = "SELECT * FROM nyc_trips".fetch(&mut conn); + s3.client.create_bucket().bucket(S3_BUCKET).send().await?; + s3.create_bucket(S3_BUCKET).await?; + s3.put_rows(S3_BUCKET, S3_KEY, &rows).await?; + + NycTripsTable::setup_s3_listing_fdw(&s3.url.clone(), &format!("s3://{S3_BUCKET}/{S3_KEY}")) + .execute(&mut conn); + + let explain: Vec<(String,)> = "EXPLAIN SELECT COUNT(*) FROM trips".fetch(&mut conn); + assert_eq!(explain[0].0, "DuckDB Scan: SELECT COUNT(*) FROM trips"); + + let explain: Result, sqlx::Error> = + "EXPLAIN (style duckdb) SELECT COUNT(*) FROM trips".fetch_result(&mut conn); + + let expected_plan = vec![ + "┌───────────────────────────┐", + "│ UNGROUPED_AGGREGATE │", + "│ ──────────────────── │", + "│ Aggregates: │", + "│ count_star() │", + "└─────────────┬─────────────┘", + "┌─────────────┴─────────────┐", + "│ PROJECTION │", + "│ ──────────────────── │", + "│ 42 │", + "│ │", + "│ ~100 Rows │", + "└─────────────┬─────────────┘", + "┌─────────────┴─────────────┐", + "│ READ_PARQUET │", + "│ ──────────────────── │", + "│ Function: │", + "│ READ_PARQUET │", + "│ │", + "│ ~100 Rows │", + "└───────────────────────────┘", + ]; + + assert!(explain.is_ok()); + if let Ok(plan) = explain { + assert_eq!(plan.len(), expected_plan.len()); + for ((row,), expect_row) in plan.iter().zip(expected_plan) { + assert_eq!(row, expect_row); + } + } + + // test (style duckdb, analyze) + let explain: Result, sqlx::Error> = + "EXPLAIN (style duckdb, analyze) SELECT COUNT(*) FROM trips".fetch_result(&mut conn); + + let expected_plan = vec![ + "┌─────────────────────────────────────┐", + "│┌───────────────────────────────────┐│", + "││ Query Profiling Information ││", + "│└───────────────────────────────────┘│", + "└─────────────────────────────────────┘", + "EXPLAIN ANALYZE SELECT COUNT(*) FROM trips", + "┌─────────────────────────────────────┐", + "│┌───────────────────────────────────┐│", + "││ HTTPFS HTTP Stats ││", + "││ ││", + "││ in: 3.0 KiB ││", + "││ out: 0 bytes ││", + "││ #HEAD: 1 ││", + "││ #GET: 2 ││", + "││ #PUT: 0 ││", + "││ #POST: 0 ││", + "│└───────────────────────────────────┘│", + "└─────────────────────────────────────┘", + "┌────────────────────────────────────────────────┐", + "│┌──────────────────────────────────────────────┐│", + "││ Total Time: 0.0007s ││", + "│└──────────────────────────────────────────────┘│", + "└────────────────────────────────────────────────┘", + "┌───────────────────────────┐", + "│ QUERY │", + "│ ──────────────────── │", + "│ 0 Rows │", + "│ (0.00s) │", + "└─────────────┬─────────────┘", + "┌─────────────┴─────────────┐", + "│ EXPLAIN_ANALYZE │", + "│ ──────────────────── │", + "│ 0 Rows │", + "│ (0.00s) │", + "└─────────────┬─────────────┘", + "┌─────────────┴─────────────┐", + "│ UNGROUPED_AGGREGATE │", + "│ ──────────────────── │", + "│ Aggregates: │", + "│ count_star() │", + "│ │", + "│ 1 Rows │", + "│ (0.00s) │", + "└─────────────┬─────────────┘", + "┌─────────────┴─────────────┐", + "│ PROJECTION │", + "│ ──────────────────── │", + "│ 42 │", + "│ │", + "│ 100 Rows │", + "│ (0.00s) │", + "└─────────────┬─────────────┘", + "┌─────────────┴─────────────┐", + "│ TABLE_SCAN │", + "│ ──────────────────── │", + "│ Function: │", + "│ READ_PARQUET │", + "│ │", + "│ 100 Rows │", + "│ (0.00s) │", + "└───────────────────────────┘", + ]; + assert!(explain.is_ok()); + if let Ok(plan) = explain { + assert_eq!(plan.len(), expected_plan.len()); + for ((row,), expect_row) in plan.iter().zip(expected_plan) { + if expect_row.contains("Time") { + continue; + } + assert_eq!(row, expect_row); + } + } + Ok(()) +}