Skip to content

Commit 0d6c8b2

Browse files
authored
refactor(rust): implement ChunkArray::(try_)from_chunk_iter (pola-rs#10395)
1 parent ab6e87b commit 0d6c8b2

File tree

43 files changed

+547
-871
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+547
-871
lines changed

contribution/polars_ops_multiple_arguments/src/lib.rs

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,21 @@ fn compute_chunked_array_2_args<T: PolarsNumericType>(
4747
ca_1: &ChunkedArray<T>,
4848
ca_2: &ChunkedArray<T>,
4949
) -> ChunkedArray<T> {
50-
// this ensures both ChunkedArrays have the same number of chunks with the same offset
51-
// and the same length.
50+
// This ensures both ChunkedArrays have the same number of chunks with the
51+
// same offset and the same length.
5252
let (ca_1, ca_2) = align_chunks_binary(ca_1, ca_2);
53-
5453
let chunks = ca_1
5554
.downcast_iter()
5655
.zip(ca_2.downcast_iter())
57-
.map(|(arr_1, arr_2)| compute_kernel(arr_1, arr_2).boxed())
58-
.collect::<Vec<_>>();
59-
60-
// Safety: we are sure the `ArrayRef` holds type `T`
61-
unsafe { ChunkedArray::from_chunks(ca_1.name(), chunks) }
56+
.map(|(arr_1, arr_2)| compute_kernel(arr_1, arr_2));
57+
ChunkedArray::from_chunk_iter(ca_1.name(), chunks)
6258
}
6359

6460
pub fn compute_expr_2_args(arg_1: &Series, arg_2: &Series) -> Series {
65-
// dispatch the numerical series to `compute_chunked_array_2_args`
61+
// Dispatch the numerical series to `compute_chunked_array_2_args`.
6662
with_match_physical_numeric_polars_type!(arg_1.dtype(), |$T| {
67-
let ca_1: &ChunkedArray<$T> = arg_1.as_ref().as_ref().as_ref();
68-
let ca_2: &ChunkedArray<$T> = arg_2.as_ref().as_ref().as_ref();
69-
70-
compute_chunked_array_2_args(ca_1, ca_2).into_series()
71-
})
63+
let ca_1: &ChunkedArray<$T> = arg_1.as_ref().as_ref().as_ref();
64+
let ca_2: &ChunkedArray<$T> = arg_2.as_ref().as_ref().as_ref();
65+
compute_chunked_array_2_args(ca_1, ca_2).into_series()
66+
})
7267
}

