diff --git a/throttle.go b/throttle.go index f7d5317..b63a3cd 100644 --- a/throttle.go +++ b/throttle.go @@ -30,10 +30,9 @@ func New[T any](limit uint64) *Throttler[T] { func (t *Throttler[T]) Do(fn Fn[T]) (T, error) { t.mu.Lock() t.advance() - res, err := fn() t.mu.Unlock() - return res, err + return fn() } // advance updates the throttler state, advancing the window or incrementing the counter as necessary. @@ -42,7 +41,7 @@ func (t *Throttler[T]) advance() { if t.limit == 0 { return } - + now := time.Now() // if this is the first operation, initialize the window diff --git a/throttle_test.go b/throttle_test.go index 93c60f8..aee6cc0 100644 --- a/throttle_test.go +++ b/throttle_test.go @@ -9,11 +9,11 @@ import ( "time" ) -func currentTime() (time.Time, error) { - return time.Now(), nil +func seconds(fraction float64) time.Duration { + return time.Duration(float64(time.Second) * fraction) } -func TestThrottler_Wait_Consistent(t *testing.T) { +func TestThrottler_Do_Consistent(t *testing.T) { useCases := []struct { Limit uint64 Calls int @@ -43,19 +43,19 @@ func TestThrottler_Wait_Consistent(t *testing.T) { for _, useCase := range useCases { t.Run(fmt.Sprintf("Consistent %d RPS within %d calls", useCase.Limit, useCase.Calls), func(t *testing.T) { calls := make(chan time.Time, useCase.Calls) - call := func(t *throttle.Throttler[time.Time]) { - ts, _ := t.Do(currentTime) - calls <- ts - } - throttler := throttle.New[time.Time](useCase.Limit) ts := time.Now() - wg := sync.WaitGroup{} + + var wg sync.WaitGroup wg.Add(useCase.Calls) for range useCase.Calls { go func() { - call(throttler) + res, _ := throttler.Do(func() (time.Time, error) { + return time.Now(), nil + }) + + calls <- res wg.Done() }() } @@ -86,16 +86,13 @@ func TestThrottler_Wait_Consistent(t *testing.T) { } } -func TestThrottler_Wait_Sporadic(t *testing.T) { +func TestThrottler_Do_Sporadic(t *testing.T) { type Burst struct { Warmup time.Duration Latency time.Duration Calls int } - seconds := func(fraction float64) time.Duration { - return time.Duration(float64(time.Second) * fraction) - } useCases := []struct { Limit uint64 Calls []Burst @@ -148,39 +145,46 @@ func TestThrottler_Wait_Sporadic(t *testing.T) { for _, useCase := range useCases { t.Run(fmt.Sprintf("Sporadic %d RPS within %d calls", useCase.Limit, useCase.Calls), func(t *testing.T) { - var totalCalls int + var buffer int for _, tp := range useCase.Calls { - totalCalls += tp.Calls - } - - calls := make(chan time.Time, totalCalls) - call := func(t *throttle.Throttler[time.Time], latency time.Duration) { - if latency > 0 { - time.Sleep(latency) - } - - ts, _ := t.Do(currentTime) - calls <- ts + buffer += tp.Calls } + calls := make(chan time.Time, buffer) throttler := throttle.New[time.Time](useCase.Limit) ts := time.Now() + var wg sync.WaitGroup - wg.Add(totalCalls) + wg.Add(len(useCase.Calls)) go func() { for _, tpl := range useCase.Calls { - if tpl.Warmup > 0 { - time.Sleep(tpl.Warmup) + warmup := tpl.Warmup + latency := tpl.Latency + callNum := tpl.Calls + + if warmup > 0 { + time.Sleep(warmup) } - for range tpl.Calls { + for range callNum { + res, _ := throttler.Do(func() (time.Time, error) { + if latency > 0 { + time.Sleep(latency) + } + + return time.Now(), nil + }) + + calls <- res + //ts := time.Now() - call(throttler, tpl.Latency) - wg.Done() + //fmt.Println(fmt.Sprintf("Call %dms", time.Since(ts).Milliseconds())) } + + wg.Done() } }() @@ -194,7 +198,145 @@ func TestThrottler_Wait_Sporadic(t *testing.T) { dur := math.Abs(math.Floor(diff.Seconds())) groups[dur]++ - //fmt.Println(fmt.Sprintf("Elapsed %ds", int64(dur))) + // fmt.Println(fmt.Sprintf("Elapsed %ds", int64(dur))) + } + + for sec, actual := range groups { + expected, found := useCase.Expected[sec] + + if !found { + t.Fatal(fmt.Sprintf("Expected to have calls within %ds time range", int64(sec))) + } + + if actual != expected { + t.Fatal(fmt.Sprintf("Expected %d per second, but got %d", expected, actual)) + } + } + }) + } +} + +func TestThrottler_Do_Parallel(t *testing.T) { + type Call struct { + Latency time.Duration + } + + useCases := []struct { + Limit uint64 + Calls []Call + Expected map[float64]uint64 + }{ + { + Limit: 1, + Calls: []Call{ + {}, + {}, + {}, + {}, + {}, + }, + Expected: map[float64]uint64{ + 0: 1, + 1: 1, + 2: 1, + 3: 1, + 4: 1, + }, + }, + { + Limit: 5, + Calls: []Call{ + { + Latency: seconds(0.99), + }, + { + Latency: seconds(0.99), + }, + { + Latency: seconds(0.99), + }, + { + Latency: seconds(0.99), + }, + { + Latency: seconds(0.99), + }, + }, + Expected: map[float64]uint64{ + 0: 5, + }, + }, + + { + Limit: 5, + Calls: []Call{ + { + Latency: seconds(0.5), + }, + { + Latency: seconds(0.5), + }, + { + Latency: seconds(0.5), + }, + { + Latency: seconds(0.5), + }, + { + Latency: seconds(0.99), + }, + { + Latency: seconds(1), + }, + {}, + }, + Expected: map[float64]uint64{ + 0: 5, + 1: 1, + 2: 1, + }, + }, + } + + for _, useCase := range useCases { + t.Run(fmt.Sprintf("Parallel %d RPS", useCase.Limit), func(t *testing.T) { + calls := make(chan time.Time, len(useCase.Calls)) + throttler := throttle.New[time.Time](useCase.Limit) + ts := time.Now() + + var wg sync.WaitGroup + wg.Add(len(useCase.Calls)) + + for _, tpl := range useCase.Calls { + go func(latency time.Duration) { + defer wg.Done() + + // callTs := time.Now() + res, _ := throttler.Do(func() (time.Time, error) { + if latency > 0 { + time.Sleep(latency) + } + + return time.Now(), nil + }) + + calls <- res + + // fmt.Println(fmt.Sprintf("Call %dms", time.Since(callTs).Milliseconds())) + }(tpl.Latency) + } + + wg.Wait() + close(calls) + + groups := map[float64]uint64{} + + for c := range calls { + diff := c.Sub(ts) + dur := math.Abs(math.Floor(diff.Seconds())) + groups[dur]++ + + // fmt.Println(fmt.Sprintf("Elapsed %ds", int64(dur))) } for sec, actual := range groups {