diff --git a/README.md b/README.md index bccf4de..fd4ddc5 100644 --- a/README.md +++ b/README.md @@ -55,12 +55,13 @@ logFile.Chdir() // Do something with the file ### HTTP request with strategies and backoff ```go -var response *http.Response +action := func(ctx context.Context, attempt uint) error { + var response *http.Response -action := func(attempt uint) error { - var err error - - response, err = http.Get("https://api.github.com/repos/Rican7/retry") + req, err := NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/Rican7/retry", nil) + if err == nil { + response, err = c.Do(req) + } if nil == err && nil != response && response.StatusCode > 200 { err = fmt.Errorf("failed to fetch (attempt #%d) with status code: %d", attempt, response.StatusCode) @@ -69,7 +70,8 @@ action := func(attempt uint) error { return err } -err := retry.Retry( +err := retry.RetryWithContext( + context.TODO(), action, strategy.Limit(5), strategy.Backoff(backoff.Fibonacci(10*time.Millisecond)), diff --git a/retry.go b/retry.go index 15015db..b7fcd44 100644 --- a/retry.go +++ b/retry.go @@ -4,20 +4,48 @@ // Copyright © 2016 Trevor N. Suarez (Rican7) package retry -import "github.com/Rican7/retry/strategy" +import ( + "context" + "time" + + "github.com/Rican7/retry/strategy" +) // Action defines a callable function that package retry can handle. type Action func(attempt uint) error +// ActionWithContext defines a callable function that package retry can handle. +type ActionWithContext func(ctx context.Context, attempt uint) error + // Retry takes an action and performs it, repetitively, until successful. // // Optionally, strategies may be passed that assess whether or not an attempt // should be made. func Retry(action Action, strategies ...strategy.Strategy) error { + return RetryWithContext(context.Background(), func(ctx context.Context, attempt uint) error { return action(attempt) }, strategies...) +} + +// RetryWithContext takes an action and performs it, repetitively, until successful +// or the context is done. +// +// Optionally, strategies may be passed that assess whether or not an attempt +// should be made. +// +// Context errors take precedence over action errors so this commonplace test: +// +// err := retry.RetryWithContext(...) +// if err != nil { return err } +// +// will pass cancellation errors up the call chain. +func RetryWithContext(ctx context.Context, action ActionWithContext, strategies ...strategy.Strategy) error { var err error - for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, strategies...); attempt++ { - err = action(attempt) + for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, sleepFunc(ctx), strategies...) && nil == ctx.Err(); attempt++ { + err = action(ctx, attempt) + } + + if ctx.Err() != nil { + return ctx.Err() } return err @@ -25,12 +53,24 @@ func Retry(action Action, strategies ...strategy.Strategy) error { // shouldAttempt evaluates the provided strategies with the given attempt to // determine if the Retry loop should make another attempt. -func shouldAttempt(attempt uint, strategies ...strategy.Strategy) bool { +func shouldAttempt(attempt uint, sleep func(time.Duration), strategies ...strategy.Strategy) bool { shouldAttempt := true for i := 0; shouldAttempt && i < len(strategies); i++ { - shouldAttempt = shouldAttempt && strategies[i](attempt) + shouldAttempt = shouldAttempt && strategies[i](attempt, sleep) } return shouldAttempt } + +// sleepFunc returns a function with the same signature as time.Sleep() +// that blocks for the given duration, but will return sooner if the context is +// cancelled or its deadline passes. +func sleepFunc(ctx context.Context) func(time.Duration) { + return func(d time.Duration) { + select { + case <-ctx.Done(): + case <-time.After(d): + } + } +} diff --git a/retry_test.go b/retry_test.go index 8340a15..97b96b7 100644 --- a/retry_test.go +++ b/retry_test.go @@ -1,10 +1,16 @@ package retry import ( + "context" "errors" "testing" + "time" ) +// timeMarginOfError represents the acceptable amount of time that may pass for +// a time-based (sleep) unit before considering invalid. +const timeMarginOfError = time.Millisecond + func TestRetry(t *testing.T) { action := func(attempt uint) error { return nil @@ -47,8 +53,99 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) { } } +func TestRetryWithContextChecksContextAfterLastAttempt(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + strategy := func(attempt uint, sleep func(time.Duration)) bool { + if attempt == 0 { + return true + } + + cancel() + return false + } + + action := func(ctx context.Context, attempt uint) error { + return errors.New("erroring") + } + + err := RetryWithContext(ctx, action, strategy) + + if context.Canceled != err { + t.Error("expected a context error") + } +} + +func TestRetryWithContextCancelStopsAttempts(t *testing.T) { + var numCalls int + + ctx, cancel := context.WithCancel(context.Background()) + + action := func(ctx context.Context, attempt uint) error { + numCalls++ + + if numCalls == 1 { + cancel() + return ctx.Err() + } + + return nil + } + + err := RetryWithContext(ctx, action) + + if 1 != numCalls { + t.Errorf("expected the action to be tried once, not %d times", numCalls) + } + + if context.Canceled != err { + t.Error("expected a context error") + } +} + +func TestRetryWithContextSleepIsInterrupted(t *testing.T) { + const sleepDuration = 100 * timeMarginOfError + fullySleptBy := time.Now().Add(sleepDuration) + + strategy := func(attempt uint, sleep func(time.Duration)) bool { + if attempt > 0 { + sleep(sleepDuration) + } + return attempt <= 1 + } + + var numCalls int + + action := func(ctx context.Context, attempt uint) error { + numCalls++ + return errors.New("erroring") + } + + stopAfter := 10 * timeMarginOfError + deadline := time.Now().Add(stopAfter) + ctx, _ := context.WithDeadline(context.Background(), deadline) + + err := RetryWithContext(ctx, action, strategy) + + if time.Now().Before(deadline) { + t.Errorf("expected to stop after %s", stopAfter) + } + + if time.Now().After(fullySleptBy) { + t.Errorf("expected to stop before %s", sleepDuration) + } + + if 1 != numCalls { + t.Errorf("expected the action to be tried once, not %d times", numCalls) + } + + if context.DeadlineExceeded != err { + t.Error("expected a context error") + } +} + func TestShouldAttempt(t *testing.T) { - shouldAttempt := shouldAttempt(1) + shouldAttempt := shouldAttempt(1, time.Sleep) if !shouldAttempt { t.Error("expected to return true") @@ -58,23 +155,23 @@ func TestShouldAttempt(t *testing.T) { func TestShouldAttemptWithStrategy(t *testing.T) { const attemptNumberShouldReturnFalse = 7 - strategy := func(attempt uint) bool { + strategy := func(attempt uint, sleep func(time.Duration)) bool { return (attemptNumberShouldReturnFalse != attempt) } - should := shouldAttempt(1, strategy) + should := shouldAttempt(1, time.Sleep, strategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(1+attemptNumberShouldReturnFalse, strategy) + should = shouldAttempt(1+attemptNumberShouldReturnFalse, time.Sleep, strategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(attemptNumberShouldReturnFalse, strategy) + should = shouldAttempt(attemptNumberShouldReturnFalse, time.Sleep, strategy) if should { t.Error("expected to return false") @@ -82,39 +179,39 @@ func TestShouldAttemptWithStrategy(t *testing.T) { } func TestShouldAttemptWithMultipleStrategies(t *testing.T) { - trueStrategy := func(attempt uint) bool { + trueStrategy := func(attempt uint, sleep func(time.Duration)) bool { return true } - falseStrategy := func(attempt uint) bool { + falseStrategy := func(attempt uint, sleep func(time.Duration)) bool { return false } - should := shouldAttempt(1, trueStrategy) + should := shouldAttempt(1, time.Sleep, trueStrategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(1, falseStrategy) + should = shouldAttempt(1, time.Sleep, falseStrategy) if should { t.Error("expected to return false") } - should = shouldAttempt(1, trueStrategy, trueStrategy, trueStrategy) + should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, trueStrategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(1, falseStrategy, falseStrategy, falseStrategy) + should = shouldAttempt(1, time.Sleep, falseStrategy, falseStrategy, falseStrategy) if should { t.Error("expected to return false") } - should = shouldAttempt(1, trueStrategy, trueStrategy, falseStrategy) + should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, falseStrategy) if should { t.Error("expected to return false") diff --git a/strategy/strategy.go b/strategy/strategy.go index a315fa0..5d4b2b6 100644 --- a/strategy/strategy.go +++ b/strategy/strategy.go @@ -18,22 +18,22 @@ import ( // The strategy will be passed an "attempt" number on each successive retry // iteration, starting with a `0` value before the first attempt is actually // made. This allows for a pre-action delay, etc. -type Strategy func(attempt uint) bool +type Strategy func(attempt uint, sleep func(time.Duration)) bool // Limit creates a Strategy that limits the number of attempts that Retry will // make. func Limit(attemptLimit uint) Strategy { - return func(attempt uint) bool { - return (attempt <= attemptLimit) + return func(attempt uint, sleep func(time.Duration)) bool { + return attempt <= attemptLimit } } // Delay creates a Strategy that waits the given duration before the first // attempt is made. func Delay(duration time.Duration) Strategy { - return func(attempt uint) bool { + return func(attempt uint, sleep func(time.Duration)) bool { if 0 == attempt { - time.Sleep(duration) + sleep(duration) } return true @@ -44,7 +44,7 @@ func Delay(duration time.Duration) Strategy { // the first. If the number of attempts is greater than the number of durations // provided, then the strategy uses the last duration provided. func Wait(durations ...time.Duration) Strategy { - return func(attempt uint) bool { + return func(attempt uint, sleep func(time.Duration)) bool { if 0 < attempt && 0 < len(durations) { durationIndex := int(attempt - 1) @@ -52,7 +52,7 @@ func Wait(durations ...time.Duration) Strategy { durationIndex = len(durations) - 1 } - time.Sleep(durations[durationIndex]) + sleep(durations[durationIndex]) } return true @@ -68,9 +68,9 @@ func Backoff(algorithm backoff.Algorithm) Strategy { // BackoffWithJitter creates a Strategy that waits before each attempt, with a // duration as defined by the given backoff.Algorithm and jitter.Transformation. func BackoffWithJitter(algorithm backoff.Algorithm, transformation jitter.Transformation) Strategy { - return func(attempt uint) bool { + return func(attempt uint, sleep func(time.Duration)) bool { if 0 < attempt { - time.Sleep(transformation(algorithm(attempt))) + sleep(transformation(algorithm(attempt))) } return true diff --git a/strategy/strategy_test.go b/strategy/strategy_test.go index 17488f5..6b9c644 100644 --- a/strategy/strategy_test.go +++ b/strategy/strategy_test.go @@ -5,45 +5,38 @@ import ( "time" ) -// timeMarginOfError represents the acceptable amount of time that may pass for -// a time-based (sleep) unit before considering invalid. -const timeMarginOfError = time.Millisecond - func TestLimit(t *testing.T) { const attemptLimit = 3 strategy := Limit(attemptLimit) - if !strategy(1) { + if !strategy(1, time.Sleep) { t.Error("strategy expected to return true") } - if !strategy(2) { + if !strategy(2, time.Sleep) { t.Error("strategy expected to return true") } - if !strategy(3) { + if !strategy(3, time.Sleep) { t.Error("strategy expected to return true") } - if strategy(4) { + if strategy(4, time.Sleep) { t.Error("strategy expected to return false") } } func TestDelay(t *testing.T) { - const delayDuration = time.Duration(10 * timeMarginOfError) + const delayDuration = time.Duration(10) strategy := Delay(delayDuration) - if now := time.Now(); !strategy(0) || delayDuration > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - time.Duration(delayDuration), - ) + if spy, actual := sleepSpy(); !strategy(0, spy) || delayDuration != *actual { + t.Errorf("strategy expected to return true in %s", delayDuration) } - if now := time.Now(); !strategy(5) || (delayDuration/10) < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(5, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } } @@ -51,71 +44,59 @@ func TestDelay(t *testing.T) { func TestWait(t *testing.T) { strategy := Wait() - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } - if now := time.Now(); !strategy(999) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(999, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } } func TestWaitWithDuration(t *testing.T) { - const waitDuration = time.Duration(10 * timeMarginOfError) + const waitDuration = time.Duration(10) strategy := Wait(waitDuration) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } - if now := time.Now(); !strategy(1) || waitDuration > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - time.Duration(waitDuration), - ) + if spy, actual := sleepSpy(); !strategy(1, spy) || waitDuration != *actual { + t.Errorf("strategy expected to return true in %s", waitDuration) } } func TestWaitWithMultipleDurations(t *testing.T) { waitDurations := []time.Duration{ - time.Duration(10 * timeMarginOfError), - time.Duration(20 * timeMarginOfError), - time.Duration(30 * timeMarginOfError), - time.Duration(40 * timeMarginOfError), + time.Duration(10), + time.Duration(20), + time.Duration(30), + time.Duration(40), } strategy := Wait(waitDurations...) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } - if now := time.Now(); !strategy(1) || waitDurations[0] > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - time.Duration(waitDurations[0]), - ) + if spy, actual := sleepSpy(); !strategy(1, spy) || waitDurations[0] != *actual { + t.Errorf("strategy expected to return true in %s", waitDurations[0]) } - if now := time.Now(); !strategy(3) || waitDurations[2] > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - waitDurations[2], - ) + if spy, actual := sleepSpy(); !strategy(3, spy) || waitDurations[2] != *actual { + t.Errorf("strategy expected to return true in %s", waitDurations[2]) } - if now := time.Now(); !strategy(999) || waitDurations[len(waitDurations)-1] > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - waitDurations[len(waitDurations)-1], - ) + if spy, actual := sleepSpy(); !strategy(999, spy) || waitDurations[len(waitDurations)-1] != *actual { + t.Errorf("strategy expected to return true in %s", waitDurations[len(waitDurations)-1]) } } func TestBackoff(t *testing.T) { - const backoffDuration = time.Duration(10 * timeMarginOfError) - const algorithmDurationBase = timeMarginOfError + const backoffDuration = time.Duration(10) + const algorithmDurationBase = time.Duration(1) algorithm := func(attempt uint) time.Duration { return backoffDuration - (algorithmDurationBase * time.Duration(attempt)) @@ -123,48 +104,42 @@ func TestBackoff(t *testing.T) { strategy := Backoff(algorithm) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } for i := uint(1); i < 10; i++ { expectedResult := algorithm(i) - if now := time.Now(); !strategy(i) || expectedResult > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - expectedResult, - ) + if spy, actual := sleepSpy(); !strategy(i, spy) || expectedResult != *actual { + t.Errorf("strategy expected to return true in %s", expectedResult) } } } func TestBackoffWithJitter(t *testing.T) { - const backoffDuration = time.Duration(10 * timeMarginOfError) - const algorithmDurationBase = timeMarginOfError + const backoffDuration = time.Duration(20) + const algorithmDurationBase = time.Duration(1) algorithm := func(attempt uint) time.Duration { return backoffDuration - (algorithmDurationBase * time.Duration(attempt)) } transformation := func(duration time.Duration) time.Duration { - return duration - time.Duration(10*timeMarginOfError) + return duration - time.Duration(10) } strategy := BackoffWithJitter(algorithm, transformation) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } for i := uint(1); i < 10; i++ { expectedResult := transformation(algorithm(i)) - if now := time.Now(); !strategy(i) || expectedResult > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - expectedResult, - ) + if spy, actual := sleepSpy(); !strategy(i, spy) || expectedResult != *actual { + t.Errorf("strategy expected to return true in %s", expectedResult) } } } @@ -173,7 +148,7 @@ func TestNoJitter(t *testing.T) { transformation := noJitter() for i := uint(0); i < 10; i++ { - duration := time.Duration(i) * timeMarginOfError + duration := time.Duration(i) result := transformation(duration) expected := duration @@ -182,3 +157,11 @@ func TestNoJitter(t *testing.T) { } } } + +// sleepSpy returns a spy for the time.Sleep function that sums the +// durations passed to it. +func sleepSpy() (func(time.Duration), *time.Duration) { + var actual time.Duration + + return func(d time.Duration) { actual += d }, &actual +}