forked from hashicorp/go-connlimit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnlimit.go
233 lines (199 loc) · 7.23 KB
/
connlimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
package connlimit
import (
"errors"
"fmt"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
)
var (
// ErrPerClientIPLimitReached is returned if accepting a new conn would exceed
// the per-client-ip limit set.
ErrPerClientIPLimitReached = errors.New("client connection limit reached")
tooManyConnsMsg = "Your IP is issuing too many concurrent connections, please rate limit your calls\n"
tooManyRequestsResponse = []byte(fmt.Sprintf("HTTP/1.1 429 Too Many Requests\r\n"+
"Content-Type: text/plain\r\n"+
"Content-Length: %d\r\n"+
"Connection: close\r\n\r\n%s", len(tooManyConnsMsg), tooManyConnsMsg))
)
// Limiter implements a simple limiter that tracks the number of connections
// from each client IP. It may be used in it's zero value although no limits
// will be configured initially - they can be set later with SetConfig.
type Limiter struct {
// cs stores the map of active connections by IP address. We store a set of
// conn pointers not just a counter because http.Server.ConnState hook only
// gives us a connection object between calls so we need to know if a closed
// conn is one that was previously accepted or one we've just closed in the
// ConnState hook because the client has hit its limit.
cs map[string]map[net.Conn]struct{}
// l protects access to cs
l sync.Mutex
// cfg is stored atomically to provide non-blocking reads via Config. This
// might be important if this is called regularly in a health or metrics
// endpoint and shouldn't block new connections being established.
cfg atomic.Value
}
// Config is the configuration for the limiter.
type Config struct {
// MaxConnsPerClientIP limits how many concurrent connections are allowed from
// a given client IP. The IP is the one reported by the connection so cannot
// be relied upon if clients are connecting through multiple proxies or able
// to spoof their source IP address in some way. Similarly, multiple clients
// connected via a proxy or NAT gateway or similar will all be seen as coming
// from the same IP and so limited as one client.
MaxConnsPerClientIP int
}
// NewLimiter returns a limiter with the specified config.
func NewLimiter(cfg Config) *Limiter {
l := &Limiter{}
l.SetConfig(cfg)
return l
}
// Accept is called as early as possible when handling a new conn. If the
// connection should be accepted according to the Limiter's Config, it will
// return a free func and nil error. The free func must be called when the
// connection is no longer being handled - typically in a defer statement in the
// main connection handling goroutine, this will decrement the counter for that
// client IP. If the configured limit has been reached, a no-op func is returned
// (doesn't need to be called), and ErrPerClientIPLimitReached is returned.
//
// If any other error is returned it signifies something wrong with the config
// or transient failure to read or parse the remote IP. The free func will be a
// no-op in this case and need not be called.
func (l *Limiter) Accept(conn net.Conn) (func(), error) {
addrKey := connKey(conn)
// Load config outside locked section since it's not updated under lock anyway
// and the atomic Load might be slower/contented so better to do outside lock.
cfg := l.Config()
l.l.Lock()
defer l.l.Unlock()
if l.cs == nil {
l.cs = make(map[string]map[net.Conn]struct{})
}
cs := l.cs[addrKey]
if cs == nil {
cs = make(map[net.Conn]struct{})
l.cs[addrKey] = cs
}
n := len(cs)
// Might be greater since config is dynamic.
if cfg.MaxConnsPerClientIP > 0 && n >= cfg.MaxConnsPerClientIP {
return func() {}, ErrPerClientIPLimitReached
}
// Add the conn to the map
cs[conn] = struct{}{}
// Create a free func over the address key we used
free := func() {
l.freeConn(conn)
}
return free, nil
}
func (l *Limiter) NumOpen(addr net.Addr) int {
addrKey := addrKey(addr)
l.l.Lock()
defer l.l.Unlock()
if l.cs == nil {
return 0
}
cs := l.cs[addrKey]
if cs == nil {
return 0
}
return len(cs)
}
func connKey(conn net.Conn) string {
return addrKey(conn.RemoteAddr())
}
func addrKey(addr net.Addr) string {
switch a := addr.(type) {
case *net.TCPAddr:
return "ip:" + a.IP.String()
case *net.UDPAddr:
return "ip:" + a.IP.String()
case *net.IPAddr:
return "ip:" + a.IP.String()
default:
// not sure what to do with this, just assume whole Addr is relevant?
return addr.Network() + "/" + addr.String()
}
}
// freeConn removes a connection from the map if it's present. It is a no-op if
// the conn was never accepted by Accept.
func (l *Limiter) freeConn(conn net.Conn) {
addrKey := connKey(conn)
l.l.Lock()
defer l.l.Unlock()
cs, ok := l.cs[addrKey]
if !ok {
return
}
delete(cs, conn)
if len(cs) == 0 {
delete(l.cs, addrKey)
}
}
// Config returns the current limiter configuration. It is safe to call from any
// goroutine and does not block new connections being accepted.
func (l *Limiter) Config() Config {
cfgRaw := l.cfg.Load()
if cfg, ok := cfgRaw.(Config); ok {
return cfg
}
return Config{}
}
// SetConfig dynamically updates the limiter configuration. It is safe to call
// from any goroutine. Note that if the limit is lowered, active conns will not
// be closed and may remain over the limit until they close naturally.
func (l *Limiter) SetConfig(c Config) {
l.cfg.Store(c)
}
// HTTPConnStateFuncWithErrorHandler returns a func that can be passed as the ConnState field of
// an http.Server. This intercepts new HTTP connections to the server and
// applies the limiting to new connections.
//
// Note that if the conn is hijacked from the HTTP server then it will be freed
// in the limiter as if it was closed. Servers that use Hijacking must implement
// their own calls if they need to continue limiting the number of concurrent
// hijacked connections.
// errorHandler MUST close the connection itself
func (l *Limiter) HTTPConnStateFuncWithErrorHandler(errorHandler func(error, net.Conn)) func(net.Conn, http.ConnState) {
return func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
_, err := l.Accept(conn)
if err != nil {
errorHandler(err, conn)
}
case http.StateHijacked:
l.freeConn(conn)
case http.StateClosed:
// Maybe free the conn. This might be a conn we closed in the case above
// that was never counted as it was over limit but freeConn will be a
// no-op in that case.
l.freeConn(conn)
}
}
}
// HTTPConnStateFunc is here for ascending compatibility reasons.
func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) {
return l.HTTPConnStateFuncWithErrorHandler(func(err error, conn net.Conn) {
conn.Close()
})
}
// HTTPConnStateFuncWithDefault429Handler return an HTTP 429 if too many connections occur.
// BEWARE that returning HTTP 429 is done on critical path, you might choose to use
// HTTPConnStateFuncWithErrorHandler if you want to use a non-blocking strategy.
func (l *Limiter) HTTPConnStateFuncWithDefault429Handler(writeDeadlineMaxDelay time.Duration) func(net.Conn, http.ConnState) {
return l.HTTPConnStateFuncWithErrorHandler(func(err error, conn net.Conn) {
if err == ErrPerClientIPLimitReached {
// We don't care about slow players
if writeDeadlineMaxDelay > 0 {
conn.SetDeadline(time.Now().Add(writeDeadlineMaxDelay))
}
conn.Write(tooManyRequestsResponse)
}
conn.Close()
})
}