@@ -522,8 +522,10 @@ impl<F: FileOpener> RecordBatchStream for FileStream<F> {
522
522
523
523
#[ cfg( test) ]
524
524
mod tests {
525
+ use std:: future:: Future ;
525
526
use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
526
527
use std:: sync:: Arc ;
528
+ use std:: task:: Wake ;
527
529
528
530
use super :: * ;
529
531
use crate :: datasource:: object_store:: ObjectStoreUrl ;
@@ -532,6 +534,7 @@ mod tests {
532
534
533
535
use arrow_schema:: Schema ;
534
536
use datafusion_common:: internal_err;
537
+ use pin_project_lite:: pin_project;
535
538
536
539
/// Test `FileOpener` which will simulate errors during file opening or scanning
537
540
#[ derive( Default ) ]
@@ -624,6 +627,14 @@ mod tests {
624
627
625
628
/// Collect the results of the `FileStream`
626
629
pub async fn result ( self ) -> Result < Vec < RecordBatch > > {
630
+ self . result_with_cnt_yields ( )
631
+ . await
632
+ . map ( |( batches, _) | batches)
633
+ }
634
+
635
+ /// Collect the results of the `FileStream`, and count the number of yields
636
+ /// back to the runtime.
637
+ pub async fn result_with_cnt_yields ( self ) -> Result < ( Vec < RecordBatch > , usize ) > {
627
638
let file_schema = self
628
639
. opener
629
640
. records
@@ -661,11 +672,85 @@ mod tests {
661
672
. unwrap ( )
662
673
. with_on_error ( on_error) ;
663
674
664
- file_stream
665
- . collect :: < Vec < _ > > ( )
666
- . await
667
- . into_iter ( )
668
- . collect :: < Result < Vec < _ > > > ( )
675
+ Collector :: new ( file_stream) . collect ( ) . await
676
+ }
677
+ }
678
+
679
+ /// A waker that wakes up the current thread when called,
680
+ /// and tracks the number of wakings.
681
+ struct ThreadWaker {
682
+ pub thread : std:: thread:: Thread ,
683
+ pub wake_cnt : AtomicUsize ,
684
+ }
685
+
686
+ impl Wake for ThreadWaker {
687
+ fn wake ( self : Arc < Self > ) {
688
+ let _ = self . wake_cnt . fetch_add ( 1 , Ordering :: SeqCst ) ;
689
+ self . thread . unpark ( ) ;
690
+ }
691
+ }
692
+
693
+ pin_project ! {
694
+ /// Perform a stream collect(), with the counting [`ThreadWaker`].
695
+ pub struct Collector <St > {
696
+ #[ pin]
697
+ stream: St ,
698
+ collection: Vec <RecordBatch >,
699
+ }
700
+ }
701
+
702
+ impl < St : Stream < Item = Result < RecordBatch > > > Collector < St > {
703
+ pub fn new ( stream : St ) -> Self {
704
+ Self {
705
+ stream,
706
+ collection : Default :: default ( ) ,
707
+ }
708
+ }
709
+
710
+ pub async fn collect ( mut self ) -> Result < ( Vec < RecordBatch > , usize ) > {
711
+ let mut pinned = std:: pin:: pin!( self ) ;
712
+
713
+ // Create a new context, with the tracked waker.
714
+ let thread = std:: thread:: current ( ) ;
715
+ let counting_waker = Arc :: new ( ThreadWaker {
716
+ thread,
717
+ wake_cnt : Default :: default ( ) ,
718
+ } ) ;
719
+ let waker = Arc :: clone ( & counting_waker) . into ( ) ;
720
+ let mut cx = Context :: from_waker ( & waker) ;
721
+
722
+ // poll future with provided context
723
+ loop {
724
+ match pinned. as_mut ( ) . poll ( & mut cx) {
725
+ Poll :: Ready ( Ok ( batches) ) => {
726
+ return Ok ( (
727
+ batches,
728
+ counting_waker. wake_cnt . load ( Ordering :: SeqCst ) ,
729
+ ) )
730
+ }
731
+ Poll :: Ready ( Err ( e) ) => return Err ( e) ,
732
+ _ => continue ,
733
+ }
734
+ }
735
+ }
736
+
737
+ fn finish ( self : Pin < & mut Self > ) -> Vec < RecordBatch > {
738
+ mem:: take ( self . project ( ) . collection )
739
+ }
740
+ }
741
+
742
+ impl < St : Stream < Item = Result < RecordBatch > > > Future for Collector < St > {
743
+ type Output = Result < Vec < RecordBatch > > ;
744
+
745
+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
746
+ let mut this = self . as_mut ( ) . project ( ) ;
747
+ loop {
748
+ match ready ! ( this. stream. as_mut( ) . poll_next( cx) ) {
749
+ Some ( Ok ( rb) ) => this. collection . push ( rb) ,
750
+ Some ( Err ( e) ) => return Poll :: Ready ( Err ( e) ) ,
751
+ None => return Poll :: Ready ( Ok ( self . finish ( ) ) ) ,
752
+ }
753
+ }
669
754
}
670
755
}
671
756
@@ -974,4 +1059,36 @@ mod tests {
974
1059
975
1060
Ok ( ( ) )
976
1061
}
1062
+
1063
+ #[ tokio:: test]
1064
+ async fn file_stream_will_yield_btwn_open_files ( ) -> Result < ( ) > {
1065
+ let ( batches, woken) = FileStreamTest :: new ( )
1066
+ . with_records ( vec ! [ make_partition( 3 ) , make_partition( 2 ) ] )
1067
+ . with_num_files ( 1 )
1068
+ . result_with_cnt_yields ( )
1069
+ . await ?;
1070
+ assert_eq ! (
1071
+ batches. len( ) ,
1072
+ 2 ,
1073
+ "should have 2 batches per file, and 1 file"
1074
+ ) ;
1075
+ assert_eq ! (
1076
+ woken, 0 ,
1077
+ "never yielded for single open files with reader ready"
1078
+ ) ;
1079
+
1080
+ let ( batches, woken) = FileStreamTest :: new ( )
1081
+ . with_records ( vec ! [ make_partition( 3 ) , make_partition( 2 ) ] )
1082
+ . with_num_files ( 3 )
1083
+ . result_with_cnt_yields ( )
1084
+ . await ?;
1085
+ assert_eq ! (
1086
+ batches. len( ) ,
1087
+ 6 ,
1088
+ "should have 2 batches per file, and 3 files"
1089
+ ) ;
1090
+ assert_eq ! ( woken, 2 , "should yield btwn each open file" ) ;
1091
+
1092
+ Ok ( ( ) )
1093
+ }
977
1094
}
0 commit comments