Skip to content

Commit

Permalink
fix: rename package / thread safety (#3)
Browse files Browse the repository at this point in the history
* fix: make bucket concurrency safe

- don't rely on the rate limiter to take care of safe access

* fix: remove unsafe code

- give buckets their own concurrency safety
- finer locking for lastRefillAt

* fix: better naming for last refill time and refill duration

* fix: use nanosecond precision

- update changelog
- add benchmarks
- rename variables
- rename package

* feat: run bench in gh actions
  • Loading branch information
vivangkumar authored Mar 30, 2023
1 parent 1dbf7ca commit 0e9cd14
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 49 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ jobs:
version: ${{ matrix.version.golangci }}
- name: Run tests
run: make test
- name: Run benchmarks
run: make bench
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## [0.2.0] - 2023-03-30

### Changed

- Readme updates.
- `internal` package - has no change to library functionality.
- Rename `maxTokens` to `max` when constructing new rate limiter.
- Fine-grained locking for rate limiter.
- Rename internal struct fields.
- Rename package (it was wrongly packaged under ratelimiter)
- Use nanosecnd precision for timestamps.

### Fixed

- Make rate limiter more thread safe. There was potential for race conditons prior to this.

## [0.1.0] - 2023-03-30

Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: test build lint fmt
.PHONY: test bench build lint fmt

lint := go run github.com/golangci/golangci-lint/cmd/[email protected]

Expand All @@ -22,3 +22,7 @@ lint:
test:
@echo "Running tests..."
go test -v -race ./...

bench:
@echo "Running benchmarks..."
go test -bench=. -count 5 -run=^# ./...
20 changes: 19 additions & 1 deletion internal/bucket/bucket.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package bucket

import (
"sync"
)

// Bucket represents a token bucket.
//
// It is not safe for concurrent use.
// It is safe for concurrent use.
type Bucket struct {
m sync.Mutex
// size is the max tokens the bucket can hold.
size uint64

Expand All @@ -16,6 +21,7 @@ type Bucket struct {
// size number of max tokens.
func New(size uint64) *Bucket {
return &Bucket{
m: sync.Mutex{},
size: size,
available: size,
}
Expand All @@ -34,6 +40,9 @@ func (b *Bucket) Take() bool {
// If n tokens are not available, no tokens are removed
// from the bucket.
func (b *Bucket) TakeN(n uint64) bool {
b.m.Lock()
defer b.m.Unlock()

if b.available >= n {
b.available -= n
return true
Expand All @@ -47,6 +56,9 @@ func (b *Bucket) TakeN(n uint64) bool {
// If the quantity to refill exceeds the size of the bucket,
// the bucket is refilled upto its size.
func (b *Bucket) Refill(n uint64) {
b.m.Lock()
defer b.m.Unlock()

t := b.available + n
if t > b.size {
t = b.size
Expand All @@ -56,10 +68,16 @@ func (b *Bucket) Refill(n uint64) {

// Available returns the tokens currently available in the bucket.
func (b *Bucket) Available() uint64 {
b.m.Lock()
defer b.m.Unlock()

return b.available
}

// Size returns the max tokens a bucket can have.
func (b *Bucket) Size() uint64 {
b.m.Lock()
defer b.m.Unlock()

return b.size
}
9 changes: 9 additions & 0 deletions internal/bucket/bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,12 @@ func TestBucket_Refill(t *testing.T) {
b.Refill(20)
assert.EqualValues(t, b.Available(), 10)
}

func BenchmarkBucket(b *testing.B) {
b.ReportAllocs()

bucket := bucket.New(uint64(b.N) * 1000000000000)
for i := 0; i < b.N; i++ {
bucket.Take()
}
}
2 changes: 1 addition & 1 deletion opt.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ratelimiter
package ratelimit

type Opt func(r *RateLimiter)

Expand Down
56 changes: 27 additions & 29 deletions ratelimiter.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ratelimiter
package ratelimit

import (
"context"
Expand All @@ -20,23 +20,22 @@ type NowFunc = func() time.Time
// Most callers should use either Wait or WaitN to wait for tokens
// to be available.
type RateLimiter struct {
m sync.Mutex
// bucket is the underlying storage structure
// for the rate limiter.
// bucket is the underlying storage for the rate limiter.
bucket *bucket.Bucket

// lastRefillAt keeps track of the time when the last
// refresh of tokens took place.
//
// It is kept track in milliseconds.
lastRefillAt int64

// refillEvery is the duration after which tokens are refilled.
// refillDuration is the duration after which tokens are refilled.
//
// The duration is calculated based on the limit specified
// at creation time.
refillEvery time.Duration
now NowFunc
refillDuration time.Duration
now NowFunc

m sync.Mutex
// lastRefillUnixNs keeps track of the time when the last
// refresh of tokens took place.
//
// It is kept track in nanoseconds.
lastRefillUnixNs int64
}

// New constructs a rate limiter that accepts the max tokens (size) that
Expand All @@ -58,15 +57,15 @@ func New(
}

r := &RateLimiter{
bucket: bucket.New(max),
refillEvery: time.Duration(float64(per) / float64(limit)),
now: time.Now,
bucket: bucket.New(max),
refillDuration: per / time.Duration(limit),
now: time.Now,
}

for _, opt := range opts {
opt(r)
}
r.lastRefillAt = r.now().Unix()
r.lastRefillUnixNs = r.now().UnixNano()

return r, nil
}
Expand All @@ -79,13 +78,17 @@ func New(
// the token to be taken.
func (r *RateLimiter) refill() {
r.m.Lock()
defer r.m.Unlock()
lastRefill := r.lastRefillUnixNs
r.m.Unlock()

now := r.now()
t := (now.UnixMilli() - r.lastRefillAt) / r.refillEvery.Milliseconds()
if t > 0 {
r.lastRefillAt = now.UnixMilli()
r.bucket.Refill(uint64(t))
tokens := (now.UnixNano() - lastRefill) / r.refillDuration.Nanoseconds()
if tokens > 0 {
r.m.Lock()
r.lastRefillUnixNs = now.UnixNano()
r.m.Unlock()

r.bucket.Refill(uint64(tokens))
}
}

Expand All @@ -105,12 +108,7 @@ func (r *RateLimiter) Add() bool {
// Its behaviour is details in bucket.takeN.
func (r *RateLimiter) AddN(n uint64) bool {
r.refill()

r.m.Lock()
ok := r.bucket.TakeN(n)
r.m.Unlock()

return ok
return r.bucket.TakeN(n)
}

// Wait blocks until a token is available.
Expand All @@ -137,7 +135,7 @@ func (r *RateLimiter) WaitN(ctx context.Context, n uint64) error {
}

// Check refillEvery duration to see if a new token is available.
t := time.NewTicker(r.refillEvery)
t := time.NewTicker(r.refillDuration)
defer t.Stop()

for {
Expand Down
43 changes: 26 additions & 17 deletions ratelimiter_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ratelimiter_test
package ratelimit_test

import (
"context"
Expand All @@ -8,11 +8,11 @@ import (

"github.com/stretchr/testify/assert"

ratelimiter "github.com/vivangkumar/ratelimit"
"github.com/vivangkumar/ratelimit"
)

func TestRateLimiter_Add_TokensAvailable(t *testing.T) {
r, err := ratelimiter.New(100, 100, 1*time.Second)
r, err := ratelimit.New(100, 100, 1*time.Second)
assert.Nil(t, err)

assert.True(t, r.Add())
Expand All @@ -22,27 +22,27 @@ func TestRateLimiter_Add_TokensAvailable(t *testing.T) {
}

func TestRateLimiter_NewError(t *testing.T) {
_, err := ratelimiter.New(100, 0, 1*time.Second)
_, err := ratelimit.New(100, 0, 1*time.Second)
assert.Error(t, err)
}

func TestRateLimiter_AddN(t *testing.T) {
r, err := ratelimiter.New(10, 10, 1*time.Second)
r, err := ratelimit.New(10, 10, 1*time.Second)
assert.Nil(t, err)

assert.True(t, r.AddN(10))
}

func TestRateLimiter_Add_NoTokens(t *testing.T) {
r, err := ratelimiter.New(1, 1, 1*time.Second)
r, err := ratelimit.New(1, 1, 1*time.Second)
assert.Nil(t, err)

assert.True(t, r.Add())
assert.False(t, r.Add())
}

func TestRateLimiter_RefreshToken(t *testing.T) {
r, err := ratelimiter.New(10, 10, 1*time.Second)
r, err := ratelimit.New(10, 10, 1*time.Second)
assert.Nil(t, err)

assert.True(t, r.AddN(10))
Expand All @@ -58,7 +58,7 @@ func TestRateLimiter_RefreshToken(t *testing.T) {
}

func TestRateLimiter_Wait(t *testing.T) {
r, err := ratelimiter.New(100, 10, 1*time.Second)
r, err := ratelimit.New(100, 10, 1*time.Second)
assert.Nil(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
Expand All @@ -72,7 +72,7 @@ func TestRateLimiter_Wait(t *testing.T) {
}

func TestRateLimiter_WaitN(t *testing.T) {
r, err := ratelimiter.New(100, 10, 1*time.Second)
r, err := ratelimit.New(100, 10, 1*time.Second)
assert.Nil(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
Expand All @@ -87,7 +87,7 @@ func TestRateLimiter_WaitN(t *testing.T) {
}

func TestRateLimiter_WaitN_ExceedsMaxTokens(t *testing.T) {
r, err := ratelimiter.New(100, 10, 1*time.Second)
r, err := ratelimit.New(100, 10, 1*time.Second)
assert.Nil(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
Expand All @@ -102,7 +102,7 @@ func TestRateLimiter_WaitN_ExceedsMaxTokens(t *testing.T) {
}

func TestRateLimiter_WaitCancel(t *testing.T) {
r, err := ratelimiter.New(100, 10, 1*time.Minute)
r, err := ratelimit.New(100, 10, 1*time.Minute)
assert.Nil(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
Expand All @@ -115,16 +115,25 @@ func TestRateLimiter_WaitCancel(t *testing.T) {
assert.Error(t, err)
}

func BenchmarkRateLimiter(b *testing.B) {
b.ReportAllocs()

rl, _ := ratelimit.New(uint64(b.N), uint64(b.N), 1*time.Second)
for i := 0; i < b.N; i++ {
rl.Add()
}
}

func ExampleNew() {
// Create a new rate limiter instance.
_, err := ratelimiter.New(100, 10, time.Second)
_, err := ratelimit.New(100, 10, time.Second)
if err != nil {
fmt.Printf("context cancelled waiting for token: %s\n", err.Error())
}
}

func ExampleRateLimiter_Add() {
r, err := ratelimiter.New(100, 10, time.Second)
r, err := ratelimit.New(100, 10, time.Second)
if err != nil {
fmt.Printf("context cancelled waiting for token: %s\n", err.Error())
}
Expand All @@ -138,7 +147,7 @@ func ExampleRateLimiter_Add() {
}

func ExampleRateLimiter_AddN() {
r, err := ratelimiter.New(100, 10, time.Second)
r, err := ratelimit.New(100, 10, time.Second)
if err != nil {
fmt.Printf("error creating rate limiter: %s\n", err.Error())
}
Expand All @@ -152,7 +161,7 @@ func ExampleRateLimiter_AddN() {
}

func ExampleRateLimiter_Wait() {
r, err := ratelimiter.New(100, 10, time.Second)
r, err := ratelimit.New(100, 10, time.Second)
if err != nil {
fmt.Printf("error creating rate limiter: %s\n", err.Error())
}
Expand All @@ -169,7 +178,7 @@ func ExampleRateLimiter_Wait() {
}

func ExampleRateLimiter_WaitN() {
r, err := ratelimiter.New(100, 10, time.Second)
r, err := ratelimit.New(100, 10, time.Second)
if err != nil {
fmt.Printf("error creating rate limiter: %s\n", err.Error())
}
Expand All @@ -186,7 +195,7 @@ func ExampleRateLimiter_WaitN() {
}

func ExampleRateLimiter() {
r, err := ratelimiter.New(100, 10, 1*time.Second)
r, err := ratelimit.New(100, 10, 1*time.Second)
if err != nil {
fmt.Printf("error creating rate limiter: %s\n", err.Error())
}
Expand Down

0 comments on commit 0e9cd14

Please sign in to comment.