Skip to content

Commit

Permalink
feat: add SetClientBehavior method to allow users can select proxy cl…
Browse files Browse the repository at this point in the history
…ient's do behavior (#19)

* do redirect

* do redirect

* do redirect

* do redirect

* support choose proxy client do behavior

* support choose proxy client do behavior

* support choose proxy client do behavior

* support choose proxy client do behavior

* support choose proxy client do behavior

* support choose proxy client do behavior
  • Loading branch information
dragonYang200 authored May 31, 2024
1 parent 3059929 commit e589602
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 12 deletions.
58 changes: 58 additions & 0 deletions proxy_client_behavior.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 CloudWeGo Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package reverseproxy

import "time"

type clientBehaviorType int

const (
do clientBehaviorType = iota
doDeadline
doRedirects
doTimeout
)

type clientBehavior struct {
clientBehaviorType clientBehaviorType
param interface{}
}

func ClientDo() clientBehavior {
return clientBehavior{
clientBehaviorType: do,
}
}

func ClientDoRedirects(param int) clientBehavior {
return clientBehavior{
clientBehaviorType: doRedirects,
param: param,
}
}

func ClientDoDeadline(param time.Time) clientBehavior {
return clientBehavior{
clientBehaviorType: doDeadline,
param: param,
}
}

func ClientDoTimeout(param time.Duration) clientBehavior {
return clientBehavior{
clientBehaviorType: doTimeout,
param: param,
}
}
43 changes: 31 additions & 12 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"reflect"
"strings"
"sync"
"time"
"unsafe"

"github.com/cloudwego/hertz/pkg/app"
Expand All @@ -45,6 +46,8 @@ import (
type ReverseProxy struct {
client *client.Client

clientBehavior clientBehavior

// target is set as a reverse proxy address
Target string

Expand Down Expand Up @@ -105,7 +108,6 @@ var hopHeaders = []string{
// To rewrite Host headers, use ReverseProxy directly with a custom
// director policy.
//
// Note: if no config.ClientOption is passed it will use the default global client.Client instance.
// When passing config.ClientOption it will initialize a local client.Client instance.
// Using ReverseProxy.SetClient if there is need for shared customized client.Client instance.
func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*ReverseProxy, error) {
Expand All @@ -116,13 +118,11 @@ func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*
req.Header.SetHostBytes(req.URI().Host())
},
}
if len(options) != 0 {
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
return r, nil
}

Expand Down Expand Up @@ -275,11 +275,8 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
req.Header.Add("X-Forwarded-For", ip)
}
}
fn := client.Do
if r.client != nil {
fn = r.client.Do
}
err := fn(c, req, resp)

err := r.doClientBehavior(c, req, resp)
if err != nil {
hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error())
r.getErrorHandler()(ctx, err)
Expand Down Expand Up @@ -345,13 +342,35 @@ func (r *ReverseProxy) SetSaveOriginResHeader(b bool) {
r.saveOriginResHeader = b
}

func (r *ReverseProxy) SetClientBehavior(cb clientBehavior) {
r.clientBehavior = cb
}

func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) {
if r.errorHandler != nil {
return r.errorHandler
}
return r.defaultErrorHandler
}

func (r *ReverseProxy) doClientBehavior(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
var err error
switch r.clientBehavior.clientBehaviorType {
case doDeadline:
deadline := r.clientBehavior.param.(time.Time)
err = r.client.DoDeadline(ctx, req, resp, deadline)
case doRedirects:
maxRedirectsCount := r.clientBehavior.param.(int)
err = r.client.DoRedirects(ctx, req, resp, maxRedirectsCount)
case doTimeout:
timeout := r.clientBehavior.param.(time.Duration)
err = r.client.DoTimeout(ctx, req, resp, timeout)
default:
err = r.client.Do(ctx, req, resp)
}
return err
}

// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
Expand Down
1 change: 1 addition & 0 deletions reverse_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ func TestReverseProxySaveRespHeader(t *testing.T) {

proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy")
proxy.SetSaveOriginResHeader(true)
proxy.SetClientBehavior(ClientDoRedirects(2))
if err != nil {
t.Errorf("proxy error: %v", err)
}
Expand Down

0 comments on commit e589602

Please sign in to comment.