Skip to content

Commit

Permalink
Feat: Stringer Identity + Optimize
Browse files Browse the repository at this point in the history
  - Reduce potential debug contention by using cmpandswap atomics

  - Add the ability to use fmt.Stringers for Identity functionality (not sure why i ever did anything else tbh)

  - Complete test coverage
  • Loading branch information
yunginnanet committed Dec 18, 2022
1 parent e008c05 commit d7d8a0c
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 57 deletions.
39 changes: 19 additions & 20 deletions debug.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package rate5

import "fmt"
import (
"fmt"
"sync/atomic"
)

func (q *Limiter) debugPrintf(format string, a ...interface{}) {
q.debugMutex.RLock()
defer q.debugMutex.RUnlock()
if !q.debug {
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) {
return
}
msg := fmt.Sprintf(format, a...)
select {
case q.debugChannel <- msg:
//
default:
println(msg)
// drop the message but increment the lost counter
atomic.AddInt64(&q.debugLost, 1)
}
}

Expand All @@ -23,26 +26,22 @@ func (q *Limiter) setDebugEvict() {
}

func (q *Limiter) SetDebug(on bool) {
q.debugMutex.Lock()
if !on {
q.debug = false
q.Patrons.OnEvicted(nil)
q.debugMutex.Unlock()
return
switch on {
case true:
atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled)
q.debugPrintf("rate5 debug enabled")
case false:
atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled)
}
q.debug = on
q.setDebugEvict()
q.debugMutex.Unlock()
q.debugPrintf("rate5 debug enabled")
}

// DebugChannel enables debug mode and returns a channel where debug messages are sent.
// NOTE: You must read from this channel if created via this function or it will block
//
// NOTE: If you do not read from this channel, the debug messages will eventually be lost.
// If this happens,
func (q *Limiter) DebugChannel() chan string {
defer func() {
q.debugMutex.Lock()
q.debug = true
q.debugMutex.Unlock()
atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled)
}()
q.debugMutex.RLock()
if q.debugChannel != nil {
Expand All @@ -52,7 +51,7 @@ func (q *Limiter) DebugChannel() chan string {
q.debugMutex.RUnlock()
q.debugMutex.Lock()
defer q.debugMutex.Unlock()
q.debugChannel = make(chan string, 25)
q.debugChannel = make(chan string, 55)
q.setDebugEvict()
return q.debugChannel
}
21 changes: 18 additions & 3 deletions models.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rate5

import (
"fmt"
"sync"

"github.com/patrickmn/go-cache"
Expand All @@ -18,19 +19,33 @@ type Identity interface {
UniqueKey() string
}

// IdentityStringer is an implentation of Identity that acts as a shim for types that implement fmt.Stringer.
type IdentityStringer struct {
stringer fmt.Stringer
}

func (i IdentityStringer) UniqueKey() string {
return i.stringer.String()
}

const (
DebugDisabled uint32 = iota
DebugEnabled
)

// Limiter implements an Enforcer to create an arbitrary ratelimiter.
type Limiter struct {
// Source is the implementation of the Identity interface. It is used to create a unique key for each request.
Source Identity
// Patrons gives access to the underlying cache type that powers the ratelimiter.
// It is exposed for testing purposes.
Patrons *cache.Cache

// Ruleset determines the Policy which is used to determine whether or not to ratelimit.
// It consists of a Window and Burst, see Policy for more details.
Ruleset Policy

debug bool
debug uint32
debugChannel chan string
debugLost int64
known map[interface{}]*int64
debugMutex *sync.RWMutex
*sync.RWMutex
Expand Down
19 changes: 16 additions & 3 deletions ratelimiter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rate5

import (
"fmt"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -57,10 +58,12 @@ func NewStrictLimiter(window int, burst int) *Limiter {
})
}

/*NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
/*
NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
Hardcore mode causes the time limited to be multiplied by the number of hits.
This differs from strict mode which is only using addition instead of multiplication.*/
This differs from strict mode which is only using addition instead of multiplication.
*/
func NewHardcoreLimiter(window int, burst int) *Limiter {
l := NewStrictLimiter(window, burst)
l.Ruleset.Hardcore = true
Expand All @@ -80,7 +83,7 @@ func newLimiter(policy Policy) *Limiter {
known: make(map[interface{}]*int64),
RWMutex: &sync.RWMutex{},
debugMutex: &sync.RWMutex{},
debug: false,
debug: DebugDisabled,
}
}

Expand Down Expand Up @@ -122,6 +125,11 @@ func (q *Limiter) strictLogic(src string, count int64) {
q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime)
}

func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
targ := IdentityStringer{stringer: from}
return q.Check(targ)
}

// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
func (q *Limiter) Check(from Identity) (limited bool) {
var count int64
Expand Down Expand Up @@ -159,3 +167,8 @@ func (q *Limiter) Peek(from Identity) bool {
}
return false
}

func (q *Limiter) PeekStringer(from fmt.Stringer) bool {
targ := IdentityStringer{stringer: from}
return q.Peek(targ)
}
93 changes: 62 additions & 31 deletions ratelimiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ var (
)

