Skip to content

Commit

Permalink
Switched to generic method
Browse files Browse the repository at this point in the history
  • Loading branch information
ziflex committed May 2, 2024
1 parent d8bf96c commit f1c4216
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 39 deletions.
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,31 @@ go get github.com/ziflex/throttle
package myapp

import (
"context"
"net/http"
"github.com/ziflex/throttle"
)

type ApiClient struct {
transport *http.Client
throttler *throttle.Throttler
throttler *throttle.Throttler[*http.Response]
}

func NewApiClient(rps uint64) *ApiClient {
return &ApiClient{
transport: &http.Client{},
throttler: throttle.New(rps),
}
throttler: throttle.New[*http.Response](rps),
}
}

func (c *ApiClient) Do(req *http.Request) (*http.Response, error) {
c.throttler.Wait()

return c.transport.Do(req)
func (c *ApiClient) Do(ctx context.Context, req *http.Request) (*http.Response, error) {
return c.throttler.Do(func() (*http.Response, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return c.transport.Do(req)
}
})
}
```
62 changes: 38 additions & 24 deletions throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,75 @@ import (
"time"
)

type Throttler struct {
mu sync.Mutex
window time.Time
counter uint64
limit uint64
}

func New(limit uint64) *Throttler {
t := new(Throttler)
t.limit = limit
type (
// Fn represents a function that returns a value of type T and an error.
Fn[T any] func() (T, error)

// Throttler manages the execution of operations so that they don't exceed a specified rate limit.
Throttler[T any] struct {
mu sync.Mutex
window time.Time
counter uint64
limit uint64
}
)

return t
// New creates a new instance of Throttler with a specified limit.
func New[T any](limit uint64) *Throttler[T] {
return &Throttler[T]{
limit: limit,
}
}

func (t *Throttler) Wait() {
// Do executes the provided function fn if the rate limit has not been reached.
// It ensures that the operation respects the throttling constraints.
func (t *Throttler[T]) Do(fn Fn[T]) (T, error) {
t.mu.Lock()
defer t.mu.Unlock()
t.advance()
res, err := fn()
t.mu.Unlock()

return res, err
}

// advance updates the throttler state, advancing the window or incrementing the counter as necessary.
func (t *Throttler[T]) advance() {
now := time.Now()

// if first call
// if this is the first operation, initialize the window
if t.window.IsZero() {
t.window = now
}

sinceLastCall := now.Sub(t.window)

// if we are past the current window
// start a new one and exit
// if the current window has expired
if sinceLastCall > time.Second {
// start a new window
t.reset(now)

return
}

nextCount := t.counter + 1

// if we are in the limit and there is enough time left to process next operation
// we increase the counter and move on
// if adding another operation doesn't exceed the limit
if t.limit >= nextCount {
// increment the counter
t.counter = nextCount

return
}

leftInWindow := time.Second - sinceLastCall

// otherwise wait for the next window
time.Sleep(leftInWindow)
// if the limit is reached, wait until the current window expires
time.Sleep(time.Second - sinceLastCall)

// new window
// after sleeping, reset to a new window starting now
t.reset(time.Now())
}

func (t *Throttler) reset(window time.Time) {
// reset starts a new window from the specified start time and resets the operation counter.
func (t *Throttler[T]) reset(window time.Time) {
t.window = window
t.counter = 1
}
20 changes: 12 additions & 8 deletions throttle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
"time"
)

func currentTime() (time.Time, error) {
return time.Now(), nil
}

func TestThrottler_Wait_Consistent(t *testing.T) {
useCases := []struct {
Limit uint64
Expand All @@ -35,12 +39,12 @@ func TestThrottler_Wait_Consistent(t *testing.T) {
for _, useCase := range useCases {
t.Run(fmt.Sprintf("Consistent %d RPS within %d calls", useCase.Limit, useCase.Calls), func(t *testing.T) {
calls := make(chan time.Time, useCase.Calls)
call := func(t *throttle.Throttler) {
t.Wait()
calls <- time.Now()
call := func(t *throttle.Throttler[time.Time]) {
ts, _ := t.Do(currentTime)
calls <- ts
}

throttler := throttle.New(useCase.Limit)
throttler := throttle.New[time.Time](useCase.Limit)
ts := time.Now()
wg := sync.WaitGroup{}
wg.Add(useCase.Calls)
Expand Down Expand Up @@ -143,16 +147,16 @@ func TestThrottler_Wait_Sporadic(t *testing.T) {
}

calls := make(chan time.Time, totalCalls)
call := func(t *throttle.Throttler, latency time.Duration) {
call := func(t *throttle.Throttler[time.Time], latency time.Duration) {
if latency > 0 {
time.Sleep(latency)
}

t.Wait()
calls <- time.Now()
ts, _ := t.Do(currentTime)
calls <- ts
}

throttler := throttle.New(useCase.Limit)
throttler := throttle.New[time.Time](useCase.Limit)
ts := time.Now()
var wg sync.WaitGroup
wg.Add(totalCalls)
Expand Down

0 comments on commit f1c4216

Please sign in to comment.