diff --git a/reverse_proxy.go b/reverse_proxy.go index 71129a8..16df778 100644 --- a/reverse_proxy.go +++ b/reverse_proxy.go @@ -47,9 +47,12 @@ type ReverseProxy struct { // target is set as a reverse proxy address Target string - // transforTrailer is whether to forward Trailer-related header + // transferTrailer is whether to forward Trailer-related header transferTrailer bool + // saveOriginResponse is whether to save the original response header + saveOriginResHeader bool + // director must be a function which modifies the request // into a new request. Its response is then redirected // back to the original client unmodified. @@ -214,6 +217,21 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) { req := &ctx.Request resp := &ctx.Response + // save tmp resp header + respTmpHeader := map[string][]string{} + if r.saveOriginResHeader { + resp.Header.SetNoDefaultContentType(true) + resp.Header.VisitAll(func(key, value []byte) { + keyStr := string(key) + valueStr := string(value) + if _, ok := respTmpHeader[keyStr]; ok { + respTmpHeader[keyStr] = []string{valueStr} + } else { + respTmpHeader[keyStr] = append(respTmpHeader[keyStr], valueStr) + } + }) + } + if r.director != nil { r.director(&ctx.Request) } @@ -261,6 +279,13 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) { return } + // add tmp resp header + for key, hs := range respTmpHeader { + for _, h := range hs { + resp.Header.Add(key, h) + } + } + removeResponseConnHeaders(ctx) for _, h := range hopHeaders { @@ -303,6 +328,10 @@ func (r *ReverseProxy) SetTransferTrailer(b bool) { r.transferTrailer = b } +func (r *ReverseProxy) SetSaveOriginResHeader(b bool) { + r.saveOriginResHeader = b +} + func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) { if r.errorHandler != nil { return r.errorHandler diff --git a/reverse_proxy_test.go b/reverse_proxy_test.go index a56288f..9ed83d2 100644 --- a/reverse_proxy_test.go +++ b/reverse_proxy_test.go @@ -33,6 +33,8 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/app/server" @@ -556,3 +558,35 @@ func TestReverseProxyTransferTrailer(t *testing.T) { t.Errorf("handler got X-Trailer Trailer value %q; want 'trailer_value'", c) } } + +func TestReverseProxySaveRespHeader(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + r := server.New(server.WithHostPorts("127.0.0.1:9997")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + ctx.Data(backendStatus, "application/json", []byte(backendResponse)) + }) + + proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy") + proxy.SetSaveOriginResHeader(true) + if err != nil { + t.Errorf("proxy error: %v", err) + } + + r.GET("/backend", func(c context.Context, ctx *app.RequestContext) { + ctx.Response.Header.Set("aaa", "bbb") + proxy.ServeHTTP(c, ctx) + }) + go r.Spin() + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + res := protocol.AcquireResponse() + req.SetRequestURI("http://localhost:9997/backend") + err = cli.Do(context.Background(), req, res) + if err != nil { + t.Fatalf("Get: %v", err) + } + assert.DeepEqual(t, "bbb", res.Header.Get("aaa")) +}