diff --git a/Cargo.toml b/Cargo.toml index a486b58..4ab166e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ homepage = "https://github.com/al8n/wg" repository = "https://github.com/al8n/wg.git" documentation = "https://docs.rs/wg/" readme = "README.md" -version = "0.7.2" +version = "0.7.3" license = "MIT OR Apache-2.0" keywords = ["waitgroup", "async", "sync", "notify", "wake"] categories = ["asynchronous", "concurrency", "data-structures"] @@ -18,7 +18,7 @@ full = ["triomphe", "parking_lot"] triomphe = ["dep:triomphe"] parking_lot = ["dep:parking_lot"] -future = ["event-listener", "event-listener-strategy", "pin-project-lite"] +future = ["event-listener", "pin-project-lite"] tokio = ["dep:tokio", "futures-core", "pin-project-lite"] @@ -26,7 +26,6 @@ tokio = ["dep:tokio", "futures-core", "pin-project-lite"] parking_lot = { version = "0.12", optional = true } triomphe = { version = "0.1", optional = true } event-listener = { version = "5", optional = true } -event-listener-strategy = { version = "0.5", optional = true } pin-project-lite = { version = "0.2", optional = true } tokio = { version = "1", default-features = false, optional = true, features = ["sync", "rt"] } @@ -50,3 +49,7 @@ name = "future" path = "tests/future.rs" required-features = ["future"] +[[test]] +name = "sync" +path = "tests/sync.rs" + diff --git a/src/future.rs b/src/future.rs index 2c51d08..af4572a 100644 --- a/src/future.rs +++ b/src/future.rs @@ -1,11 +1,10 @@ use super::*; use event_listener::{Event, EventListener}; -use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; use std::{ pin::Pin, sync::atomic::{AtomicUsize, Ordering}, - task::Poll, + task::{Context, Poll}, }; #[derive(Debug)] @@ -163,7 +162,7 @@ impl AsyncWaitGroup { /// }); /// # }) /// ``` - pub fn done(&self) { + pub fn done(self) { if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { self.inner.event.notify(usize::MAX); } @@ -197,7 +196,11 @@ impl AsyncWaitGroup { /// # }) /// ``` pub fn wait(&self) -> WaitGroupFuture<'_> { - WaitGroupFuture::_new(WaitGroupFutureInner::new(&self.inner)) + WaitGroupFuture { + inner: self, + notified: self.inner.event.listen(), + _pin: std::marker::PhantomPinned, + } } /// Wait blocks until the [`AsyncWaitGroup`] counter is zero. This method is @@ -222,79 +225,64 @@ impl AsyncWaitGroup { /// t_wg.done() /// }); /// + /// let spawner = |fut| { + /// spawn(fut); + /// }; + /// /// // wait other thread completes - /// wg.block_wait(); + /// wg.block_wait(spawner); /// # }) /// ``` - pub fn block_wait(&self) { - WaitGroupFutureInner::new(&self.inner).wait(); + pub fn block_wait(&self, spawner: S) + where + S: FnOnce(Pin + Send + 'static>>), + { + let this = self.clone(); + let (tx, rx) = std::sync::mpsc::channel(); + spawner(Box::pin(async move { + this.wait().await; + let _ = tx.send(()); + })); + + let _ = rx.recv(); } } -easy_wrapper! { +pin_project_lite::pin_project! { /// A future returned by [`AsyncWaitGroup::wait()`]. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] - #[cfg_attr(docsrs, doc(cfg(feature = "future")))] - pub struct WaitGroupFuture<'a>(WaitGroupFutureInner<'a> => ()); - - #[cfg(all(feature = "std", not(target_family = "wasm")))] - pub(crate) wait(); -} - -pin_project_lite::pin_project! { - /// A future that used to wait for the [`AsyncWaitGroup`] counter is zero. - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[project(!Unpin)] - #[derive(Debug)] - struct WaitGroupFutureInner<'a> { - inner: &'a Arc, - listener: Option, + #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] + pub struct WaitGroupFuture<'a> { + inner: &'a AsyncWaitGroup, + #[pin] + notified: EventListener, #[pin] _pin: std::marker::PhantomPinned, } } -impl<'a> WaitGroupFutureInner<'a> { - fn new(inner: &'a Arc) -> Self { - Self { - inner, - listener: None, - _pin: std::marker::PhantomPinned, - } - } -} - -impl EventListenerFuture for WaitGroupFutureInner<'_> { +impl<'a> std::future::Future for WaitGroupFuture<'a> { type Output = (); - fn poll_with_strategy<'a, S: Strategy<'a>>( - self: Pin<&mut Self>, - strategy: &mut S, - context: &mut S::Context, - ) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.inner.counter.load(Ordering::SeqCst) == 0 { + return Poll::Ready(()); + } + let this = self.project(); - loop { - if this.inner.counter.load(Ordering::SeqCst) == 0 { - return Poll::Ready(()); + match this.notified.poll(cx) { + Poll::Pending => { + cx.waker().wake_by_ref(); + Poll::Pending } - - if this.listener.is_some() { - // Poll using the given strategy - match S::poll(strategy, &mut *this.listener, context) { - Poll::Ready(_) => { - // Event received, check the condition again. - if this.inner.counter.load(Ordering::SeqCst) == 0 { - return Poll::Ready(()); - } - - // Event received but condition not met, reset listener. - *this.listener = None; - } - Poll::Pending => return Poll::Pending, + Poll::Ready(_) => { + if this.inner.inner.counter.load(Ordering::SeqCst) == 0 { + Poll::Ready(()) + } else { + cx.waker().wake_by_ref(); + Poll::Pending } - } else { - *this.listener = Some(this.inner.event.listen()); } } } diff --git a/src/lib.rs b/src/lib.rs index 8d5dffe..78e9bd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -220,7 +220,7 @@ impl WaitGroup { /// }); /// /// ``` - pub fn done(&self) { + pub fn done(self) { let mut val = self.inner.count.lock_me(); *val = if val.eq(&1) { @@ -277,92 +277,3 @@ impl WaitGroup { } } } - -#[cfg(test)] -mod test { - use super::*; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; - use std::time::Duration; - - #[test] - fn test_sync_wait_group_reuse() { - let wg = WaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - for _ in 0..6 { - let wg = wg.add(1); - let ctrx = ctr.clone(); - std::thread::spawn(move || { - std::thread::sleep(Duration::from_millis(5)); - ctrx.fetch_add(1, Ordering::Relaxed); - wg.done(); - }); - } - - wg.wait(); - assert_eq!(ctr.load(Ordering::Relaxed), 6); - - let worker = wg.add(1); - let ctrx = ctr.clone(); - std::thread::spawn(move || { - std::thread::sleep(Duration::from_millis(5)); - ctrx.fetch_add(1, Ordering::Relaxed); - worker.done(); - }); - wg.wait(); - assert_eq!(ctr.load(Ordering::Relaxed), 7); - } - - #[test] - fn test_sync_wait_group_nested() { - let wg = WaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - for _ in 0..5 { - let worker = wg.add(1); - let ctrx = ctr.clone(); - std::thread::spawn(move || { - let nested_worker = worker.add(1); - let ctrxx = ctrx.clone(); - std::thread::spawn(move || { - ctrxx.fetch_add(1, Ordering::Relaxed); - nested_worker.done(); - }); - ctrx.fetch_add(1, Ordering::Relaxed); - worker.done(); - }); - } - - wg.wait(); - assert_eq!(ctr.load(Ordering::Relaxed), 10); - } - - #[test] - fn test_sync_wait_group_from() { - std::thread::scope(|s| { - let wg = WaitGroup::from(5); - for _ in 0..5 { - let t = wg.clone(); - s.spawn(move || { - t.done(); - }); - } - wg.wait(); - }); - } - - #[test] - fn test_clone_and_fmt() { - let swg = WaitGroup::new(); - let swg1 = swg.clone(); - swg1.add(3); - assert_eq!(format!("{:?}", swg), format!("{:?}", swg1)); - } - - #[test] - fn test_waitings() { - let wg = WaitGroup::new(); - wg.add(1); - wg.add(1); - assert_eq!(wg.waitings(), 2); - } -} diff --git a/src/tokio.rs b/src/tokio.rs index 6cb4d85..429044a 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -161,7 +161,7 @@ impl AsyncWaitGroup { /// }); /// } /// ``` - pub fn done(&self) { + pub fn done(self) { if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { self.inner.notify.notify_waiters(); } @@ -261,6 +261,20 @@ impl<'a> Future for WaitGroupFuture<'a> { return Poll::Ready(()); } - self.project().notified.poll(cx) + let this = self.project(); + match this.notified.poll(cx) { + Poll::Pending => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(_) => { + if this.inner.inner.counter.load(Ordering::SeqCst) == 0 { + Poll::Ready(()) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } } } diff --git a/tests/future.rs b/tests/future.rs index 93af4b3..75e1a63 100644 --- a/tests/future.rs +++ b/tests/future.rs @@ -127,9 +127,11 @@ fn test_async_block_wait() { // do some time consuming task t_wg.done(); }); - + let spawner = |fut| { + async_std::task::spawn(fut); + }; // wait other thread completes - wg.block_wait(); + wg.block_wait(spawner); assert_eq!(wg.waitings(), 0); } diff --git a/tests/sync.rs b/tests/sync.rs new file mode 100644 index 0000000..2a10e43 --- /dev/null +++ b/tests/sync.rs @@ -0,0 +1,85 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use wg::WaitGroup; + +#[test] +fn test_sync_wait_group_reuse() { + let wg = WaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + for _ in 0..6 { + let wg = wg.add(1); + let ctrx = ctr.clone(); + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(5)); + ctrx.fetch_add(1, Ordering::Relaxed); + wg.done(); + }); + } + + wg.wait(); + assert_eq!(ctr.load(Ordering::Relaxed), 6); + + let worker = wg.add(1); + let ctrx = ctr.clone(); + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(5)); + ctrx.fetch_add(1, Ordering::Relaxed); + worker.done(); + }); + wg.wait(); + assert_eq!(ctr.load(Ordering::Relaxed), 7); +} + +#[test] +fn test_sync_wait_group_nested() { + let wg = WaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + for _ in 0..5 { + let worker = wg.add(1); + let ctrx = ctr.clone(); + std::thread::spawn(move || { + let nested_worker = worker.add(1); + let ctrxx = ctrx.clone(); + std::thread::spawn(move || { + ctrxx.fetch_add(1, Ordering::Relaxed); + nested_worker.done(); + }); + ctrx.fetch_add(1, Ordering::Relaxed); + worker.done(); + }); + } + + wg.wait(); + assert_eq!(ctr.load(Ordering::Relaxed), 10); +} + +#[test] +fn test_sync_wait_group_from() { + std::thread::scope(|s| { + let wg = WaitGroup::from(5); + for _ in 0..5 { + let t = wg.clone(); + s.spawn(move || { + t.done(); + }); + } + wg.wait(); + }); +} + +#[test] +fn test_clone_and_fmt() { + let swg = WaitGroup::new(); + let swg1 = swg.clone(); + swg1.add(3); + assert_eq!(format!("{:?}", swg), format!("{:?}", swg1)); +} + +#[test] +fn test_waitings() { + let wg = WaitGroup::new(); + wg.add(1); + wg.add(1); + assert_eq!(wg.waitings(), 2); +}