diff --git a/proxy_client_behavior.go b/proxy_client_behavior.go new file mode 100644 index 0000000..cdeb4c1 --- /dev/null +++ b/proxy_client_behavior.go @@ -0,0 +1,58 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import "time" + +type clientBehaviorType int + +const ( + do clientBehaviorType = iota + doDeadline + doRedirects + doTimeout +) + +type clientBehavior struct { + clientBehaviorType clientBehaviorType + param interface{} +} + +func ClientDo() clientBehavior { + return clientBehavior{ + clientBehaviorType: do, + } +} + +func ClientDoRedirects(param int) clientBehavior { + return clientBehavior{ + clientBehaviorType: doRedirects, + param: param, + } +} + +func ClientDoDeadline(param time.Time) clientBehavior { + return clientBehavior{ + clientBehaviorType: doDeadline, + param: param, + } +} + +func ClientDoTimeout(param time.Duration) clientBehavior { + return clientBehavior{ + clientBehaviorType: doTimeout, + param: param, + } +} diff --git a/reverse_proxy.go b/reverse_proxy.go index 1256200..08ddca0 100644 --- a/reverse_proxy.go +++ b/reverse_proxy.go @@ -32,6 +32,7 @@ import ( "reflect" "strings" "sync" + "time" "unsafe" "github.com/cloudwego/hertz/pkg/app" @@ -45,6 +46,8 @@ import ( type ReverseProxy struct { client *client.Client + clientBehavior clientBehavior + // target is set as a reverse proxy address Target string @@ -105,7 +108,6 @@ var hopHeaders = []string{ // To rewrite Host headers, use ReverseProxy directly with a custom // director policy. // -// Note: if no config.ClientOption is passed it will use the default global client.Client instance. // When passing config.ClientOption it will initialize a local client.Client instance. // Using ReverseProxy.SetClient if there is need for shared customized client.Client instance. func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*ReverseProxy, error) { @@ -116,13 +118,11 @@ func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (* req.Header.SetHostBytes(req.URI().Host()) }, } - if len(options) != 0 { - c, err := client.NewClient(options...) - if err != nil { - return nil, err - } - r.client = c + c, err := client.NewClient(options...) + if err != nil { + return nil, err } + r.client = c return r, nil } @@ -275,11 +275,8 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) { req.Header.Add("X-Forwarded-For", ip) } } - fn := client.Do - if r.client != nil { - fn = r.client.Do - } - err := fn(c, req, resp) + + err := r.doClientBehavior(c, req, resp) if err != nil { hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error()) r.getErrorHandler()(ctx, err) @@ -345,6 +342,10 @@ func (r *ReverseProxy) SetSaveOriginResHeader(b bool) { r.saveOriginResHeader = b } +func (r *ReverseProxy) SetClientBehavior(cb clientBehavior) { + r.clientBehavior = cb +} + func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) { if r.errorHandler != nil { return r.errorHandler @@ -352,6 +353,24 @@ func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) return r.defaultErrorHandler } +func (r *ReverseProxy) doClientBehavior(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { + var err error + switch r.clientBehavior.clientBehaviorType { + case doDeadline: + deadline := r.clientBehavior.param.(time.Time) + err = r.client.DoDeadline(ctx, req, resp, deadline) + case doRedirects: + maxRedirectsCount := r.clientBehavior.param.(int) + err = r.client.DoRedirects(ctx, req, resp, maxRedirectsCount) + case doTimeout: + timeout := r.clientBehavior.param.(time.Duration) + err = r.client.DoTimeout(ctx, req, resp, timeout) + default: + err = r.client.Do(ctx, req, resp) + } + return err +} + // b2s converts byte slice to a string without memory allocation. // See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . // diff --git a/reverse_proxy_test.go b/reverse_proxy_test.go index 6699aac..903cb41 100644 --- a/reverse_proxy_test.go +++ b/reverse_proxy_test.go @@ -569,6 +569,7 @@ func TestReverseProxySaveRespHeader(t *testing.T) { proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy") proxy.SetSaveOriginResHeader(true) + proxy.SetClientBehavior(ClientDoRedirects(2)) if err != nil { t.Errorf("proxy error: %v", err) }