Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion/core/tests/memory_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ async fn sort_merge_join_spill() {
.with_config(config)
.with_disk_manager_builder(DiskManagerBuilder::default())
.with_scenario(Scenario::AccessLogStreaming)
// SMJ spilling succeeds at this memory limit because the
// pre-resolved JoinComparator allocates via the global heap
// rather than the tracked memory pool, leaving more pool
// budget for buffered batch data.
.with_expected_success()
.run()
.await
}
Expand Down
241 changes: 155 additions & 86 deletions datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use std::sync::atomic::Ordering::Relaxed;
use std::task::{Context, Poll};

use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
use crate::joins::utils::{JoinFilter, compare_join_arrays};
use crate::joins::utils::JoinFilter;
use crate::metrics::RecordOutput;
use crate::spill::spill_manager::SpillManager;
use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream};
Expand All @@ -44,13 +44,14 @@ use arrow::compute::{
self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, is_not_null,
take,
};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use arrow_ord::ord::{DynComparator, make_comparator};
use datafusion_common::config::SpillCompression;
use datafusion_common::{
DataFusionError, HashSet, JoinSide, JoinType, NullEquality, Result, exec_err,
internal_err, not_impl_err,
internal_err,
};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::MemoryReservation;
Expand Down Expand Up @@ -286,10 +287,8 @@ pub(super) struct SortMergeJoinStream {
// ========================================================================
/// Output schema
pub schema: SchemaRef,
/// Defines the null equality for the join.
pub null_equality: NullEquality,
/// Sort options of join columns used to sort streamed and buffered data stream
pub sort_options: Vec<SortOptions>,
/// Comparator for join key columns (delegates to Arrow's `make_comparator`)
pub comparator: JoinComparator,
/// optional join filter
pub filter: Option<JoinFilter>,
/// How the join is performed
Expand Down Expand Up @@ -360,6 +359,9 @@ pub(super) struct SortMergeJoinStream {
pub runtime_env: Arc<RuntimeEnv>,
/// A unique number for each batch
pub streamed_batch_counter: AtomicUsize,
/// Cached comparators for `compare_streamed_buffered`, rebuilt only when
/// the streamed or buffered join key arrays change.
compare_cache: Option<ComparatorCache>,
}

/// Staging area for joined data before output
Expand Down Expand Up @@ -947,6 +949,7 @@ impl SortMergeJoinStream {
) -> Result<Self> {
let streamed_schema = streamed.schema();
let buffered_schema = buffered.schema();
let comparator = JoinComparator::new(sort_options, null_equality);
let spill_manager = SpillManager::new(
Arc::clone(&runtime_env),
join_metrics.spill_metrics().clone(),
Expand All @@ -955,8 +958,7 @@ impl SortMergeJoinStream {
.with_compression_type(spill_compression);
Ok(Self {
state: SortMergeJoinState::Init,
sort_options,
null_equality,
comparator,
schema: Arc::clone(&schema),
streamed_schema: Arc::clone(&streamed_schema),
buffered_schema,
Expand Down Expand Up @@ -988,6 +990,7 @@ impl SortMergeJoinStream {
runtime_env,
spill_manager,
streamed_batch_counter: AtomicUsize::new(0),
compare_cache: None,
})
}

Expand Down Expand Up @@ -1190,15 +1193,18 @@ impl SortMergeJoinStream {
if self.buffered_data.tail_batch().range.end
< self.buffered_data.tail_batch().num_rows
{
let comparators = self.comparator.build_comparators(
&self.buffered_data.head_batch().join_arrays,
&self.buffered_data.tail_batch().join_arrays,
)?;
while self.buffered_data.tail_batch().range.end
< self.buffered_data.tail_batch().num_rows
{
if is_join_arrays_equal(
&self.buffered_data.head_batch().join_arrays,
if JoinComparator::is_equal(
&comparators,
self.buffered_data.head_batch().range.start,
&self.buffered_data.tail_batch().join_arrays,
self.buffered_data.tail_batch().range.end,
)? {
) {
self.buffered_data.tail_batch_mut().range.end += 1;
} else {
self.buffered_state = BufferedState::Ready;
Expand Down Expand Up @@ -1240,22 +1246,42 @@ impl SortMergeJoinStream {
}

/// Get comparison result of streamed row and buffered batches
fn compare_streamed_buffered(&self) -> Result<Ordering> {
fn compare_streamed_buffered(&mut self) -> Result<Ordering> {
if self.streamed_state == StreamedState::Exhausted {
return Ok(Ordering::Greater);
}
if !self.buffered_data.has_buffered_rows() {
return Ok(Ordering::Less);
}

compare_join_arrays(
// Check if cached comparators are still valid (same arrays).
let left_ptr = array_id(&self.streamed_batch.join_arrays);
let right_ptr = array_id(&self.buffered_data.head_batch().join_arrays);
let needs_rebuild = match &self.compare_cache {
Some(c) => c.left_ptr != left_ptr || c.right_ptr != right_ptr,
None => true,
};
if needs_rebuild {
let comparators = self.comparator.build_comparators(
&self.streamed_batch.join_arrays,
&self.buffered_data.head_batch().join_arrays,
)?;
self.compare_cache = Some(ComparatorCache {
left_ptr,
right_ptr,
comparators,
});
}

let cache = self.compare_cache.as_ref().unwrap();
Ok(JoinComparator::compare(
self.comparator.null_equality,
&cache.comparators,
&self.streamed_batch.join_arrays,
self.streamed_batch.idx,
&self.buffered_data.head_batch().join_arrays,
self.streamed_batch.idx,
self.buffered_data.head_batch().range.start,
&self.sort_options,
self.null_equality,
)
))
}

/// Produce join and fill output buffer until reaching target batch size
Expand Down Expand Up @@ -2096,77 +2122,120 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayR
.collect()
}

/// A faster version of compare_join_arrays() that only output whether
/// the given two rows are equal
fn is_join_arrays_equal(
left_arrays: &[ArrayRef],
left: usize,
right_arrays: &[ArrayRef],
right: usize,
) -> Result<bool> {
let mut is_equal = true;
for (left_array, right_array) in left_arrays.iter().zip(right_arrays) {
macro_rules! compare_value {
($T:ty) => {{
match (left_array.is_null(left), right_array.is_null(right)) {
(false, false) => {
let left_array =
left_array.as_any().downcast_ref::<$T>().unwrap();
let right_array =
right_array.as_any().downcast_ref::<$T>().unwrap();
if left_array.value(left) != right_array.value(right) {
is_equal = false;
}
}
(true, false) => is_equal = false,
(false, true) => is_equal = false,
_ => {}
}
}};
/// Cached `DynComparator`s keyed by the identity of the arrays they were built from.
///
/// Since `DynComparator` captures array references at construction time, we must
/// rebuild when the underlying arrays change. We detect changes by comparing
/// the `Arc` data pointers of the first array on each side (all join key arrays
/// come from the same batch, so they all change together).
struct ComparatorCache {
/// Data pointer of the first left array.
left_ptr: usize,
/// Data pointer of the first right array.
right_ptr: usize,
/// Cached `DynComparator`s, one per join key column.
comparators: Vec<DynComparator>,
}

/// Returns a data-pointer identity for a slice of arrays.
///
/// All arrays in a join key set come from the same batch, so the first
/// array's `Arc` pointer changing implies all arrays changed.
#[inline]
fn array_id(arrays: &[ArrayRef]) -> usize {
arrays
.first()
.map_or(0, |a| Arc::as_ptr(a) as *const () as usize)
}

/// Comparator for join key columns using Arrow's built-in `make_comparator`.
///
/// Delegates to [`arrow_ord::ord::make_comparator`] which handles all Arrow
/// data types (including List, Struct, Map, RunEndEncoded, etc.) and optimizes
/// away null checks when a column has no nulls.
///
/// Because `make_comparator` captures array references at construction time,
/// comparators are rebuilt each time the underlying batch arrays change.
pub(super) struct JoinComparator {
/// Sort options per join key column (needed when rebuilding comparators).
sort_options: Vec<SortOptions>,
/// Null equality semantics.
null_equality: NullEquality,
}

impl JoinComparator {
/// Create a new comparator with the given sort options and null equality.
fn new(sort_options: Vec<SortOptions>, null_equality: NullEquality) -> Self {
Self {
sort_options,
null_equality,
}
}

match left_array.data_type() {
DataType::Null => {}
DataType::Boolean => compare_value!(BooleanArray),
DataType::Int8 => compare_value!(Int8Array),
DataType::Int16 => compare_value!(Int16Array),
DataType::Int32 => compare_value!(Int32Array),
DataType::Int64 => compare_value!(Int64Array),
DataType::UInt8 => compare_value!(UInt8Array),
DataType::UInt16 => compare_value!(UInt16Array),
DataType::UInt32 => compare_value!(UInt32Array),
DataType::UInt64 => compare_value!(UInt64Array),
DataType::Float32 => compare_value!(Float32Array),
DataType::Float64 => compare_value!(Float64Array),
DataType::Utf8 => compare_value!(StringArray),
DataType::Utf8View => compare_value!(StringViewArray),
DataType::LargeUtf8 => compare_value!(LargeStringArray),
DataType::Binary => compare_value!(BinaryArray),
DataType::BinaryView => compare_value!(BinaryViewArray),
DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
DataType::LargeBinary => compare_value!(LargeBinaryArray),
DataType::Decimal32(..) => compare_value!(Decimal32Array),
DataType::Decimal64(..) => compare_value!(Decimal64Array),
DataType::Decimal128(..) => compare_value!(Decimal128Array),
DataType::Decimal256(..) => compare_value!(Decimal256Array),
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => compare_value!(TimestampSecondArray),
TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
},
DataType::Date32 => compare_value!(Date32Array),
DataType::Date64 => compare_value!(Date64Array),
dt => {
return not_impl_err!(
"Unsupported data type in sort merge join comparator: {}",
dt
);
/// Build per-column `DynComparator`s for the given left/right key arrays.
///
/// Must be called each time the underlying batch arrays change, since
/// `make_comparator` captures array references.
fn build_comparators(
&self,
left_arrays: &[ArrayRef],
right_arrays: &[ArrayRef],
) -> Result<Vec<DynComparator>> {
left_arrays
.iter()
.zip(right_arrays)
.zip(&self.sort_options)
.map(|((l, r), opts)| {
make_comparator(l.as_ref(), r.as_ref(), *opts)
.map_err(DataFusionError::from)
})
.collect()
}

/// Compare two rows for ordering (used in the merge loop to decide
/// advance-streamed vs advance-buffered).
///
/// When `null_equality` is `NullEqualsNothing`, an extra check ensures
/// that null-null pairs are not treated as equal (Arrow's comparator
/// returns `Equal` for null-null).
#[inline]
fn compare(
null_equality: NullEquality,
comparators: &[DynComparator],
left_arrays: &[ArrayRef],
right_arrays: &[ArrayRef],
left_idx: usize,
right_idx: usize,
) -> Ordering {
for cmp in comparators {
let ord = cmp(left_idx, right_idx);
if !ord.is_eq() {
return ord;
}
}
if !is_equal {
return Ok(false);
// All columns compared equal. If null-null should not match,
// check whether any column pair has both sides null.
if null_equality == NullEquality::NullEqualsNothing {
for (l, r) in left_arrays.iter().zip(right_arrays) {
if l.is_null(left_idx) && r.is_null(right_idx) {
return Ordering::Less;
}
}
}
Ordering::Equal
}

/// Check if two rows are equal (used when expanding buffered batches with
/// the same key). Null-null is treated as equal here since we are grouping
/// buffered rows that share the same key.
#[inline]
fn is_equal(
comparators: &[DynComparator],
left_idx: usize,
right_idx: usize,
) -> bool {
comparators
.iter()
.all(|cmp| cmp(left_idx, right_idx).is_eq())
}
Ok(true)
}