crates/polars-arrow/src/array/utf8.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ impl<T: AsRef<str>> AsRef<[u8]> for StrAsBytes<T> {
5555

5656
pub trait Utf8FromIter {
5757
#[inline]
58-
fn from_values_iter<I, S>(iter: I, len: usize, value_cap: usize) -> Utf8Array<i64>
58+
fn from_values_iter<I, S>(iter: I, len: usize, size_hint: usize) -> Utf8Array<i64>
5959
where
6060
S: AsRef<str>,
6161
I: Iterator<Item = S>,
6262
{
6363
let iter = iter.map(StrAsBytes);
64-
let (offsets, values) = unsafe { fill_offsets_and_values(iter, value_cap, len) };
64+
let (offsets, values) = unsafe { fill_offsets_and_values(iter, size_hint, len) };
6565
unsafe {
6666
Utf8Array::new_unchecked(DataType::LargeUtf8, offsets.into(), values.into(), None)
6767
}

crates/polars-arrow/src/kernels/time.rs

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@ use arrow::error::{Error as ArrowError, Result};
55
use arrow::temporal_conversions::{
66
timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime,
77
};
8-
#[cfg(feature = "timezones")]
98
use chrono::{LocalResult, NaiveDateTime, TimeZone};
10-
#[cfg(feature = "timezones")]
119
use chrono_tz::Tz;
1210
use polars_error::polars_bail;
1311

1412
use crate::error::PolarsResult;
15-
use crate::prelude::ArrayRef;
1613

17-
#[cfg(feature = "timezones")]
1814
fn convert_to_naive_local(
1915
from_tz: &Tz,
2016
to_tz: &Tz,
@@ -41,68 +37,55 @@ fn convert_to_naive_local(
4137
}
4238
}
4339

44-
#[cfg(feature = "timezones")]
4540
fn convert_to_timestamp(
4641
from_tz: Tz,
4742
to_tz: Tz,
4843
arr: &PrimitiveArray<i64>,
4944
tu: TimeUnit,
5045
use_earliest: Option<bool>,
51-
) -> PolarsResult<ArrayRef> {
52-
match tu {
53-
TimeUnit::Millisecond => {
54-
let data = try_unary(
55-
arr,
56-
|value| {
57-
let ndt = timestamp_ms_to_datetime(value);
58-
Ok(convert_to_naive_local(&from_tz, &to_tz, ndt, use_earliest)?
59-
.timestamp_millis())
60-
},
61-
ArrowDataType::Int64,
62-
)?;
63-
Ok(Box::new(data))
64-
}
65-
TimeUnit::Microsecond => {
66-
let data = try_unary(
67-
arr,
68-
|value| {
69-
let ndt = timestamp_us_to_datetime(value);
70-
Ok(convert_to_naive_local(&from_tz, &to_tz, ndt, use_earliest)?
71-
.timestamp_micros())
72-
},
73-
ArrowDataType::Int64,
74-
)?;
75-
Ok(Box::new(data))
76-
}
77-
TimeUnit::Nanosecond => {
78-
let data = try_unary(
79-
arr,
80-
|value| {
81-
let ndt = timestamp_ns_to_datetime(value);
82-
Ok(convert_to_naive_local(&from_tz, &to_tz, ndt, use_earliest)?
83-
.timestamp_nanos())
84-
},
85-
ArrowDataType::Int64,
86-
)?;
87-
Ok(Box::new(data))
88-
}
46+
) -> PolarsResult<PrimitiveArray<i64>> {
47+
let res = match tu {
48+
TimeUnit::Millisecond => try_unary(
49+
arr,
50+
|value| {
51+
let ndt = timestamp_ms_to_datetime(value);
52+
Ok(convert_to_naive_local(&from_tz, &to_tz, ndt, use_earliest)?.timestamp_millis())
53+
},
54+
ArrowDataType::Int64,
55+
),
56+
TimeUnit::Microsecond => try_unary(
57+
arr,
58+
|value| {
59+
let ndt = timestamp_us_to_datetime(value);
60+
Ok(convert_to_naive_local(&from_tz, &to_tz, ndt, use_earliest)?.timestamp_micros())
61+
},
62+
ArrowDataType::Int64,
63+
),
64+
TimeUnit::Nanosecond => try_unary(
65+
arr,
66+
|value| {
67+
let ndt = timestamp_ns_to_datetime(value);
68+
Ok(convert_to_naive_local(&from_tz, &to_tz, ndt, use_earliest)?.timestamp_nanos())
69+
},
70+
ArrowDataType::Int64,
71+
),
8972
_ => unreachable!(),
90-
}
73+
};
74+
Ok(res?)
9175
}
9276

93-
#[cfg(feature = "timezones")]
9477
pub fn replace_time_zone(
9578
arr: &PrimitiveArray<i64>,
9679
tu: TimeUnit,
9780
from: &str,
9881
to: &str,
9982
use_earliest: Option<bool>,
100-
) -> PolarsResult<ArrayRef> {
101-
Ok(match from.parse::<chrono_tz::Tz>() {
83+
) -> PolarsResult<PrimitiveArray<i64>> {
84+
match from.parse::<chrono_tz::Tz>() {
10285
Ok(from_tz) => match to.parse::<chrono_tz::Tz>() {
103-
Ok(to_tz) => convert_to_timestamp(from_tz, to_tz, arr, tu, use_earliest)?,
86+
Ok(to_tz) => convert_to_timestamp(from_tz, to_tz, arr, tu, use_earliest),
10487
Err(_) => polars_bail!(ComputeError: "unable to parse time zone: '{}'", to),
10588
},
10689
Err(_) => polars_bail!(ComputeError: "unable to parse time zone: '{}'", from),
107-
})
90+
}
10891
}

crates/polars-arrow/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#![cfg_attr(
33
feature = "nightly",
44
allow(clippy::incorrect_partial_ord_impl_on_ord_type)
5-
)] // remove once stable
5+
)] // Remove once stable.
66
pub mod array;
77
pub mod bit_util;
88
pub mod bitmap;

crates/polars-core/src/chunked_array/bitwise.rs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ impl BitOr for &BooleanChunked {
7777
let chunks = lhs
7878
.downcast_iter()
7979
.zip(rhs.downcast_iter())
80-
.map(|(lhs, rhs)| Box::new(compute::boolean_kleene::or(lhs, rhs)) as ArrayRef)
81-
.collect();
82-
// safety: same type
83-
unsafe { BooleanChunked::from_chunks(self.name(), chunks) }
80+
.map(|(lhs, rhs)| compute::boolean_kleene::or(lhs, rhs));
81+
BooleanChunked::from_chunk_iter(self.name(), chunks)
8482
}
8583
}
8684

