From d7d8a0ce8762b1086f90f81e4106f8295a6cd7fd Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Sun, 18 Dec 2022 02:38:43 -0800 Subject: [PATCH] Feat: Stringer Identity + Optimize - Reduce potential debug contention by using cmpandswap atomics - Add the ability to use fmt.Stringers for Identity functionality (not sure why i ever did anything else tbh) - Complete test coverage --- debug.go | 39 +++++++++---------- models.go | 21 ++++++++-- ratelimiter.go | 19 +++++++-- ratelimiter_test.go | 93 ++++++++++++++++++++++++++++++--------------- 4 files changed, 115 insertions(+), 57 deletions(-) diff --git a/debug.go b/debug.go index dabc386..67dd1fe 100644 --- a/debug.go +++ b/debug.go @@ -1,18 +1,21 @@ package rate5 -import "fmt" +import ( + "fmt" + "sync/atomic" +) func (q *Limiter) debugPrintf(format string, a ...interface{}) { - q.debugMutex.RLock() - defer q.debugMutex.RUnlock() - if !q.debug { + if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) { return } msg := fmt.Sprintf(format, a...) select { case q.debugChannel <- msg: + // default: - println(msg) + // drop the message but increment the lost counter + atomic.AddInt64(&q.debugLost, 1) } } @@ -23,26 +26,22 @@ func (q *Limiter) setDebugEvict() { } func (q *Limiter) SetDebug(on bool) { - q.debugMutex.Lock() - if !on { - q.debug = false - q.Patrons.OnEvicted(nil) - q.debugMutex.Unlock() - return + switch on { + case true: + atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled) + q.debugPrintf("rate5 debug enabled") + case false: + atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled) } - q.debug = on - q.setDebugEvict() - q.debugMutex.Unlock() - q.debugPrintf("rate5 debug enabled") } // DebugChannel enables debug mode and returns a channel where debug messages are sent. -// NOTE: You must read from this channel if created via this function or it will block +// +// NOTE: If you do not read from this channel, the debug messages will eventually be lost. +// If this happens, func (q *Limiter) DebugChannel() chan string { defer func() { - q.debugMutex.Lock() - q.debug = true - q.debugMutex.Unlock() + atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled) }() q.debugMutex.RLock() if q.debugChannel != nil { @@ -52,7 +51,7 @@ func (q *Limiter) DebugChannel() chan string { q.debugMutex.RUnlock() q.debugMutex.Lock() defer q.debugMutex.Unlock() - q.debugChannel = make(chan string, 25) + q.debugChannel = make(chan string, 55) q.setDebugEvict() return q.debugChannel } diff --git a/models.go b/models.go index 0da3cd1..48fd4c9 100644 --- a/models.go +++ b/models.go @@ -1,6 +1,7 @@ package rate5 import ( + "fmt" "sync" "github.com/patrickmn/go-cache" @@ -18,19 +19,33 @@ type Identity interface { UniqueKey() string } +// IdentityStringer is an implentation of Identity that acts as a shim for types that implement fmt.Stringer. +type IdentityStringer struct { + stringer fmt.Stringer +} + +func (i IdentityStringer) UniqueKey() string { + return i.stringer.String() +} + +const ( + DebugDisabled uint32 = iota + DebugEnabled +) + // Limiter implements an Enforcer to create an arbitrary ratelimiter. type Limiter struct { - // Source is the implementation of the Identity interface. It is used to create a unique key for each request. - Source Identity // Patrons gives access to the underlying cache type that powers the ratelimiter. // It is exposed for testing purposes. Patrons *cache.Cache + // Ruleset determines the Policy which is used to determine whether or not to ratelimit. // It consists of a Window and Burst, see Policy for more details. Ruleset Policy - debug bool + debug uint32 debugChannel chan string + debugLost int64 known map[interface{}]*int64 debugMutex *sync.RWMutex *sync.RWMutex diff --git a/ratelimiter.go b/ratelimiter.go index 59008c0..0b1f0c9 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -1,6 +1,7 @@ package rate5 import ( + "fmt" "sync" "sync/atomic" "time" @@ -57,10 +58,12 @@ func NewStrictLimiter(window int, burst int) *Limiter { }) } -/*NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled. +/* +NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled. Hardcore mode causes the time limited to be multiplied by the number of hits. -This differs from strict mode which is only using addition instead of multiplication.*/ +This differs from strict mode which is only using addition instead of multiplication. +*/ func NewHardcoreLimiter(window int, burst int) *Limiter { l := NewStrictLimiter(window, burst) l.Ruleset.Hardcore = true @@ -80,7 +83,7 @@ func newLimiter(policy Policy) *Limiter { known: make(map[interface{}]*int64), RWMutex: &sync.RWMutex{}, debugMutex: &sync.RWMutex{}, - debug: false, + debug: DebugDisabled, } } @@ -122,6 +125,11 @@ func (q *Limiter) strictLogic(src string, count int64) { q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime) } +func (q *Limiter) CheckStringer(from fmt.Stringer) bool { + targ := IdentityStringer{stringer: from} + return q.Check(targ) +} + // Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status. func (q *Limiter) Check(from Identity) (limited bool) { var count int64 @@ -159,3 +167,8 @@ func (q *Limiter) Peek(from Identity) bool { } return false } + +func (q *Limiter) PeekStringer(from fmt.Stringer) bool { + targ := IdentityStringer{stringer: from} + return q.Peek(targ) +} diff --git a/ratelimiter_test.go b/ratelimiter_test.go index 1755699..72350d1 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -48,7 +48,6 @@ var ( ) func watchDebug(ctx context.Context, r *Limiter, t *testing.T) { - t.Helper() watchDebugMutex.Lock() defer watchDebugMutex.Unlock() rd := r.DebugChannel() @@ -68,25 +67,28 @@ func watchDebug(ctx context.Context, r *Limiter, t *testing.T) { } } -func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe bool) { - t.Helper() +func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe, stringer bool) { + limited := limiter.Peek(dummyTicker) + if stringer { + limited = limiter.PeekStringer(dummyTicker) + } switch { - case limiter.Peek(dummyTicker) && !shouldbe: + case limited && !shouldbe: if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok { t.Errorf("Should not have been limited. Ratelimiter count: %d", ct) } else { t.Fatalf("dummyTicker does not exist in ratelimiter at all!") } - case !limiter.Peek(dummyTicker) && shouldbe: + case !limited && shouldbe: if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok { t.Errorf("Should have been limited. Ratelimiter count: %d", ct) } else { t.Fatalf("dummyTicker does not exist in ratelimiter at all!") } - case limiter.Peek(dummyTicker) && shouldbe: - t.Logf("dummyTicker is limited") - case !limiter.Peek(dummyTicker) && !shouldbe: - t.Logf("dummyTicker is not limited") + case limited && shouldbe: + t.Logf("dummyTicker is limited as expected.") + case !limited && !shouldbe: + t.Logf("dummyTicker is not limited as expected.") } } @@ -105,6 +107,10 @@ func (tick *ticker) UniqueKey() string { return "TestItem" } +func (tick *ticker) String() string { + return "TestItem" +} + func Test_ResetItem(t *testing.T) { limiter := NewLimiter(500, 1) ctx, cancel := context.WithCancel(context.Background()) @@ -114,26 +120,36 @@ func Test_ResetItem(t *testing.T) { limiter.Check(dummyTicker) } limiter.ResetItem(dummyTicker) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) cancel() } func Test_NewDefaultLimiter(t *testing.T) { limiter := NewDefaultLimiter() limiter.Check(dummyTicker) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) for n := 0; n != DefaultBurst; n++ { limiter.Check(dummyTicker) } - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) +} + +func Test_CheckAndPeekStringer(t *testing.T) { + limiter := NewDefaultLimiter() + limiter.CheckStringer(dummyTicker) + peekCheckLimited(t, limiter, false, true) + for n := 0; n != DefaultBurst; n++ { + limiter.CheckStringer(dummyTicker) + } + peekCheckLimited(t, limiter, true, true) } func Test_NewLimiter(t *testing.T) { limiter := NewLimiter(5, 1) limiter.Check(dummyTicker) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) limiter.Check(dummyTicker) - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) } func Test_NewDefaultStrictLimiter(t *testing.T) { @@ -144,9 +160,9 @@ func Test_NewDefaultStrictLimiter(t *testing.T) { for n := 0; n < 25; n++ { limiter.Check(dummyTicker) } - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) limiter.Check(dummyTicker) - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) cancel() limiter = nil } @@ -156,23 +172,23 @@ func Test_NewStrictLimiter(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go watchDebug(ctx, limiter, t) limiter.Check(dummyTicker) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) limiter.Check(dummyTicker) - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) limiter.Check(dummyTicker) // for coverage, first we give the debug messages a couple seconds to be safe, // then we wait for the cache eviction to trigger a debug message. time.Sleep(2 * time.Second) t.Logf(<-limiter.DebugChannel()) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) for n := 0; n != 6; n++ { limiter.Check(dummyTicker) } - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) time.Sleep(5 * time.Second) - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) time.Sleep(8 * time.Second) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) cancel() limiter = nil } @@ -184,35 +200,35 @@ func Test_NewHardcoreLimiter(t *testing.T) { for n := 0; n != 4; n++ { limiter.Check(dummyTicker) } - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) if !limiter.Check(dummyTicker) { t.Errorf("Should have been limited") } t.Logf("limited once, waiting for cache eviction") time.Sleep(2 * time.Second) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) for n := 0; n != 4; n++ { limiter.Check(dummyTicker) } - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) if !limiter.Check(dummyTicker) { t.Errorf("Should have been limited") } limiter.Check(dummyTicker) limiter.Check(dummyTicker) time.Sleep(3 * time.Second) - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) time.Sleep(5 * time.Second) - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) for n := 0; n != 4; n++ { limiter.Check(dummyTicker) } - peekCheckLimited(t, limiter, false) + peekCheckLimited(t, limiter, false, false) for n := 0; n != 10; n++ { limiter.Check(dummyTicker) } time.Sleep(10 * time.Second) - peekCheckLimited(t, limiter, true) + peekCheckLimited(t, limiter, true, false) cancel() // for coverage, triggering the switch statement case for hardcore logic limiter2 := NewHardcoreLimiter(2, 5) @@ -221,9 +237,9 @@ func Test_NewHardcoreLimiter(t *testing.T) { for n := 0; n != 6; n++ { limiter2.Check(dummyTicker) } - peekCheckLimited(t, limiter2, true) + peekCheckLimited(t, limiter2, true, false) time.Sleep(4 * time.Second) - peekCheckLimited(t, limiter2, false) + peekCheckLimited(t, limiter2, false, false) cancel2() } @@ -314,3 +330,18 @@ func Test_ConcurrentShouldLimit(t *testing.T) { concurrentTest(t, 50, 21, 20, true) concurrentTest(t, 50, 51, 50, true) } + +func Test_debugChannelOverflow(t *testing.T) { + limiter := NewDefaultLimiter() + _ = limiter.DebugChannel() + for n := 0; n != 78; n++ { + limiter.Check(dummyTicker) + if limiter.debugLost > 0 { + t.Fatalf("debug channel overflowed") + } + } + limiter.Check(dummyTicker) + if limiter.debugLost == 0 { + t.Fatalf("debug channel did not overflow") + } +}