-
Notifications
You must be signed in to change notification settings - Fork 14
/
reverse_proxy.go
394 lines (350 loc) · 11.1 KB
/
reverse_proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
/*
* Copyright 2023 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Copyright 2011 The Go Authors. All rights reserved.
* Use of this source code is governed by a BSD-style
* license that can be found in the LICENSE file.
*
* This file may have been modified by CloudWeGo Authors. All CloudWeGo
* Modifications are Copyright 2023 CloudWeGo Authors.
*/
package reverseproxy
import (
"bytes"
"context"
"fmt"
"net"
"net/textproto"
"reflect"
"strings"
"sync"
"time"
"unsafe"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/common/config"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)
type ReverseProxy struct {
client *client.Client
clientBehavior clientBehavior
// target is set as a reverse proxy address
Target string
// transferTrailer is whether to forward Trailer-related header
transferTrailer bool
// saveOriginResponse is whether to save the original response header
saveOriginResHeader bool
// director must be a function which modifies the request
// into a new request. Its response is then redirected
// back to the original client unmodified.
// director must not access the provided Request
// after returning.
director func(*protocol.Request)
// modifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional errorHandler is
// called without any call to modifyResponse.
//
// If modifyResponse returns an error, errorHandler is called
// with its error value. If errorHandler is nil, its default
// implementation is used.
modifyResponse func(*protocol.Response) error
// errorHandler is an optional function that handles errors
// reaching the backend or errors from modifyResponse.
//
// If nil, the default is to log the provided error and return
// a 502 Status Bad Gateway response.
errorHandler func(*app.RequestContext, error)
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
// NewSingleHostReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
// NewSingleHostReverseProxy does not rewrite the Host header.
// To rewrite Host headers, use ReverseProxy directly with a custom
// director policy.
//
// When passing config.ClientOption it will initialize a local client.Client instance.
// Using ReverseProxy.SetClient if there is need for shared customized client.Client instance.
func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*ReverseProxy, error) {
r := &ReverseProxy{
Target: target,
director: func(req *protocol.Request) {
req.SetRequestURI(b2s(JoinURLPath(req, target)))
req.Header.SetHostBytes(req.URI().Host())
},
}
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
return r, nil
}
func JoinURLPath(req *protocol.Request, target string) (path []byte) {
aslash := req.URI().Path()[0] == '/'
var bslash bool
if strings.HasPrefix(target, "http") {
// absolute path
bslash = strings.HasSuffix(target, "/")
} else {
// default redirect to local
bslash = strings.HasPrefix(target, "/")
if bslash {
target = fmt.Sprintf("%s%s", req.Host(), target)
} else {
target = fmt.Sprintf("%s/%s", req.Host(), target)
}
bslash = strings.HasSuffix(target, "/")
}
targetQuery := strings.Split(target, "?")
var buffer bytes.Buffer
buffer.WriteString(targetQuery[0])
switch {
case aslash && bslash:
buffer.Write(req.URI().Path()[1:])
case !aslash && !bslash:
buffer.Write([]byte{'/'})
buffer.Write(req.URI().Path())
default:
buffer.Write(req.URI().Path())
}
if len(targetQuery) > 1 {
buffer.Write([]byte{'?'})
buffer.WriteString(targetQuery[1])
}
if len(req.QueryString()) > 0 {
if len(targetQuery) == 1 {
buffer.Write([]byte{'?'})
} else {
buffer.Write([]byte{'&'})
}
buffer.Write(req.QueryString())
}
return buffer.Bytes()
}
// removeRequestConnHeaders removes hop-by-hop headers listed in the "Connection" header of h.
// See RFC 7230, section 6.1
func removeRequestConnHeaders(c *app.RequestContext) {
c.Request.Header.VisitAll(func(k, v []byte) {
if b2s(k) == "Connection" {
for _, sf := range strings.Split(b2s(v), ",") {
if sf = textproto.TrimString(sf); sf != "" {
c.Request.Header.DelBytes(s2b(sf))
}
}
}
})
}
// removeRespConnHeaders removes hop-by-hop headers listed in the "Connection" header of h.
// See RFC 7230, section 6.1
func removeResponseConnHeaders(c *app.RequestContext) {
c.Response.Header.VisitAll(func(k, v []byte) {
if b2s(k) == "Connection" {
for _, sf := range strings.Split(b2s(v), ",") {
if sf = textproto.TrimString(sf); sf != "" {
c.Response.Header.DelBytes(s2b(sf))
}
}
}
})
}
// checkTeHeader check RequestHeader if has 'Te: trailers'
// See https://github.com/golang/go/issues/21096
func checkTeHeader(header *protocol.RequestHeader) bool {
teHeaders := header.PeekAll("Te")
for _, te := range teHeaders {
if bytes.Contains(te, []byte("trailers")) {
return true
}
}
return false
}
func (r *ReverseProxy) defaultErrorHandler(c *app.RequestContext, _ error) {
c.Response.Header.SetStatusCode(consts.StatusBadGateway)
}
var respTmpHeaderPool = sync.Pool{
New: func() interface{} {
return make(map[string][]string)
},
}
func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
req := &ctx.Request
resp := &ctx.Response
// save tmp resp header
respTmpHeader := respTmpHeaderPool.Get().(map[string][]string)
if r.saveOriginResHeader {
resp.Header.SetNoDefaultContentType(true)
resp.Header.VisitAll(func(key, value []byte) {
keyStr := string(key)
valueStr := string(value)
if _, ok := respTmpHeader[keyStr]; !ok {
respTmpHeader[keyStr] = []string{valueStr}
} else {
respTmpHeader[keyStr] = append(respTmpHeader[keyStr], valueStr)
}
})
}
if r.director != nil {
r.director(&ctx.Request)
}
req.Header.ResetConnectionClose()
hasTeTrailer := false
if r.transferTrailer {
hasTeTrailer = checkTeHeader(&req.Header)
}
removeRequestConnHeaders(ctx)
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
if r.transferTrailer && h == "Trailer" {
continue
}
req.Header.DelBytes(s2b(h))
}
// Check if 'trailers' exists in te header, If exists, add an additional Te header
if r.transferTrailer && hasTeTrailer {
req.Header.Set("Te", "trailers")
}
// prepare request(replace headers and some URL host)
if ip, _, err := net.SplitHostPort(ctx.RemoteAddr().String()); err == nil {
tmp := req.Header.Peek("X-Forwarded-For")
if len(tmp) > 0 {
ip = fmt.Sprintf("%s, %s", tmp, ip)
}
if tmp == nil || string(tmp) != "" {
req.Header.Add("X-Forwarded-For", ip)
}
}
err := r.doClientBehavior(c, req, resp)
if err != nil {
hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error())
r.getErrorHandler()(ctx, err)
return
}
// add tmp resp header
for key, hs := range respTmpHeader {
for _, h := range hs {
resp.Header.Add(key, h)
}
}
// Clear and put respTmpHeader back to respTmpHeaderPool
for k := range respTmpHeader {
delete(respTmpHeader, k)
}
respTmpHeaderPool.Put(respTmpHeader)
removeResponseConnHeaders(ctx)
for _, h := range hopHeaders {
if r.transferTrailer && h == "Trailer" {
continue
}
resp.Header.DelBytes(s2b(h))
}
if r.modifyResponse == nil {
return
}
err = r.modifyResponse(resp)
if err != nil {
r.getErrorHandler()(ctx, err)
}
}
// SetDirector use to customize protocol.Request
func (r *ReverseProxy) SetDirector(director func(req *protocol.Request)) {
r.director = director
}
// SetClient use to customize client
func (r *ReverseProxy) SetClient(client *client.Client) {
r.client = client
}
// SetModifyResponse use to modify response
func (r *ReverseProxy) SetModifyResponse(mr func(*protocol.Response) error) {
r.modifyResponse = mr
}
// SetErrorHandler use to customize error handler
func (r *ReverseProxy) SetErrorHandler(eh func(c *app.RequestContext, err error)) {
r.errorHandler = eh
}
func (r *ReverseProxy) SetTransferTrailer(b bool) {
r.transferTrailer = b
}
func (r *ReverseProxy) SetSaveOriginResHeader(b bool) {
r.saveOriginResHeader = b
}
func (r *ReverseProxy) SetClientBehavior(cb clientBehavior) {
r.clientBehavior = cb
}
func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) {
if r.errorHandler != nil {
return r.errorHandler
}
return r.defaultErrorHandler
}
func (r *ReverseProxy) doClientBehavior(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
var err error
switch r.clientBehavior.clientBehaviorType {
case doDeadline:
deadline := r.clientBehavior.param.(time.Time)
err = r.client.DoDeadline(ctx, req, resp, deadline)
case doRedirects:
maxRedirectsCount := r.clientBehavior.param.(int)
err = r.client.DoRedirects(ctx, req, resp, maxRedirectsCount)
case doTimeout:
timeout := r.clientBehavior.param.(time.Duration)
err = r.client.DoTimeout(ctx, req, resp, timeout)
default:
err = r.client.Do(ctx, req, resp)
}
return err
}
// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// s2b converts string to a byte slice without memory allocation.
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func s2b(s string) (b []byte) {
bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
sh := (*reflect.StringHeader)(unsafe.Pointer(&s))
bh.Data = sh.Data
bh.Cap = sh.Len
bh.Len = sh.Len
return
}