Skip to content

Commit b9e4d8b

Browse files
committed
test: demonstrate yielding btwn each open file being scanned
1 parent 9a808eb commit b9e4d8b

File tree

2 files changed

+123
-5
lines changed

2 files changed

+123
-5
lines changed

datafusion/core/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ datafusion-functions-window-common = { workspace = true }
144144
doc-comment = { workspace = true }
145145
env_logger = { workspace = true }
146146
paste = "^1.0"
147+
pin-project-lite = "0.2.16"
147148
rand = { workspace = true, features = ["small_rng"] }
148149
rand_distr = "0.4.3"
149150
regex = { workspace = true }

datafusion/core/src/datasource/physical_plan/file_stream.rs

+122-5
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,10 @@ impl<F: FileOpener> RecordBatchStream for FileStream<F> {
522522

523523
#[cfg(test)]
524524
mod tests {
525+
use std::future::Future;
525526
use std::sync::atomic::{AtomicUsize, Ordering};
526527
use std::sync::Arc;
528+
use std::task::Wake;
527529

528530
use super::*;
529531
use crate::datasource::object_store::ObjectStoreUrl;
@@ -532,6 +534,7 @@ mod tests {
532534

533535
use arrow_schema::Schema;
534536
use datafusion_common::internal_err;
537+
use pin_project_lite::pin_project;
535538

536539
/// Test `FileOpener` which will simulate errors during file opening or scanning
537540
#[derive(Default)]
@@ -624,6 +627,14 @@ mod tests {
624627

625628
/// Collect the results of the `FileStream`
626629
pub async fn result(self) -> Result<Vec<RecordBatch>> {
630+
self.result_with_cnt_yields()
631+
.await
632+
.map(|(batches, _)| batches)
633+
}
634+
635+
/// Collect the results of the `FileStream`, and count the number of yields
636+
/// back to the runtime.
637+
pub async fn result_with_cnt_yields(self) -> Result<(Vec<RecordBatch>, usize)> {
627638
let file_schema = self
628639
.opener
629640
.records
@@ -661,11 +672,85 @@ mod tests {
661672
.unwrap()
662673
.with_on_error(on_error);
663674

664-
file_stream
665-
.collect::<Vec<_>>()
666-
.await
667-
.into_iter()
668-
.collect::<Result<Vec<_>>>()
675+
Collector::new(file_stream).collect().await
676+
}
677+
}
678+
679+
/// A waker that wakes up the current thread when called,
680+
/// and tracks the number of wakings.
681+
struct ThreadWaker {
682+
pub thread: std::thread::Thread,
683+
pub wake_cnt: AtomicUsize,
684+
}
685+
686+
impl Wake for ThreadWaker {
687+
fn wake(self: Arc<Self>) {
688+
let _ = self.wake_cnt.fetch_add(1, Ordering::SeqCst);
689+
self.thread.unpark();
690+
}
691+
}
692+
693+
pin_project! {
694+
/// Perform a stream collect(), with the counting [`ThreadWaker`].
695+
pub struct Collector<St> {
696+
#[pin]
697+
stream: St,
698+
collection: Vec<RecordBatch>,
699+
}
700+
}
701+
702+
impl<St: Stream<Item = Result<RecordBatch>>> Collector<St> {
703+
pub fn new(stream: St) -> Self {
704+
Self {
705+
stream,
706+
collection: Default::default(),
707+
}
708+
}
709+
710+
pub async fn collect(mut self) -> Result<(Vec<RecordBatch>, usize)> {
711+
let mut pinned = std::pin::pin!(self);
712+
713+
// Create a new context, with the tracked waker.
714+
let thread = std::thread::current();
715+
let counting_waker = Arc::new(ThreadWaker {
716+
thread,
717+
wake_cnt: Default::default(),
718+
});
719+
let waker = Arc::clone(&counting_waker).into();
720+
let mut cx = Context::from_waker(&waker);
721+
722+
// poll future with provided context
723+
loop {
724+
match pinned.as_mut().poll(&mut cx) {
725+
Poll::Ready(Ok(batches)) => {
726+
return Ok((
727+
batches,
728+
counting_waker.wake_cnt.load(Ordering::SeqCst),
729+
))
730+
}
731+
Poll::Ready(Err(e)) => return Err(e),
732+
_ => continue,
733+
}
734+
}
735+
}
736+
737+
fn finish(self: Pin<&mut Self>) -> Vec<RecordBatch> {
738+
mem::take(self.project().collection)
739+
}
740+
}
741+
742+
impl<St: Stream<Item = Result<RecordBatch>>> Future for Collector<St> {
743+
type Output = Result<Vec<RecordBatch>>;
744+
745+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
746+
let mut this = self.as_mut().project();
747+
loop {
748+
match ready!(this.stream.as_mut().poll_next(cx)) {
749+
Some(Ok(rb)) => this.collection.push(rb),
750+
Some(Err(e)) => return Poll::Ready(Err(e)),
751+
None => return Poll::Ready(Ok(self.finish())),
752+
}
753+
}
669754
}
670755
}
671756

@@ -974,4 +1059,36 @@ mod tests {
9741059

9751060
Ok(())
9761061
}
1062+
1063+
#[tokio::test]
1064+
async fn file_stream_will_yield_btwn_open_files() -> Result<()> {
1065+
let (batches, woken) = FileStreamTest::new()
1066+
.with_records(vec![make_partition(3), make_partition(2)])
1067+
.with_num_files(1)
1068+
.result_with_cnt_yields()
1069+
.await?;
1070+
assert_eq!(
1071+
batches.len(),
1072+
2,
1073+
"should have 2 batches per file, and 1 file"
1074+
);
1075+
assert_eq!(
1076+
woken, 0,
1077+
"never yielded for single open files with reader ready"
1078+
);
1079+
1080+
let (batches, woken) = FileStreamTest::new()
1081+
.with_records(vec![make_partition(3), make_partition(2)])
1082+
.with_num_files(3)
1083+
.result_with_cnt_yields()
1084+
.await?;
1085+
assert_eq!(
1086+
batches.len(),
1087+
6,
1088+
"should have 2 batches per file, and 3 files"
1089+
);
1090+
assert_eq!(woken, 2, "should yield btwn each open file");
1091+
1092+
Ok(())
1093+
}
9771094
}

0 commit comments

Comments
 (0)