func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
t.Helper()
watchDebugMutex.Lock()
defer watchDebugMutex.Unlock()
rd := r.DebugChannel()
Expand All @@ -68,25 +67,28 @@ func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
}
}

func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe bool) {
t.Helper()
func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe, stringer bool) {
limited := limiter.Peek(dummyTicker)
if stringer {
limited = limiter.PeekStringer(dummyTicker)
}
switch {
case limiter.Peek(dummyTicker) && !shouldbe:
case limited && !shouldbe:
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
case !limiter.Peek(dummyTicker) && shouldbe:
case !limited && shouldbe:
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
case limiter.Peek(dummyTicker) && shouldbe:
t.Logf("dummyTicker is limited")
case !limiter.Peek(dummyTicker) && !shouldbe:
t.Logf("dummyTicker is not limited")
case limited && shouldbe:
t.Logf("dummyTicker is limited as expected.")
case !limited && !shouldbe:
t.Logf("dummyTicker is not limited as expected.")
}
}

Expand All @@ -105,6 +107,10 @@ func (tick *ticker) UniqueKey() string {
return "TestItem"
}

func (tick *ticker) String() string {
return "TestItem"
}

func Test_ResetItem(t *testing.T) {
limiter := NewLimiter(500, 1)
ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -114,26 +120,36 @@ func Test_ResetItem(t *testing.T) {
limiter.Check(dummyTicker)
}
limiter.ResetItem(dummyTicker)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
cancel()
}

func Test_NewDefaultLimiter(t *testing.T) {
limiter := NewDefaultLimiter()
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
for n := 0; n != DefaultBurst; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
}

func Test_CheckAndPeekStringer(t *testing.T) {
limiter := NewDefaultLimiter()
limiter.CheckStringer(dummyTicker)
peekCheckLimited(t, limiter, false, true)
for n := 0; n != DefaultBurst; n++ {
limiter.CheckStringer(dummyTicker)
}
peekCheckLimited(t, limiter, true, true)
}

func Test_NewLimiter(t *testing.T) {
limiter := NewLimiter(5, 1)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
}

func Test_NewDefaultStrictLimiter(t *testing.T) {
Expand All @@ -144,9 +160,9 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
for n := 0; n < 25; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
cancel()
limiter = nil
}
Expand All @@ -156,23 +172,23 @@ func Test_NewStrictLimiter(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
limiter.Check(dummyTicker)
// for coverage, first we give the debug messages a couple seconds to be safe,
// then we wait for the cache eviction to trigger a debug message.
time.Sleep(2 * time.Second)
t.Logf(<-limiter.DebugChannel())
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
for n := 0; n != 6; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
time.Sleep(5 * time.Second)
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
time.Sleep(8 * time.Second)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
cancel()
limiter = nil
}
Expand All @@ -184,35 +200,35 @@ func Test_NewHardcoreLimiter(t *testing.T) {
for n := 0; n != 4; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
if !limiter.Check(dummyTicker) {
t.Errorf("Should have been limited")
}
t.Logf("limited once, waiting for cache eviction")
time.Sleep(2 * time.Second)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
for n := 0; n != 4; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
if !limiter.Check(dummyTicker) {
t.Errorf("Should have been limited")
}
limiter.Check(dummyTicker)
limiter.Check(dummyTicker)
time.Sleep(3 * time.Second)
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
time.Sleep(5 * time.Second)
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
for n := 0; n != 4; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
peekCheckLimited(t, limiter, false, false)
for n := 0; n != 10; n++ {
limiter.Check(dummyTicker)
}
time.Sleep(10 * time.Second)
peekCheckLimited(t, limiter, true)
peekCheckLimited(t, limiter, true, false)
cancel()
// for coverage, triggering the switch statement case for hardcore logic
limiter2 := NewHardcoreLimiter(2, 5)
Expand All @@ -221,9 +237,9 @@ func Test_NewHardcoreLimiter(t *testing.T) {
for n := 0; n != 6; n++ {
limiter2.Check(dummyTicker)
}
peekCheckLimited(t, limiter2, true)
peekCheckLimited(t, limiter2, true, false)
time.Sleep(4 * time.Second)
peekCheckLimited(t, limiter2, false)
peekCheckLimited(t, limiter2, false, false)
cancel2()
}

Expand Down Expand Up @@ -314,3 +330,18 @@ func Test_ConcurrentShouldLimit(t *testing.T) {
concurrentTest(t, 50, 21, 20, true)
concurrentTest(t, 50, 51, 50, true)
}

func Test_debugChannelOverflow(t *testing.T) {
limiter := NewDefaultLimiter()
_ = limiter.DebugChannel()
for n := 0; n != 78; n++ {
limiter.Check(dummyTicker)
if limiter.debugLost > 0 {
t.Fatalf("debug channel overflowed")
}
}
limiter.Check(dummyTicker)
if limiter.debugLost == 0 {
t.Fatalf("debug channel did not overflow")
}
}

0 comments on commit d7d8a0c

Please sign in to comment.