Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PgBindIter for encoding and use it as the implementation encoding &[T] #3651

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions sqlx-postgres/src/bind_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use sqlx_core::{
database::Database,
encode::{Encode, IsNull},
types::Type,
};

use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};

pub struct PgBindIter<I>(I);

impl<I> PgBindIter<I> {
pub fn new(inner: I) -> Self {
Self(inner)
}
}

impl<I> From<I> for PgBindIter<I> {
fn from(inner: I) -> Self {
Self::new(inner)
}
}

impl<T, I> Type<Postgres> for PgBindIter<I>
where
T: Type<Postgres> + PgHasArrayType,
I: Iterator<Item = T>,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
T::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}

impl<'q, T, I> PgBindIter<I>
where
I: Iterator<Item = T>,
T: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
// need ownership to iterate
mut iter: I,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
let first = iter.next();
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let len_start = buf.len();
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
buf.extend(1_i32.to_be_bytes()); // lower bound

match first {
Some(first) => buf.encode(first)?,
None => return Ok(IsNull::No),
}

let mut count = 1_i32;
const MAX: usize = i32::MAX as usize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an oversight. We've already written 1 item at this point.

Suggested change
const MAX: usize = i32::MAX as usize;
const MAX: usize = i32::MAX as usize - 1;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I missed this when removing the peeking iterator.


for value in (&mut iter).take(MAX) {
buf.encode(value)?;
count += 1;
}

const OVERFLOW: usize = MAX + 1;
if iter.next().is_some() {
return Err(format!("encoded iterator is too large for Postgres: {OVERFLOW}").into());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if printing the Iterator::size_hint would be more interesting here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got it printing the larger of the size_hint lower bound and OVERFLOW now

}

// set the length now that we know what it is.
buf[len_start..(len_start + 4)].copy_from_slice(count.to_be_bytes().as_slice());
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved

Ok(IsNull::No)
}
}

impl<'q, T, I> Encode<'q, Postgres> for PgBindIter<I>
where
T: Type<Postgres> + Encode<'q, Postgres>,
// Clone is required for the encode_by_ref call since we can't iterate with a shared reference
I: Iterator<Item = T> + Clone,
{
fn encode_by_ref(
&self,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
Self::encode_inner(self.0.clone(), buf)
}
fn encode(
self,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>>
tylerhawkes marked this conversation as resolved.
Show resolved Hide resolved
where
Self: Sized,
{
Self::encode_inner(self.0, buf)
}
}
2 changes: 2 additions & 0 deletions sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::executor::Executor;

mod advisory_lock;
mod arguments;
mod bind_iter;
mod column;
mod connection;
mod copy;
Expand Down Expand Up @@ -44,6 +45,7 @@ pub(crate) use sqlx_core::driver_prelude::*;

pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use bind_iter::PgBindIter;
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{PgCopyIn, PgPoolCopyExt};
Expand Down
32 changes: 3 additions & 29 deletions sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::borrow::Cow;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::type_info::PgType;
use crate::types::Oid;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
Expand Down Expand Up @@ -156,39 +155,14 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = self
.first()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

// element type
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let array_len = i32::try_from(self.len()).map_err(|_| {
// do the length check early to avoid doing unnecessary work
i32::try_from(self.len()).map_err(|_| {
format!(
"encoded array length is too large for Postgres: {}",
self.len()
)
})?;

buf.extend(array_len.to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound

for element in self.iter() {
buf.encode(element)?;
}

Ok(IsNull::No)
crate::bind_iter::PgBindIter::new(self.iter()).encode(buf)
}
}

Expand Down
Loading