diff --git a/clockwork.go b/clockwork.go index 85a9934..3e6c8a7 100644 --- a/clockwork.go +++ b/clockwork.go @@ -121,7 +121,15 @@ func (fc *FakeClock) After(d time.Duration) <-chan time.Time { // Sleep blocks until the given duration has passed on the fakeClock. func (fc *FakeClock) Sleep(d time.Duration) { - <-fc.After(d) + fc.SleepNotify(d, make(chan struct{})) +} + +// SleepNotify blocks until the given duration has passed on the fakeClock. +// Notify "ch" once the waiters has been updated +func (fc *FakeClock) SleepNotify(d time.Duration, ch chan struct{}) { + afterCh := fc.After(d) + close(ch) + <-afterCh } // Now returns the current time of the fakeClock diff --git a/clockwork_test.go b/clockwork_test.go index 7f25606..1b143d6 100644 --- a/clockwork_test.go +++ b/clockwork_test.go @@ -3,6 +3,8 @@ package clockwork import ( "context" "errors" + "math/rand" + "sync/atomic" "testing" "time" ) @@ -209,3 +211,26 @@ func TestFakeClockRace(t *testing.T) { go func() { fc.NewTimer(d) }() go func() { fc.Sleep(d) }() } + +func TestSleepNotify(t *testing.T) { + var calls atomic.Int32 + clock := NewFakeClock() + beforeCh := make(chan struct{}) + afterCh := make(chan struct{}) + go func() { // thread #1 + clock.SleepNotify(time.Minute, beforeCh) // We want to wait for this before advancing the clock + calls.Add(1) + close(afterCh) + }() + go func() { // thread #2 + if rand.Intn(2) == 0 { // 50% chance of making another Sleep + clock.Sleep(time.Hour) + } + }() + <-beforeCh + clock.Advance(time.Minute) + <-afterCh + if calls.Load() != 1 { + t.Fatalf("SleepNotify() did not call the callback") + } +}