Skip to content

Commit

Permalink
perf: optimize limiter middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
fufuok committed Jul 2, 2024
1 parent 5482759 commit d1a8c79
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
43 changes: 29 additions & 14 deletions web/fiber/middleware/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@ const (
XRateLimitReset = "X-RateLimit-Reset"
)

// DefaultLimitMessage 触发限制时的 429 消息
var DefaultLimitMessage = "处理中的请求数过多"
var (
// DefaultLimitMessage 触发限制时的 429 消息
DefaultLimitMessage = "处理中的请求数过多"

// DefaultLimitReached 达到限制时执行的 Hook
DefaultLimitReached = func(c *fiber.Ctx) error {
return response.APIException(c, http.StatusTooManyRequests, DefaultLimitMessage, nil)
}
)

// RequestsLimiter 同时处理的请求数限制
// 请求处理中 +1, 请求已返回: -1
// 请求超限返回 429 错误
type RequestsLimiter struct {
count atomic.Uint64
limit atomic.Int32
running atomic.Int32
limited atomic.Uint64

// 达到限制时的响应错误消息
msg string
limitReached fiber.Handler
}

func (r *RequestsLimiter) Allow() bool {
Expand All @@ -43,14 +49,22 @@ func (r *RequestsLimiter) Allow() bool {
return false
}
if r.running.CompareAndSwap(n, n+1) {
r.count.Add(1)
return true
}
}
}

func (r *RequestsLimiter) Count() uint64 {
return r.count.Load()
func (r *RequestsLimiter) Stats() map[string]int {
return map[string]int{
"Limit": int(r.Limit()),
"Limited": int(r.Limited()),
"Running": int(r.Running()),
"Remaining": int(r.Remaining()),
}
}

func (r *RequestsLimiter) Limited() uint64 {
return r.limited.Load()
}

func (r *RequestsLimiter) SetLimit(n int32) {
Expand All @@ -76,7 +90,8 @@ func (r *RequestsLimiter) Remaining() int32 {
func (r *RequestsLimiter) Handler() fiber.Handler {
return func(c *fiber.Ctx) error {
if !r.Allow() {
return response.APIException(c, http.StatusTooManyRequests, r.msg, nil)
r.limited.Add(1)
return r.limitReached(c)
}

defer r.running.Add(-1)
Expand All @@ -89,15 +104,15 @@ func (r *RequestsLimiter) Handler() fiber.Handler {
// NewDefaultRequestsLimiter 使用配置文件参数创建限制器
// app.Use(middleware.NewDefaultRequestsLimiter().Handler())
func NewDefaultRequestsLimiter() *RequestsLimiter {
return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitMessage)
return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitReached)
}

func NewRequestsLimiter(limit int32, msg string) *RequestsLimiter {
if msg == "" {
msg = DefaultLimitMessage
func NewRequestsLimiter(limit int32, limitReached fiber.Handler) *RequestsLimiter {
if limitReached == nil {
limitReached = DefaultLimitReached
}
r := &RequestsLimiter{
msg: msg,
limitReached: limitReached,
}
r.limit.Store(limit)
return r
Expand Down
43 changes: 29 additions & 14 deletions web/gin/middleware/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@ const (
XRateLimitReset = "X-RateLimit-Reset"
)

// DefaultLimitMessage 触发限制时的 429 消息
var DefaultLimitMessage = "处理中的请求数过多"
var (
// DefaultLimitMessage 触发限制时的 429 消息
DefaultLimitMessage = "处理中的请求数过多"

// DefaultLimitReached 达到限制时执行的 Hook
DefaultLimitReached = func(c *gin.Context) {
response.APIException(c, http.StatusTooManyRequests, DefaultLimitMessage, nil)
}
)

// RequestsLimiter 同时处理的请求数限制
// 请求处理中 +1, 请求已返回: -1
// 请求超限返回 429 错误
type RequestsLimiter struct {
count atomic.Uint64
limit atomic.Int32
running atomic.Int32
limited atomic.Uint64

// 达到限制时的响应错误消息
msg string
limitReached gin.HandlerFunc
}

func (r *RequestsLimiter) Allow() bool {
Expand All @@ -43,14 +49,22 @@ func (r *RequestsLimiter) Allow() bool {
return false
}
if r.running.CompareAndSwap(n, n+1) {
r.count.Add(1)
return true
}
}
}

func (r *RequestsLimiter) Count() uint64 {
return r.count.Load()
func (r *RequestsLimiter) Stats() map[string]int {
return map[string]int{
"Limit": int(r.Limit()),
"Limited": int(r.Limited()),
"Running": int(r.Running()),
"Remaining": int(r.Remaining()),
}
}

func (r *RequestsLimiter) Limited() uint64 {
return r.limited.Load()
}

func (r *RequestsLimiter) SetLimit(n int32) {
Expand All @@ -76,7 +90,8 @@ func (r *RequestsLimiter) Remaining() int32 {
func (r *RequestsLimiter) Handler() gin.HandlerFunc {
return func(c *gin.Context) {
if !r.Allow() {
response.APIException(c, http.StatusTooManyRequests, r.msg, nil)
r.limited.Add(1)
r.limitReached(c)
return
}

Expand All @@ -90,15 +105,15 @@ func (r *RequestsLimiter) Handler() gin.HandlerFunc {
// NewDefaultRequestsLimiter 使用配置文件参数创建限制器
// app.Use(middleware.NewDefaultRequestsLimiter().Handler())
func NewDefaultRequestsLimiter() *RequestsLimiter {
return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitMessage)
return NewRequestsLimiter(config.Config().WebConf.RequestsLimit, DefaultLimitReached)
}

func NewRequestsLimiter(limit int32, msg string) *RequestsLimiter {
if msg == "" {
msg = DefaultLimitMessage
func NewRequestsLimiter(limit int32, limitReached gin.HandlerFunc) *RequestsLimiter {
if limitReached == nil {
limitReached = DefaultLimitReached
}
r := &RequestsLimiter{
msg: msg,
limitReached: limitReached,
}
r.limit.Store(limit)
return r
Expand Down

0 comments on commit d1a8c79

Please sign in to comment.