Skip to content

Commit 2d985b4

Browse files
authored
fix RecordBatch size in topK (apache#13906)
1 parent 30660e0 commit 2d985b4

File tree

1 file changed

+45
-4
lines changed
  • datafusion/physical-plan/src/topk

1 file changed

+45
-4
lines changed

datafusion/physical-plan/src/topk/mod.rs

+45-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ use arrow::{
2424
use std::mem::size_of;
2525
use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
2626

27+
use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
28+
use crate::spill::get_record_batch_memory_size;
2729
use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
2830
use arrow_array::{Array, ArrayRef, RecordBatch};
2931
use arrow_schema::SchemaRef;
@@ -36,8 +38,6 @@ use datafusion_execution::{
3638
use datafusion_physical_expr::PhysicalSortExpr;
3739
use datafusion_physical_expr_common::sort_expr::LexOrdering;
3840

39-
use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
40-
4141
/// Global TopK
4242
///
4343
/// # Background
@@ -575,7 +575,7 @@ impl RecordBatchStore {
575575
pub fn insert(&mut self, entry: RecordBatchEntry) {
576576
// uses of 0 means that none of the rows in the batch were stored in the topk
577577
if entry.uses > 0 {
578-
self.batches_size += entry.batch.get_array_memory_size();
578+
self.batches_size += get_record_batch_memory_size(&entry.batch);
579579
self.batches.insert(entry.id, entry);
580580
}
581581
}
@@ -630,7 +630,7 @@ impl RecordBatchStore {
630630
let old_entry = self.batches.remove(&id).unwrap();
631631
self.batches_size = self
632632
.batches_size
633-
.checked_sub(old_entry.batch.get_array_memory_size())
633+
.checked_sub(get_record_batch_memory_size(&old_entry.batch))
634634
.unwrap();
635635
}
636636
}
@@ -643,3 +643,44 @@ impl RecordBatchStore {
643643
+ self.batches_size
644644
}
645645
}
646+
647+
#[cfg(test)]
648+
mod tests {
649+
use super::*;
650+
use arrow::array::Int32Array;
651+
use arrow::datatypes::{DataType, Field, Schema};
652+
use arrow::record_batch::RecordBatch;
653+
use arrow_array::Float64Array;
654+
655+
/// This test ensures the size calculation is correct for RecordBatches with multiple columns.
656+
#[test]
657+
fn test_record_batch_store_size() {
658+
// given
659+
let schema = Arc::new(Schema::new(vec![
660+
Field::new("ints", DataType::Int32, true),
661+
Field::new("float64", DataType::Float64, false),
662+
]));
663+
let mut record_batch_store = RecordBatchStore::new(Arc::clone(&schema));
664+
let int_array =
665+
Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
666+
let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
667+
668+
let record_batch_entry = RecordBatchEntry {
669+
id: 0,
670+
batch: RecordBatch::try_new(
671+
schema,
672+
vec![Arc::new(int_array), Arc::new(float64_array)],
673+
)
674+
.unwrap(),
675+
uses: 1,
676+
};
677+
678+
// when insert record batch entry
679+
record_batch_store.insert(record_batch_entry);
680+
assert_eq!(record_batch_store.batches_size, 60);
681+
682+
// when unuse record batch entry
683+
record_batch_store.unuse(0);
684+
assert_eq!(record_batch_store.batches_size, 0);
685+
}
686+
}

0 commit comments

Comments
 (0)