Skip to content

Commit

Permalink
Refactored API
Browse files Browse the repository at this point in the history
  • Loading branch information
ziflex committed Jun 29, 2024
1 parent 7cf20b2 commit 3b2fdc0
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 62 deletions.
46 changes: 35 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,25 @@ import (

type ApiClient struct {
transport *http.Client
throttler *throttle.Throttler[*http.Response]
throttler *throttle.Throttler
}

func NewApiClient(rps uint64) *ApiClient {
return &ApiClient{
transport: &http.Client{},
throttler: throttle.New[*http.Response](rps),
throttler: throttle.New(rps),
}
}

func (c *ApiClient) Do(ctx context.Context, req *http.Request) (*http.Response, error) {
return c.throttler.Do(func() (*http.Response, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return c.transport.Do(req)
}
})
c.throttler.Acquire()
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return c.transport.Do(req)
}
}
```

Expand Down Expand Up @@ -72,6 +72,30 @@ func (c *MyClock) Sleep(dur time.Duration) {
}

func main() {
throttler := throttle.New[any](10, throttle.WithClock(&MyClock{time.Millisecond * 250}))
throttler := throttle.New(10, throttle.WithClock(&MyClock{time.Millisecond * 250}))
}
```

## Helpers
### RoundTripper
The package contains a helper that wraps the standard `http.RoundTripper` interface and provides a throttling mechanism.

```go
package myapp

import (
"context"
"net/http"
"github.com/ziflex/throttle"
)

func main() {
transport := &http.Transport{}
client := &http.Client{
Transport: throttle.NewRoundTripper(transport, 10),
}

req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
client.Do(req)
}
```
36 changes: 14 additions & 22 deletions throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,34 @@ import (

const windowSize = time.Second

type (
// Fn represents a function that returns a value of type T and an error.
Fn[T any] func() (T, error)

// Throttler manages the execution of operations so that they don't exceed a specified rate limit.
Throttler[T any] struct {
mu sync.Mutex
window time.Time
clock Clock
counter uint64
limit uint64
}
)
// Throttler manages the execution of operations so that they don't exceed a specified rate limit.
type Throttler struct {
mu sync.Mutex
window time.Time
clock Clock
counter uint64
limit uint64
}

// New creates a new instance of Throttler with a specified limit.
func New[T any](limit uint64, setters ...Option) *Throttler[T] {
func New(limit uint64, setters ...Option) *Throttler {
opts := buildOptions(setters)

return &Throttler[T]{
return &Throttler{
limit: limit,
clock: opts.clock,
}
}

// Do executes the provided function fn if the rate limit has not been reached.
// It ensures that the operation respects the throttling constraints.
func (t *Throttler[T]) Do(fn Fn[T]) (T, error) {
// Acquire blocks until the operation can be executed within the rate limit.
func (t *Throttler) Acquire() {
t.mu.Lock()
t.advance()
t.mu.Unlock()

return fn()
}

// advance updates the throttler state, advancing the window or incrementing the counter as necessary.
func (t *Throttler[T]) advance() {
func (t *Throttler) advance() {
// pass through
if t.limit == 0 {
return
Expand Down Expand Up @@ -87,7 +79,7 @@ func (t *Throttler[T]) advance() {
}

// reset starts a new window from the specified start time and resets the operation counter.
func (t *Throttler[T]) reset(window time.Time) {
func (t *Throttler) reset(window time.Time) {
t.window = window
t.counter = 1
}
44 changes: 15 additions & 29 deletions throttle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,16 @@ func TestThrottler_Do_Consistent(t *testing.T) {
for _, useCase := range useCases {
t.Run(fmt.Sprintf("Consistent %d RPS within %d calls", useCase.Limit, useCase.Calls), func(t *testing.T) {
calls := make(chan time.Time, useCase.Calls)
throttler := throttle.New[time.Time](useCase.Limit)
throttler := throttle.New(useCase.Limit)
ts := time.Now()

var wg sync.WaitGroup
wg.Add(useCase.Calls)

for range useCase.Calls {
go func() {
res, _ := throttler.Do(func() (time.Time, error) {
return time.Now(), nil
})

calls <- res
throttler.Acquire()
calls <- time.Now()
wg.Done()
}()
}
Expand Down Expand Up @@ -152,7 +149,7 @@ func TestThrottler_Do_Sporadic(t *testing.T) {
}

calls := make(chan time.Time, buffer)
throttler := throttle.New[time.Time](useCase.Limit)
throttler := throttle.New(useCase.Limit)
ts := time.Now()

var wg sync.WaitGroup
Expand All @@ -169,19 +166,13 @@ func TestThrottler_Do_Sporadic(t *testing.T) {
}

for range callNum {
res, _ := throttler.Do(func() (time.Time, error) {
if latency > 0 {
time.Sleep(latency)
}

return time.Now(), nil
})

calls <- res
throttler.Acquire()

//ts := time.Now()
if latency > 0 {
time.Sleep(latency)
}

//fmt.Println(fmt.Sprintf("Call %dms", time.Since(ts).Milliseconds()))
calls <- time.Now()
}

wg.Done()
Expand Down Expand Up @@ -301,7 +292,7 @@ func TestThrottler_Do_Parallel(t *testing.T) {
for _, useCase := range useCases {
t.Run(fmt.Sprintf("Parallel %d RPS", useCase.Limit), func(t *testing.T) {
calls := make(chan time.Time, len(useCase.Calls))
throttler := throttle.New[time.Time](useCase.Limit)
throttler := throttle.New(useCase.Limit)
ts := time.Now()

var wg sync.WaitGroup
Expand All @@ -311,18 +302,13 @@ func TestThrottler_Do_Parallel(t *testing.T) {
go func(latency time.Duration) {
defer wg.Done()

// callTs := time.Now()
res, _ := throttler.Do(func() (time.Time, error) {
if latency > 0 {
time.Sleep(latency)
}
throttler.Acquire()

return time.Now(), nil
})

calls <- res
if latency > 0 {
time.Sleep(latency)
}

// fmt.Println(fmt.Sprintf("Call %dms", time.Since(callTs).Milliseconds()))
calls <- time.Now()
}(tpl.Latency)
}

Expand Down
27 changes: 27 additions & 0 deletions transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package throttle

import (
"net/http"
)

type throttledRoundTripper struct {
transport http.RoundTripper
throttler *Throttler
}

func (t *throttledRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
t.throttler.Acquire()

return t.transport.RoundTrip(request)
}

func NewRoundTripper(transport http.RoundTripper, limit uint64, setters ...Option) http.RoundTripper {
return NewRoundTripperWith(transport, New(limit, setters...))
}

func NewRoundTripperWith(transport http.RoundTripper, throttler *Throttler) http.RoundTripper {
return &throttledRoundTripper{
transport: transport,
throttler: throttler,
}
}

0 comments on commit 3b2fdc0

Please sign in to comment.