diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index f1f7ce7d4b..7fdc8c4739 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -33,10 +33,11 @@ use std::io; /// use sqlx::PgConnection; /// use sqlx::postgres::PgPoolOptions; /// use sqlx::Connection; +/// use sqlx::pool::PoolConnectMetadata; /// -/// # async fn _example() -> sqlx::Result<()> { -/// // `PoolConnector` is implemented for closures but has restrictions on returning borrows -/// // due to current language limitations. +/// async fn _example() -> sqlx::Result<()> { +/// // `PoolConnector` is implemented for closures but this has restrictions on returning borrows +/// // due to current language limitations. Custom implementations are not subject to this. /// // /// // This example shows how to get around this using `Arc`. /// let database_url: Arc = "postgres://...".into(); @@ -44,7 +45,8 @@ use std::io; /// let pool = PgPoolOptions::new() /// .min_connections(5) /// .max_connections(30) -/// .connect_with_connector(move |meta| { +/// // Type annotation on the argument is required for the trait impl to reseolve. +/// .connect_with_connector(move |meta: PoolConnectMetadata| { /// let database_url = database_url.clone(); /// async move { /// println!( @@ -57,7 +59,9 @@ use std::io; /// let mut conn = PgConnection::connect(&database_url).await?; /// /// // Override the time zone of the connection. -/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'").await?; +/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'") +/// .execute(&mut conn) +/// .await?; /// /// Ok(conn) /// } @@ -76,13 +80,14 @@ use std::io; /// /// ```rust,no_run /// use std::sync::Arc; -/// use tokio::sync::{Mutex, RwLock}; +/// use tokio::sync::RwLock; /// use sqlx::PgConnection; /// use sqlx::postgres::PgConnectOptions; /// use sqlx::postgres::PgPoolOptions; /// use sqlx::ConnectOptions; +/// use sqlx::pool::PoolConnectMetadata; /// -/// # async fn _example() -> sqlx::Result<()> { +/// async fn _example() -> sqlx::Result<()> { /// // If you do not wish to hold the lock during the connection attempt, /// // you could use `Arc` instead. /// let connect_opts: Arc> = Arc::new(RwLock::new("postgres://...".parse()?)); @@ -90,7 +95,7 @@ use std::io; /// let connect_opts_ = connect_opts.clone(); /// /// let pool = PgPoolOptions::new() -/// .connect_with_connector(move |meta| { +/// .connect_with_connector(move |meta: PoolConnectMetadata| { /// let connect_opts_ = connect_opts.clone(); /// async move { /// println!( diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 3130b4f1c6..2502bac8ab 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,44 +1,13 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::connection::ConnectOptions; +use sqlx_core::pool::PoolConnectMetadata; use std::sync::{ - atomic::{AtomicI32, AtomicUsize, Ordering}, + atomic::{AtomicI32, Ordering}, Arc, Mutex, }; use std::time::Duration; -#[sqlx_macros::test] -async fn pool_should_invoke_after_connect() -> anyhow::Result<()> { - sqlx::any::install_default_drivers(); - - let counter = Arc::new(AtomicUsize::new(0)); - - let pool = AnyPoolOptions::new() - .after_connect({ - let counter = counter.clone(); - move |_conn, _meta| { - let counter = counter.clone(); - Box::pin(async move { - counter.fetch_add(1, Ordering::SeqCst); - - Ok(()) - }) - } - }) - .connect(&dotenvy::var("DATABASE_URL")?) - .await?; - - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - - // since connections are released asynchronously, - // `.after_connect()` may be called more than once - assert!(counter.load(Ordering::SeqCst) >= 1); - - Ok(()) -} - // https://github.com/launchbadge/sqlx/issues/527 #[sqlx_macros::test] async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { @@ -83,38 +52,13 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx_test::setup_if_needed(); - let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?; + let conn_options: Arc = Arc::new(std::env::var("DATABASE_URL")?.parse()?); let current_id = AtomicI32::new(0); let pool = AnyPoolOptions::new() .max_connections(1) .acquire_timeout(Duration::from_secs(5)) - .after_connect(move |conn, meta| { - assert_eq!(meta.age, Duration::ZERO); - assert_eq!(meta.idle_for, Duration::ZERO); - - let id = current_id.fetch_add(1, Ordering::AcqRel); - - Box::pin(async move { - let statement = format!( - // language=SQL - r#" - CREATE TEMPORARY TABLE conn_stats( - id int primary key, - before_acquire_calls int default 0, - after_release_calls int default 0 - ); - INSERT INTO conn_stats(id) VALUES ({}); - "#, - // Until we have generalized bind parameters - id - ); - - conn.execute(&statement[..]).await?; - Ok(()) - }) - }) .before_acquire(|conn, meta| { // `age` and `idle_for` should both be nonzero assert_ne!(meta.age, Duration::ZERO); @@ -165,7 +109,31 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { }) }) // Don't establish a connection yet. - .connect_lazy_with(conn_options); + .connect_lazy_with_connector(move |_meta: PoolConnectMetadata| { + let connect_opts = Arc::clone(&conn_options); + let id = current_id.fetch_add(1, Ordering::AcqRel); + + async move { + let mut conn = connect_opts.connect().await?; + + let statement = format!( + // language=SQL + r#" + CREATE TEMPORARY TABLE conn_stats( + id int primary key, + before_acquire_calls int default 0, + after_release_calls int default 0 + ); + INSERT INTO conn_stats(id) VALUES ({}); + "#, + // Until we have generalized bind parameters + id + ); + + conn.execute(&statement[..]).await?; + Ok(conn) + } + }); // Expected pattern of (id, before_acquire_calls, after_release_calls) let pattern = [ diff --git a/tests/sqlite/any.rs b/tests/sqlite/any.rs index 856db70c05..b71c3ba43d 100644 --- a/tests/sqlite/any.rs +++ b/tests/sqlite/any.rs @@ -1,4 +1,4 @@ -use sqlx::{Any, Sqlite}; +use sqlx::Any; use sqlx_test::new; #[sqlx_macros::test]