Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RetryWithContext() and respect cancellation while sleeping #6

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ logFile.Chdir() // Do something with the file
### HTTP request with strategies and backoff

```go
var response *http.Response
action := func(ctx context.Context, attempt uint) error {
var response *http.Response

action := func(attempt uint) error {
var err error

response, err = http.Get("https://api.github.com/repos/Rican7/retry")
req, err := NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/Rican7/retry", nil)
if err == nil {
response, err = c.Do(req)
}

if nil == err && nil != response && response.StatusCode > 200 {
err = fmt.Errorf("failed to fetch (attempt #%d) with status code: %d", attempt, response.StatusCode)
Expand All @@ -69,7 +70,8 @@ action := func(attempt uint) error {
return err
}

err := retry.Retry(
err := retry.RetryWithContext(
context.TODO(),
action,
strategy.Limit(5),
strategy.Backoff(backoff.Fibonacci(10*time.Millisecond)),
Expand Down
50 changes: 45 additions & 5 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,73 @@
// Copyright © 2016 Trevor N. Suarez (Rican7)
package retry

import "github.com/Rican7/retry/strategy"
import (
"context"
"time"

"github.com/Rican7/retry/strategy"
)

// Action defines a callable function that package retry can handle.
type Action func(attempt uint) error

// ActionWithContext defines a callable function that package retry can handle.
type ActionWithContext func(ctx context.Context, attempt uint) error

// Retry takes an action and performs it, repetitively, until successful.
//
// Optionally, strategies may be passed that assess whether or not an attempt
// should be made.
func Retry(action Action, strategies ...strategy.Strategy) error {
return RetryWithContext(context.Background(), func(ctx context.Context, attempt uint) error { return action(attempt) }, strategies...)
}

// RetryWithContext takes an action and performs it, repetitively, until successful
// or the context is done.
//
// Optionally, strategies may be passed that assess whether or not an attempt
// should be made.
//
// Context errors take precedence over action errors so this commonplace test:
//
// err := retry.RetryWithContext(...)
// if err != nil { return err }
//
// will pass cancellation errors up the call chain.
func RetryWithContext(ctx context.Context, action ActionWithContext, strategies ...strategy.Strategy) error {
var err error

for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, strategies...); attempt++ {
err = action(attempt)
for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, sleepFunc(ctx), strategies...) && nil == ctx.Err(); attempt++ {
err = action(ctx, attempt)
}

if ctx.Err() != nil {
return ctx.Err()
}

return err
}

// shouldAttempt evaluates the provided strategies with the given attempt to
// determine if the Retry loop should make another attempt.
func shouldAttempt(attempt uint, strategies ...strategy.Strategy) bool {
func shouldAttempt(attempt uint, sleep func(time.Duration), strategies ...strategy.Strategy) bool {
shouldAttempt := true

for i := 0; shouldAttempt && i < len(strategies); i++ {
shouldAttempt = shouldAttempt && strategies[i](attempt)
shouldAttempt = shouldAttempt && strategies[i](attempt, sleep)
}

return shouldAttempt
}

// sleepFunc returns a function with the same signature as time.Sleep()
// that blocks for the given duration, but will return sooner if the context is
// cancelled or its deadline passes.
func sleepFunc(ctx context.Context) func(time.Duration) {
return func(d time.Duration) {
select {
case <-ctx.Done():
case <-time.After(d):
}
}
}
121 changes: 109 additions & 12 deletions retry_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package retry

import (
"context"
"errors"
"testing"
"time"
)

// timeMarginOfError represents the acceptable amount of time that may pass for
// a time-based (sleep) unit before considering invalid.
const timeMarginOfError = time.Millisecond

func TestRetry(t *testing.T) {
action := func(attempt uint) error {
return nil
Expand Down Expand Up @@ -47,8 +53,99 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) {
}
}

func TestRetryWithContextChecksContextAfterLastAttempt(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

strategy := func(attempt uint, sleep func(time.Duration)) bool {
if attempt == 0 {
return true
}

cancel()
return false
}

action := func(ctx context.Context, attempt uint) error {
return errors.New("erroring")
}

err := RetryWithContext(ctx, action, strategy)

if context.Canceled != err {
t.Error("expected a context error")
}
}

func TestRetryWithContextCancelStopsAttempts(t *testing.T) {
var numCalls int

ctx, cancel := context.WithCancel(context.Background())

action := func(ctx context.Context, attempt uint) error {
numCalls++

if numCalls == 1 {
cancel()
return ctx.Err()
}

return nil
}

err := RetryWithContext(ctx, action)

if 1 != numCalls {
t.Errorf("expected the action to be tried once, not %d times", numCalls)
}

if context.Canceled != err {
t.Error("expected a context error")
}
}

func TestRetryWithContextSleepIsInterrupted(t *testing.T) {
const sleepDuration = 100 * timeMarginOfError
fullySleptBy := time.Now().Add(sleepDuration)

strategy := func(attempt uint, sleep func(time.Duration)) bool {
if attempt > 0 {
sleep(sleepDuration)
}
return attempt <= 1
}

var numCalls int

action := func(ctx context.Context, attempt uint) error {
numCalls++
return errors.New("erroring")
}

stopAfter := 10 * timeMarginOfError
deadline := time.Now().Add(stopAfter)
ctx, _ := context.WithDeadline(context.Background(), deadline)

err := RetryWithContext(ctx, action, strategy)

if time.Now().Before(deadline) {
t.Errorf("expected to stop after %s", stopAfter)
}

if time.Now().After(fullySleptBy) {
t.Errorf("expected to stop before %s", sleepDuration)
}

if 1 != numCalls {
t.Errorf("expected the action to be tried once, not %d times", numCalls)
}

if context.DeadlineExceeded != err {
t.Error("expected a context error")
}
}

func TestShouldAttempt(t *testing.T) {
shouldAttempt := shouldAttempt(1)
shouldAttempt := shouldAttempt(1, time.Sleep)

if !shouldAttempt {
t.Error("expected to return true")
Expand All @@ -58,63 +155,63 @@ func TestShouldAttempt(t *testing.T) {
func TestShouldAttemptWithStrategy(t *testing.T) {
const attemptNumberShouldReturnFalse = 7

strategy := func(attempt uint) bool {
strategy := func(attempt uint, sleep func(time.Duration)) bool {
return (attemptNumberShouldReturnFalse != attempt)
}

should := shouldAttempt(1, strategy)
should := shouldAttempt(1, time.Sleep, strategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(1+attemptNumberShouldReturnFalse, strategy)
should = shouldAttempt(1+attemptNumberShouldReturnFalse, time.Sleep, strategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(attemptNumberShouldReturnFalse, strategy)
should = shouldAttempt(attemptNumberShouldReturnFalse, time.Sleep, strategy)

if should {
t.Error("expected to return false")
}
}

func TestShouldAttemptWithMultipleStrategies(t *testing.T) {
trueStrategy := func(attempt uint) bool {
trueStrategy := func(attempt uint, sleep func(time.Duration)) bool {
return true
}

falseStrategy := func(attempt uint) bool {
falseStrategy := func(attempt uint, sleep func(time.Duration)) bool {
return false
}

should := shouldAttempt(1, trueStrategy)
should := shouldAttempt(1, time.Sleep, trueStrategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(1, falseStrategy)
should = shouldAttempt(1, time.Sleep, falseStrategy)

if should {
t.Error("expected to return false")
}

should = shouldAttempt(1, trueStrategy, trueStrategy, trueStrategy)
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, trueStrategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(1, falseStrategy, falseStrategy, falseStrategy)
should = shouldAttempt(1, time.Sleep, falseStrategy, falseStrategy, falseStrategy)

if should {
t.Error("expected to return false")
}

should = shouldAttempt(1, trueStrategy, trueStrategy, falseStrategy)
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, falseStrategy)

if should {
t.Error("expected to return false")
Expand Down
18 changes: 9 additions & 9 deletions strategy/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ import (
// The strategy will be passed an "attempt" number on each successive retry
// iteration, starting with a `0` value before the first attempt is actually
// made. This allows for a pre-action delay, etc.
type Strategy func(attempt uint) bool
type Strategy func(attempt uint, sleep func(time.Duration)) bool

// Limit creates a Strategy that limits the number of attempts that Retry will
// make.
func Limit(attemptLimit uint) Strategy {
return func(attempt uint) bool {
return (attempt <= attemptLimit)
return func(attempt uint, sleep func(time.Duration)) bool {
return attempt <= attemptLimit
}
}

// Delay creates a Strategy that waits the given duration before the first
// attempt is made.
func Delay(duration time.Duration) Strategy {
return func(attempt uint) bool {
return func(attempt uint, sleep func(time.Duration)) bool {
if 0 == attempt {
time.Sleep(duration)
sleep(duration)
}

return true
Expand All @@ -44,15 +44,15 @@ func Delay(duration time.Duration) Strategy {
// the first. If the number of attempts is greater than the number of durations
// provided, then the strategy uses the last duration provided.
func Wait(durations ...time.Duration) Strategy {
return func(attempt uint) bool {
return func(attempt uint, sleep func(time.Duration)) bool {
if 0 < attempt && 0 < len(durations) {
durationIndex := int(attempt - 1)

if len(durations) <= durationIndex {
durationIndex = len(durations) - 1
}

time.Sleep(durations[durationIndex])
sleep(durations[durationIndex])
}

return true
Expand All @@ -68,9 +68,9 @@ func Backoff(algorithm backoff.Algorithm) Strategy {
// BackoffWithJitter creates a Strategy that waits before each attempt, with a
// duration as defined by the given backoff.Algorithm and jitter.Transformation.
func BackoffWithJitter(algorithm backoff.Algorithm, transformation jitter.Transformation) Strategy {
return func(attempt uint) bool {
return func(attempt uint, sleep func(time.Duration)) bool {
if 0 < attempt {
time.Sleep(transformation(algorithm(attempt)))
sleep(transformation(algorithm(attempt)))
}

return true
Expand Down
Loading