-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package ratelimit | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
|
||
"github.com/dreamsxin/go-kit/endpoint" | ||
) | ||
|
||
// ErrLimited is returned in the request path when the rate limiter is | ||
// triggered and the request is rejected. | ||
var ErrLimited = errors.New("rate limit exceeded") | ||
|
||
// Allower dictates whether or not a request is acceptable to run. | ||
// The Limiter from "golang.org/x/time/rate" already implements this interface, | ||
// one is able to use that in NewErroringLimiter without any modifications. | ||
type Allower interface { | ||
Allow() bool | ||
} | ||
|
||
// NewErroringLimiter returns an endpoint.Middleware that acts as a rate | ||
// limiter. Requests that would exceed the | ||
// maximum request rate are simply rejected with an error. | ||
func NewErroringLimiter(limit Allower) endpoint.Middleware { | ||
return func(next endpoint.Endpoint) endpoint.Endpoint { | ||
return func(ctx context.Context, request interface{}) (interface{}, error) { | ||
if !limit.Allow() { | ||
return nil, ErrLimited | ||
} | ||
return next(ctx, request) | ||
} | ||
} | ||
} | ||
|
||
// Waiter dictates how long a request must be delayed. | ||
// The Limiter from "golang.org/x/time/rate" already implements this interface, | ||
// one is able to use that in NewDelayingLimiter without any modifications. | ||
type Waiter interface { | ||
Wait(ctx context.Context) error | ||
} | ||
|
||
// NewDelayingLimiter returns an endpoint.Middleware that acts as a | ||
// request throttler. Requests that would | ||
// exceed the maximum request rate are delayed via the Waiter function | ||
func NewDelayingLimiter(limit Waiter) endpoint.Middleware { | ||
return func(next endpoint.Endpoint) endpoint.Endpoint { | ||
return func(ctx context.Context, request interface{}) (interface{}, error) { | ||
if err := limit.Wait(ctx); err != nil { | ||
return nil, err | ||
} | ||
return next(ctx, request) | ||
} | ||
} | ||
} | ||
|
||
// AllowerFunc is an adapter that lets a function operate as if | ||
// it implements Allower | ||
type AllowerFunc func() bool | ||
|
||
// Allow makes the adapter implement Allower | ||
func (f AllowerFunc) Allow() bool { | ||
return f() | ||
} | ||
|
||
// WaiterFunc is an adapter that lets a function operate as if | ||
// it implements Waiter | ||
type WaiterFunc func(ctx context.Context) error | ||
|
||
// Wait makes the adapter implement Waiter | ||
func (f WaiterFunc) Wait(ctx context.Context) error { | ||
return f(ctx) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package ratelimit_test | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"golang.org/x/time/rate" | ||
|
||
"github.com/dreamsxin/go-kit/endpoint" | ||
"github.com/dreamsxin/go-kit/endpoint/ratelimit" | ||
) | ||
|
||
var nopEndpoint = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | ||
|
||
func TestXRateErroring(t *testing.T) { | ||
limit := rate.NewLimiter(rate.Every(time.Minute), 1) | ||
testSuccessThenFailure( | ||
t, | ||
ratelimit.NewErroringLimiter(limit)(nopEndpoint), | ||
ratelimit.ErrLimited.Error()) | ||
} | ||
|
||
func TestXRateDelaying(t *testing.T) { | ||
limit := rate.NewLimiter(rate.Every(time.Minute), 1) | ||
testSuccessThenFailure( | ||
t, | ||
ratelimit.NewDelayingLimiter(limit)(nopEndpoint), | ||
"exceed context deadline") | ||
} | ||
|
||
func testSuccessThenFailure(t *testing.T, e endpoint.Endpoint, failContains string) { | ||
ctx, cxl := context.WithTimeout(context.Background(), 500*time.Millisecond) | ||
defer cxl() | ||
|
||
// First request should succeed. | ||
if _, err := e(ctx, struct{}{}); err != nil { | ||
t.Errorf("unexpected: %v\n", err) | ||
} | ||
|
||
// Next request should fail. | ||
if _, err := e(ctx, struct{}{}); !strings.Contains(err.Error(), failContains) { | ||
t.Errorf("expected `%s`: %v\n", failContains, err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters