diff --git a/client.go b/client.go index 1a03efa3..8bc1ab48 100644 --- a/client.go +++ b/client.go @@ -85,6 +85,12 @@ type ( // ErrorHook type is for reacting to request errors, called after all retries were attempted ErrorHook func(*Request, error) + + // Executor executes a Request + Executor func(req *Request) (*Response, error) + + // ExecutorMiddleware type wraps the execution of a request + ExecutorMiddleware func(req *Request, next Executor) (*Response, error) ) // Client struct is used to create Resty client with client level settings, @@ -140,6 +146,7 @@ type Client struct { requestLog RequestLogCallback responseLog ResponseLogCallback errorHooks []ErrorHook + executor Executor } // User type is to hold an username and password information @@ -444,6 +451,28 @@ func (c *Client) OnError(h ErrorHook) *Client { return c } +// WrapExecutor wraps the execution of a request, granting full access to the request, response, and error. +// Runs on every request attempt, before any request hook and after any response or error hook. +// Can be useful to introduce throttling or add hooks that always fire, regardless of success or error. +// +// c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { +// // do something with the Request +// // e.g. Acquire a lock +// +// resp, err := next(req) +// // do something with the Response or error +// // e.g. Release a lock +// +// return resp, err +// }) +func (c *Client) WrapExecutor(e ExecutorMiddleware) *Client { + next := c.executor + c.executor = func(req *Request) (*Response, error) { + return e(req, next) + } + return c +} + // SetPreRequestHook method sets the given pre-request function into resty client. // It is called right before the request is fired. // @@ -900,14 +929,14 @@ func (c *Client) execute(req *Request) (*Response, error) { // to modify the *resty.Request object for _, f := range c.udBeforeRequest { if err = f(c, req); err != nil { - return nil, wrapNoRetryErr(err) + return nil, err } } // resty middlewares for _, f := range c.beforeRequest { if err = f(c, req); err != nil { - return nil, wrapNoRetryErr(err) + return nil, err } } @@ -918,12 +947,12 @@ func (c *Client) execute(req *Request) (*Response, error) { // call pre-request if defined if c.preReqHook != nil { if err = c.preReqHook(c, req.RawRequest); err != nil { - return nil, wrapNoRetryErr(err) + return nil, err } } if err = requestLogger(c, req); err != nil { - return nil, wrapNoRetryErr(err) + return nil, err } req.RawRequest.Body = newRequestBodyReleaser(req.RawRequest.Body, req.bodyBuf) @@ -938,7 +967,7 @@ func (c *Client) execute(req *Request) (*Response, error) { if err != nil || req.notParseResponse || c.notParseResponse { response.setReceivedAt() - return response, err + return response, wrapTemporaryError(err) } if !req.isSaveResponse { @@ -951,7 +980,7 @@ func (c *Client) execute(req *Request) (*Response, error) { body, err = gzip.NewReader(body) if err != nil { response.setReceivedAt() - return response, err + return response, wrapTemporaryError(err) } defer closeq(body) } @@ -959,7 +988,7 @@ func (c *Client) execute(req *Request) (*Response, error) { if response.body, err = ioutil.ReadAll(body); err != nil { response.setReceivedAt() - return response, err + return response, wrapTemporaryError(err) } response.size = int64(len(response.body)) @@ -974,7 +1003,7 @@ func (c *Client) execute(req *Request) (*Response, error) { } } - return response, wrapNoRetryErr(err) + return response, err } // getting TLS client config if not exists then create one @@ -1092,6 +1121,8 @@ func createClient(hc *http.Client) *Client { // Logger c.SetLogger(createLogger()) + c.executor = c.execute + // default before request middlewares c.beforeRequest = []RequestMiddleware{ parseRequestURL, diff --git a/client_test.go b/client_test.go index 84ae715e..1e3ce2d8 100644 --- a/client_test.go +++ b/client_test.go @@ -735,6 +735,45 @@ func TestClientOnResponseError(t *testing.T) { } } +func TestWrapExecutor(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + t.Run("abort", func(t *testing.T) { + c := dc() + c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { + return nil, fmt.Errorf("abort") + }) + + resp, err := c.R().Get(ts.URL) + assertNil(t, resp) + assertEqual(t, "abort", err.Error()) + }) + + t.Run("noop", func(t *testing.T) { + c := dc() + c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { + return next(req) + }) + + resp, err := c.R().Get(ts.URL) + assertNil(t, err) + assertEqual(t, 200, resp.StatusCode()) + }) + + t.Run("add error", func(t *testing.T) { + c := dc() + c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { + resp, _ := next(req) + return resp, fmt.Errorf("error") + }) + + resp, err := c.R().Get(ts.URL) + assertEqual(t, "error", err.Error()) + assertEqual(t, 200, resp.StatusCode()) + }) +} + func TestResponseError(t *testing.T) { err := errors.New("error message") re := &ResponseError{ diff --git a/example_test.go b/example_test.go index 2d8d3f74..9522855c 100644 --- a/example_test.go +++ b/example_test.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "strconv" + "sync" "time" "golang.org/x/net/proxy" @@ -241,3 +242,35 @@ func Example_socks5Proxy() { func printOutput(resp *resty.Response, err error) { fmt.Println(resp, err) } + +// +// Throttling +// + +func ExampleClient_throttling() { + // Consider the use of proper throttler, possibly waiting for resources to free up + // e.g. https://github.com/throttled/throttled or https://pkg.go.dev/golang.org/x/time/rate + var lock sync.Mutex + currentConcurrent := 0 + maxConcurrent := 10 + + resty.New().WrapExecutor(func(req *resty.Request, next resty.Executor) (*resty.Response, error) { + lock.Lock() + current := currentConcurrent + if current == maxConcurrent { + lock.Unlock() + return nil, fmt.Errorf("max concurrency exceeded") + } + + current++ + lock.Unlock() + + defer func() { + lock.Lock() + current-- + lock.Unlock() + }() + + return next(req) + }) +} diff --git a/request.go b/request.go index 672df88c..36ab691b 100644 --- a/request.go +++ b/request.go @@ -745,7 +745,7 @@ func (r *Request) Execute(method, url string) (*Response, error) { if r.SRV != nil { _, addrs, err = net.LookupSRV(r.SRV.Service, "tcp", r.SRV.Domain) if err != nil { - r.client.onErrorHooks(r, nil, err) + r.client.onErrorHooks(r, resp, err) return nil, err } } @@ -755,9 +755,9 @@ func (r *Request) Execute(method, url string) (*Response, error) { if r.client.RetryCount == 0 { r.Attempt = 1 - resp, err = r.client.execute(r) - r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err)) - return resp, unwrapNoRetryErr(err) + resp, err = r.client.executor(r) + r.client.onErrorHooks(r, resp, err) + return resp, err } err = Backoff( @@ -766,7 +766,7 @@ func (r *Request) Execute(method, url string) (*Response, error) { r.URL = r.selectAddr(addrs, url, r.Attempt) - resp, err = r.client.execute(r) + resp, err = r.client.executor(r) if err != nil { r.client.log.Errorf("%v, Attempt %v", err, r.Attempt) } @@ -780,9 +780,9 @@ func (r *Request) Execute(method, url string) (*Response, error) { RetryHooks(r.client.RetryHooks), ) - r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err)) + r.client.onErrorHooks(r, resp, err) - return resp, unwrapNoRetryErr(err) + return resp, err } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ diff --git a/retry.go b/retry.go index 00b8514a..2117a500 100644 --- a/retry.go +++ b/retry.go @@ -111,11 +111,10 @@ func Backoff(operation func() (*Response, error), options ...Option) error { return err } - err1 := unwrapNoRetryErr(err) // raw error, it used for return users callback. - needsRetry := err != nil && err == err1 // retry on a few operation errors by default + needsRetry := isTemporaryError(err) // retry on temporary errors by default for _, condition := range opts.retryConditions { - needsRetry = condition(resp, err1) + needsRetry = condition(resp, err) if needsRetry { break } diff --git a/retry_test.go b/retry_test.go index 9f8fb387..a2991249 100644 --- a/retry_test.go +++ b/retry_test.go @@ -22,7 +22,7 @@ func TestBackoffSuccess(t *testing.T) { retryErr := Backoff(func() (*Response, error) { externalCounter++ if externalCounter < attempts { - return nil, errors.New("not yet got the number we're after") + return nil, wrapTemporaryError(errors.New("not yet got the number we're after")) } return nil, nil @@ -71,7 +71,7 @@ func TestBackoffTenAttemptsSuccess(t *testing.T) { retryErr := Backoff(func() (*Response, error) { externalCounter++ if externalCounter < attempts { - return nil, errors.New("not yet got the number we're after") + return nil, wrapTemporaryError(errors.New("not yet got the number we're after")) } return nil, nil }, Retries(attempts), WaitTime(5), MaxWaitTime(500)) diff --git a/util.go b/util.go index 1d563bef..597b346c 100644 --- a/util.go +++ b/util.go @@ -6,6 +6,7 @@ package resty import ( "bytes" + "errors" "fmt" "io" "log" @@ -368,24 +369,38 @@ func copyHeaders(hdrs http.Header) http.Header { return nh } -type noRetryErr struct { +type temporaryError struct { err error } -func (e *noRetryErr) Error() string { +func (e *temporaryError) Error() string { return e.err.Error() } -func wrapNoRetryErr(err error) error { - if err != nil { - err = &noRetryErr{err: err} +func (e *temporaryError) Unwrap() error { + return e.err +} + +func (e *temporaryError) Temporary() bool { + return true +} + +// wrapTemporaryError wraps an error to advertise it should be retryable, if it doesn't specify it already. +func wrapTemporaryError(err error) error { + if err == nil { + return nil } - return err + var tempError interface{ Temporary() bool } + if errors.As(err, &tempError) { + return err // Already exposes the method, honour it, even if false + } + return &temporaryError{err} } -func unwrapNoRetryErr(err error) error { - if e, ok := err.(*noRetryErr); ok { - err = e.err +func isTemporaryError(err error) bool { + var tempError interface{ Temporary() bool } + if errors.As(err, &tempError) { + return tempError.Temporary() } - return err + return false } diff --git a/util_test.go b/util_test.go index ef2bb915..3cdc6e1c 100644 --- a/util_test.go +++ b/util_test.go @@ -6,7 +6,9 @@ package resty import ( "bytes" + "errors" "mime/multipart" + "net" "testing" ) @@ -95,3 +97,22 @@ func TestWriteMultipartFormFileReaderError(t *testing.T) { assertNotNil(t, err) assertEqual(t, "read error", err.Error()) } + +func Test_wrapTemporaryError(t *testing.T) { + tests := []struct { + name string + base error + temp bool + }{ + {name: "nil", temp: false}, + {name: "dns temp", base: &net.DNSError{Err: "err", IsTemporary: true}, temp: true}, + {name: "dns not temp", base: &net.DNSError{Err: "err"}, temp: false}, + {name: "other", base: errors.New("foo"), temp: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := wrapTemporaryError(tt.base) + assertEqual(t, tt.temp, isTemporaryError(err)) + }) + } +}