@@ -132,14 +130,9 @@ impl BitXor for &BooleanChunked {
132130
.map(|(l_arr, r_arr)| {
133131
let validity = combine_validities_and(l_arr.validity(), r_arr.validity());
134132
let values = l_arr.values() ^ r_arr.values();
135-
136-
let arr = BooleanArray::from_data_default(values, validity);
137-
Box::new(arr) as ArrayRef
138-
})
139-
.collect::<Vec<_>>();
140-
141-
// safety: same type
142-
unsafe { ChunkedArray::from_chunks(self.name(), chunks) }
133+
BooleanArray::from_data_default(values, validity)
134+
});
135+
ChunkedArray::from_chunk_iter(self.name(), chunks)
143136
}
144137
}
145138

@@ -180,10 +173,8 @@ impl BitAnd for &BooleanChunked {
180173
let chunks = lhs
181174
.downcast_iter()
182175
.zip(rhs.downcast_iter())
183-
.map(|(lhs, rhs)| Box::new(compute::boolean_kleene::and(lhs, rhs)) as ArrayRef)
184-
.collect();
185-
// safety: same type
186-
unsafe { BooleanChunked::from_chunks(self.name(), chunks) }
176+
.map(|(lhs, rhs)| compute::boolean_kleene::and(lhs, rhs));
177+
BooleanChunked::from_chunk_iter(self.name(), chunks)
187178
}
188179
}
189180

crates/polars-core/src/chunked_array/builder/mod.rs

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub use list::*;
2222
pub use primitive::*;
2323
pub use utf8::*;
2424

25+
use crate::chunked_array::to_primitive;
2526
use crate::prelude::*;
2627
use crate::utils::{get_iter_capacity, NoNull};
2728

@@ -46,13 +47,10 @@ where
4647
T: PolarsNumericType,
4748
{
4849
fn from_iter<I: IntoIterator<Item = (Vec<T::Native>, Option<Bitmap>)>>(iter: I) -> Self {
49-
let mut chunks = vec![];
50-
51-
for (values, opt_buffer) in iter {
52-
chunks.push(to_array::<T>(values, opt_buffer))
53-
}
54-
// safety: same type
55-
unsafe { ChunkedArray::from_chunks("from_iter", chunks) }
50+
let chunks = iter
51+
.into_iter()
52+
.map(|(values, opt_buffer)| to_primitive::<T>(values, opt_buffer));
53+
ChunkedArray::from_chunk_iter("from_iter", chunks)
5654
}
5755
}
5856

