diff --git a/web/fiber/middleware/limiter.go b/web/fiber/middleware/limiter.go index c0dfe4a..038cf12 100644 --- a/web/fiber/middleware/limiter.go +++ b/web/fiber/middleware/limiter.go @@ -17,19 +17,25 @@ const ( XRateLimitReset = "X-RateLimit-Reset" ) -// DefaultLimitMessage 触发限制时的 429 消息 -var DefaultLimitMessage = "处理中的请求数过多" +var ( + // DefaultLimitMessage 触发限制时的 429 消息 + DefaultLimitMessage = "处理中的请求数过多" + + // DefaultLimitReached 达到限制时执行的 Hook + DefaultLimitReached = func(c *fiber.Ctx) error { + return response.APIException(c, http.StatusTooManyRequests, DefaultLimitMessage, nil) + } +) // RequestsLimiter 同时处理的请求数限制 // 请求处理中 +1, 请求已返回: -1 // 请求超限返回 429 错误 type RequestsLimiter struct { - count atomic.Uint64 limit atomic.Int32 running atomic.Int32 + limited atomic.Uint64 - // 达到限制时的响应错误消息 - msg string + limitReached fiber.Handler } func (r *RequestsLimiter) Allow() bool { @@ -43,14 +49,22 @@ func (r *RequestsLimiter) Allow() bool { return false } if r.running.CompareAndSwap(n, n+1) { - r.count.Add(1) return true } } } -func (r *RequestsLimiter) Count() uint64 { - return r.count.Load() +func (r *RequestsLimiter) Stats() map[string]int { + return map[string]int{ + "Limit": int(r.Limit()), + "Limited": int(r.Limited()), + "Running": int(r.Running()), + "Remaining": int(r.Remaining()), + } +} + +func (r *RequestsLimiter) Limited() uint64 { + return r.limited.Load() } func (r *RequestsLimiter) SetLimit(n int32) { @@ -76,7 +90,8 @@ func (r *RequestsLimiter) Remaining() int32 { func (r *RequestsLimiter) Handler() fiber.Handler { return func(c *fiber.Ctx) error { if !r.Allow() { - return response.APIException(c, http.StatusTooManyRequests, r.msg, nil) + r.limited.Add(1) + return r.limitReached(c) } defer r.running.Add(-1) @@ -89,15 +104,15 @@ func (r *RequestsLimiter) Handler() fiber.Handler { // NewDefaultRequestsLimiter 使用配置文件参数创建限制器 // app.Use(middleware.NewDefaultRequestsLimiter().Handler()) func NewDefaultRequestsLimiter() *RequestsLimiter { - return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitMessage) + return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitReached) } -func NewRequestsLimiter(limit int32, msg string) *RequestsLimiter { - if msg == "" { - msg = DefaultLimitMessage +func NewRequestsLimiter(limit int32, limitReached fiber.Handler) *RequestsLimiter { + if limitReached == nil { + limitReached = DefaultLimitReached } r := &RequestsLimiter{ - msg: msg, + limitReached: limitReached, } r.limit.Store(limit) return r diff --git a/web/gin/middleware/limiter.go b/web/gin/middleware/limiter.go index f577f5c..7ca209d 100644 --- a/web/gin/middleware/limiter.go +++ b/web/gin/middleware/limiter.go @@ -17,19 +17,25 @@ const ( XRateLimitReset = "X-RateLimit-Reset" ) -// DefaultLimitMessage 触发限制时的 429 消息 -var DefaultLimitMessage = "处理中的请求数过多" +var ( + // DefaultLimitMessage 触发限制时的 429 消息 + DefaultLimitMessage = "处理中的请求数过多" + + // DefaultLimitReached 达到限制时执行的 Hook + DefaultLimitReached = func(c *gin.Context) { + response.APIException(c, http.StatusTooManyRequests, DefaultLimitMessage, nil) + } +) // RequestsLimiter 同时处理的请求数限制 // 请求处理中 +1, 请求已返回: -1 // 请求超限返回 429 错误 type RequestsLimiter struct { - count atomic.Uint64 limit atomic.Int32 running atomic.Int32 + limited atomic.Uint64 - // 达到限制时的响应错误消息 - msg string + limitReached gin.HandlerFunc } func (r *RequestsLimiter) Allow() bool { @@ -43,14 +49,22 @@ func (r *RequestsLimiter) Allow() bool { return false } if r.running.CompareAndSwap(n, n+1) { - r.count.Add(1) return true } } } -func (r *RequestsLimiter) Count() uint64 { - return r.count.Load() +func (r *RequestsLimiter) Stats() map[string]int { + return map[string]int{ + "Limit": int(r.Limit()), + "Limited": int(r.Limited()), + "Running": int(r.Running()), + "Remaining": int(r.Remaining()), + } +} + +func (r *RequestsLimiter) Limited() uint64 { + return r.limited.Load() } func (r *RequestsLimiter) SetLimit(n int32) { @@ -76,7 +90,8 @@ func (r *RequestsLimiter) Remaining() int32 { func (r *RequestsLimiter) Handler() gin.HandlerFunc { return func(c *gin.Context) { if !r.Allow() { - response.APIException(c, http.StatusTooManyRequests, r.msg, nil) + r.limited.Add(1) + r.limitReached(c) return } @@ -90,15 +105,15 @@ func (r *RequestsLimiter) Handler() gin.HandlerFunc { // NewDefaultRequestsLimiter 使用配置文件参数创建限制器 // app.Use(middleware.NewDefaultRequestsLimiter().Handler()) func NewDefaultRequestsLimiter() *RequestsLimiter { - return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitMessage) + return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitReached) } -func NewRequestsLimiter(limit int32, msg string) *RequestsLimiter { - if msg == "" { - msg = DefaultLimitMessage +func NewRequestsLimiter(limit int32, limitReached gin.HandlerFunc) *RequestsLimiter { + if limitReached == nil { + limitReached = DefaultLimitReached } r := &RequestsLimiter{ - msg: msg, + limitReached: limitReached, } r.limit.Store(limit) return r