@@ -24,6 +24,8 @@ use arrow::{
24
24
use std:: mem:: size_of;
25
25
use std:: { cmp:: Ordering , collections:: BinaryHeap , sync:: Arc } ;
26
26
27
+ use super :: metrics:: { BaselineMetrics , Count , ExecutionPlanMetricsSet , MetricBuilder } ;
28
+ use crate :: spill:: get_record_batch_memory_size;
27
29
use crate :: { stream:: RecordBatchStreamAdapter , SendableRecordBatchStream } ;
28
30
use arrow_array:: { Array , ArrayRef , RecordBatch } ;
29
31
use arrow_schema:: SchemaRef ;
@@ -36,8 +38,6 @@ use datafusion_execution::{
36
38
use datafusion_physical_expr:: PhysicalSortExpr ;
37
39
use datafusion_physical_expr_common:: sort_expr:: LexOrdering ;
38
40
39
- use super :: metrics:: { BaselineMetrics , Count , ExecutionPlanMetricsSet , MetricBuilder } ;
40
-
41
41
/// Global TopK
42
42
///
43
43
/// # Background
@@ -575,7 +575,7 @@ impl RecordBatchStore {
575
575
pub fn insert ( & mut self , entry : RecordBatchEntry ) {
576
576
// uses of 0 means that none of the rows in the batch were stored in the topk
577
577
if entry. uses > 0 {
578
- self . batches_size += entry. batch . get_array_memory_size ( ) ;
578
+ self . batches_size += get_record_batch_memory_size ( & entry. batch ) ;
579
579
self . batches . insert ( entry. id , entry) ;
580
580
}
581
581
}
@@ -630,7 +630,7 @@ impl RecordBatchStore {
630
630
let old_entry = self . batches . remove ( & id) . unwrap ( ) ;
631
631
self . batches_size = self
632
632
. batches_size
633
- . checked_sub ( old_entry. batch . get_array_memory_size ( ) )
633
+ . checked_sub ( get_record_batch_memory_size ( & old_entry. batch ) )
634
634
. unwrap ( ) ;
635
635
}
636
636
}
@@ -643,3 +643,44 @@ impl RecordBatchStore {
643
643
+ self . batches_size
644
644
}
645
645
}
646
+
647
+ #[ cfg( test) ]
648
+ mod tests {
649
+ use super :: * ;
650
+ use arrow:: array:: Int32Array ;
651
+ use arrow:: datatypes:: { DataType , Field , Schema } ;
652
+ use arrow:: record_batch:: RecordBatch ;
653
+ use arrow_array:: Float64Array ;
654
+
655
+ /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
656
+ #[ test]
657
+ fn test_record_batch_store_size ( ) {
658
+ // given
659
+ let schema = Arc :: new ( Schema :: new ( vec ! [
660
+ Field :: new( "ints" , DataType :: Int32 , true ) ,
661
+ Field :: new( "float64" , DataType :: Float64 , false ) ,
662
+ ] ) ) ;
663
+ let mut record_batch_store = RecordBatchStore :: new ( Arc :: clone ( & schema) ) ;
664
+ let int_array =
665
+ Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , Some ( 3 ) , Some ( 4 ) , Some ( 5 ) ] ) ; // 5 * 4 = 20
666
+ let float64_array = Float64Array :: from ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 ] ) ; // 5 * 8 = 40
667
+
668
+ let record_batch_entry = RecordBatchEntry {
669
+ id : 0 ,
670
+ batch : RecordBatch :: try_new (
671
+ schema,
672
+ vec ! [ Arc :: new( int_array) , Arc :: new( float64_array) ] ,
673
+ )
674
+ . unwrap ( ) ,
675
+ uses : 1 ,
676
+ } ;
677
+
678
+ // when insert record batch entry
679
+ record_batch_store. insert ( record_batch_entry) ;
680
+ assert_eq ! ( record_batch_store. batches_size, 60 ) ;
681
+
682
+ // when unuse record batch entry
683
+ record_batch_store. unuse ( 0 ) ;
684
+ assert_eq ! ( record_batch_store. batches_size, 0 ) ;
685
+ }
686
+ }
0 commit comments