@@ -72,9 +70,8 @@ where
7270
T: PolarsNumericType,
7371
{
7472
fn from_slice(name: &str, v: &[T::Native]) -> Self {
75-
let arr = PrimitiveArray::<T::Native>::from_slice(v).to(T::get_dtype().to_arrow());
76-
// safety: same type
77-
unsafe { ChunkedArray::from_chunks(name, vec![Box::new(arr)]) }
73+
let arr = PrimitiveArray::from_slice(v).to(T::get_dtype().to_arrow());
74+
ChunkedArray::from_chunk_iter(name, [arr])
7875
}
7976

8077
fn from_slice_options(name: &str, opt_v: &[Option<T::Native>]) -> Self {
@@ -131,13 +128,10 @@ where
131128
{
132129
fn from_slice(name: &str, v: &[S]) -> Self {
133130
let values_size = v.iter().fold(0, |acc, s| acc + s.as_ref().len());
134-
135131
let mut builder = MutableUtf8Array::<i64>::with_capacities(v.len(), values_size);
136132
builder.extend_trusted_len_values(v.iter().map(|s| s.as_ref()));
137-
138-
let chunks = vec![builder.as_box()];
139-
// safety: same type
140-
unsafe { ChunkedArray::from_chunks(name, chunks) }
133+
let imm: Utf8Array<i64> = builder.into();
134+
ChunkedArray::from_chunk_iter(name, [imm])
141135
}
142136

143137
fn from_slice_options(name: &str, opt_v: &[Option<S>]) -> Self {
@@ -147,10 +141,8 @@ where
147141
});
148142
let mut builder = MutableUtf8Array::<i64>::with_capacities(opt_v.len(), values_size);
149143
builder.extend_trusted_len(opt_v.iter().map(|s| s.as_ref()));
150-
151-
let chunks = vec![builder.as_box()];
152-
// safety: same type
153-
unsafe { ChunkedArray::from_chunks(name, chunks) }
144+
let imm: Utf8Array<i64> = builder.into();
145+
ChunkedArray::from_chunk_iter(name, [imm])
154146
}
155147

156148
fn from_iter_options(name: &str, it: impl Iterator<Item = Option<S>>) -> Self {
@@ -175,13 +167,10 @@ where
175167
{
176168
fn from_slice(name: &str, v: &[B]) -> Self {
177169
let values_size = v.iter().fold(0, |acc, s| acc + s.as_ref().len());
178-
179170
let mut builder = MutableBinaryArray::<i64>::with_capacities(v.len(), values_size);
180171
builder.extend_trusted_len_values(v.iter().map(|s| s.as_ref()));
181-
182-
let chunks = vec![builder.as_box()];
183-
// safety: same type
184-
unsafe { ChunkedArray::from_chunks(name, chunks) }
172+
let imm: BinaryArray<i64> = builder.into();
173+
ChunkedArray::from_chunk_iter(name, [imm])
185174
}
186175

187176
fn from_slice_options(name: &str, opt_v: &[Option<B>]) -> Self {
@@ -191,10 +180,8 @@ where
191180
});
192181
let mut builder = MutableBinaryArray::<i64>::with_capacities(opt_v.len(), values_size);
193182
builder.extend_trusted_len(opt_v.iter().map(|s| s.as_ref()));
194-
195-
let chunks = vec![builder.as_box()];
196-
// safety: same type
197-
unsafe { ChunkedArray::from_chunks(name, chunks) }
183+
let imm: BinaryArray<i64> = builder.into();
184+
ChunkedArray::from_chunk_iter(name, [imm])
198185
}
199186

200187
fn from_iter_options(name: &str, it: impl Iterator<Item = Option<B>>) -> Self {
@@ -233,7 +220,7 @@ mod test {
233220
let mut builder =
234221
ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 5, DataType::Int32);
235222

236-
// create a series containing two chunks
223+
// Create a series containing two chunks.
237224
let mut s1 = Int32Chunked::from_slice("a", &[1, 2, 3]).into_series();
238225
let s2 = Int32Chunked::from_slice("b", &[4, 5, 6]).into_series();
239226
s1.append(&s2).unwrap();
@@ -252,7 +239,8 @@ mod test {
252239
} else {
253240
panic!()
254241
}
255-
// test list collect
242+
243+
// Test list collect.
256244
let out = [&s1, &s2].iter().copied().collect::<ListChunked>();
257245
assert_eq!(out.get(0).unwrap().len(), 6);
258246
assert_eq!(out.get(1).unwrap().len(), 3);

crates/polars-core/src/chunked_array/from.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,30 @@ impl<T> ChunkedArray<T>
7474
where
7575
T: PolarsDataType,
7676
{
77+
pub fn from_chunk_iter<I>(name: &str, iter: I) -> Self
78+
where
79+
I: IntoIterator,
80+
<I as IntoIterator>::Item: StaticallyMatchesPolarsType<T> + Array,
81+
{
82+
let chunks = iter
83+
.into_iter()
84+
.map(|x| Box::new(x) as Box<dyn Array>)
85+
.collect();
86+
unsafe { Self::from_chunks(name, chunks) }
87+
}
88+
89+
pub fn try_from_chunk_iter<I, A, E>(name: &str, iter: I) -> Result<Self, E>
90+
where
91+
I: IntoIterator<Item = Result<A, E>>,
92+
A: StaticallyMatchesPolarsType<T> + Array,
93+
{
94+
let chunks: Result<_, _> = iter
95+
.into_iter()
96+
.map(|x| Ok(Box::new(x?) as Box<dyn Array>))
97+
.collect();
98+
unsafe { Ok(Self::from_chunks(name, chunks?)) }
99+
}
100+
77101
/// Create a new ChunkedArray from existing chunks.
78102
///
79103
/// # Safety

crates/polars-core/src/chunked_array/list/mod.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,14 @@ impl ListChunked {
6868
let inner_dtype = self.inner_dtype().to_arrow();
6969

7070
let chunks = ca.downcast_iter().map(|arr| {
71-
let elements = unsafe { Series::try_from_arrow_unchecked(self.name(), vec![(*arr.values()).clone()], &inner_dtype).unwrap() } ;
71+
let elements = unsafe {
72+
Series::try_from_arrow_unchecked(
73+
self.name(),
74+
vec![(*arr.values()).clone()],
75+
&inner_dtype,
76+
)
77+
.unwrap()
78+
};
7279

7380
let expected_len = elements.len();
7481
let out: Series = func(elements)?;
@@ -86,9 +93,9 @@ impl ListChunked {
8693
values,
8794
arr.validity().cloned(),
8895
);
89-
Ok(Box::new(arr) as ArrayRef)
90-
}).collect::<PolarsResult<Vec<_>>>()?;
96+
Ok(arr)
97+
});
9198

92-
unsafe { Ok(ListChunked::from_chunks(self.name(), chunks)) }
99+
ListChunked::try_from_chunk_iter(self.name(), chunks)
93100
}
94101
}

0 commit comments

Comments
 (0)