diff --git a/README.md b/README.md index 43f75d7..f5575cb 100644 --- a/README.md +++ b/README.md @@ -15,25 +15,31 @@ go get github.com/ziflex/throttle package myapp import ( + "context" "net/http" "github.com/ziflex/throttle" ) type ApiClient struct { transport *http.Client - throttler *throttle.Throttler + throttler *throttle.Throttler[*http.Response] } func NewApiClient(rps uint64) *ApiClient { return &ApiClient{ transport: &http.Client{}, - throttler: throttle.New(rps), - } + throttler: throttle.New[*http.Response](rps), + } } -func (c *ApiClient) Do(req *http.Request) (*http.Response, error) { - c.throttler.Wait() - - return c.transport.Do(req) +func (c *ApiClient) Do(ctx context.Context, req *http.Request) (*http.Response, error) { + return c.throttler.Do(func() (*http.Response, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return c.transport.Do(req) + } + }) } ``` \ No newline at end of file diff --git a/throttle.go b/throttle.go index 0b26c4f..7a1f7d1 100644 --- a/throttle.go +++ b/throttle.go @@ -5,36 +5,51 @@ import ( "time" ) -type Throttler struct { - mu sync.Mutex - window time.Time - counter uint64 - limit uint64 -} - -func New(limit uint64) *Throttler { - t := new(Throttler) - t.limit = limit +type ( + // Fn represents a function that returns a value of type T and an error. + Fn[T any] func() (T, error) + + // Throttler manages the execution of operations so that they don't exceed a specified rate limit. + Throttler[T any] struct { + mu sync.Mutex + window time.Time + counter uint64 + limit uint64 + } +) - return t +// New creates a new instance of Throttler with a specified limit. +func New[T any](limit uint64) *Throttler[T] { + return &Throttler[T]{ + limit: limit, + } } -func (t *Throttler) Wait() { +// Do executes the provided function fn if the rate limit has not been reached. +// It ensures that the operation respects the throttling constraints. +func (t *Throttler[T]) Do(fn Fn[T]) (T, error) { t.mu.Lock() - defer t.mu.Unlock() + t.advance() + res, err := fn() + t.mu.Unlock() + + return res, err +} +// advance updates the throttler state, advancing the window or incrementing the counter as necessary. +func (t *Throttler[T]) advance() { now := time.Now() - // if first call + // if this is the first operation, initialize the window if t.window.IsZero() { t.window = now } sinceLastCall := now.Sub(t.window) - // if we are past the current window - // start a new one and exit + // if the current window has expired if sinceLastCall > time.Second { + // start a new window t.reset(now) return @@ -42,24 +57,23 @@ func (t *Throttler) Wait() { nextCount := t.counter + 1 - // if we are in the limit and there is enough time left to process next operation - // we increase the counter and move on + // if adding another operation doesn't exceed the limit if t.limit >= nextCount { + // increment the counter t.counter = nextCount return } - leftInWindow := time.Second - sinceLastCall - - // otherwise wait for the next window - time.Sleep(leftInWindow) + // if the limit is reached, wait until the current window expires + time.Sleep(time.Second - sinceLastCall) - // new window + // after sleeping, reset to a new window starting now t.reset(time.Now()) } -func (t *Throttler) reset(window time.Time) { +// reset starts a new window from the specified start time and resets the operation counter. +func (t *Throttler[T]) reset(window time.Time) { t.window = window t.counter = 1 } diff --git a/throttle_test.go b/throttle_test.go index 3b7a45e..e32b482 100644 --- a/throttle_test.go +++ b/throttle_test.go @@ -9,6 +9,10 @@ import ( "time" ) +func currentTime() (time.Time, error) { + return time.Now(), nil +} + func TestThrottler_Wait_Consistent(t *testing.T) { useCases := []struct { Limit uint64 @@ -35,12 +39,12 @@ func TestThrottler_Wait_Consistent(t *testing.T) { for _, useCase := range useCases { t.Run(fmt.Sprintf("Consistent %d RPS within %d calls", useCase.Limit, useCase.Calls), func(t *testing.T) { calls := make(chan time.Time, useCase.Calls) - call := func(t *throttle.Throttler) { - t.Wait() - calls <- time.Now() + call := func(t *throttle.Throttler[time.Time]) { + ts, _ := t.Do(currentTime) + calls <- ts } - throttler := throttle.New(useCase.Limit) + throttler := throttle.New[time.Time](useCase.Limit) ts := time.Now() wg := sync.WaitGroup{} wg.Add(useCase.Calls) @@ -143,16 +147,16 @@ func TestThrottler_Wait_Sporadic(t *testing.T) { } calls := make(chan time.Time, totalCalls) - call := func(t *throttle.Throttler, latency time.Duration) { + call := func(t *throttle.Throttler[time.Time], latency time.Duration) { if latency > 0 { time.Sleep(latency) } - t.Wait() - calls <- time.Now() + ts, _ := t.Do(currentTime) + calls <- ts } - throttler := throttle.New(useCase.Limit) + throttler := throttle.New[time.Time](useCase.Limit) ts := time.Now() var wg sync.WaitGroup wg.Add(totalCalls)