Skip to content

Commit 579c986

Browse files
authored
feat: retry failed request (#1590)
1 parent 380717a commit 579c986

15 files changed

+304
-183
lines changed

plugins/wasm-go/extensions/ai-proxy/README.md

+17-8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ description: AI 代理插件配置参考
4141
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
4242
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
4343
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
44+
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
4445

4546
`context`的配置字段说明如下:
4647

@@ -78,14 +79,22 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
7879

7980
`failover` 的配置字段说明如下:
8081

81-
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
82-
|------------------|--------|------|-------|-----------------------------|
83-
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
84-
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
85-
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
86-
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
87-
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
88-
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
82+
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
83+
|------------------|--------|-----------------|-------|-----------------------------|
84+
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
85+
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
86+
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
87+
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
88+
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
89+
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
90+
91+
`retryOnFailure` 的配置字段说明如下:
92+
93+
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
94+
|------------------|--------|-----------------|-------|-------------|
95+
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
96+
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
97+
| retryTimeout | int | 非必填 | 5000 | 重试超时时间,单位毫秒 |
8998

9099
### 提供商特有配置
91100

plugins/wasm-go/extensions/ai-proxy/main.go

+23-32
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ import (
2020
const (
2121
pluginName = "ai-proxy"
2222

23-
ctxKeyApiName = "apiName"
24-
2523
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
2624
)
2725

@@ -92,14 +90,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
9290
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
9391
return types.ActionContinue
9492
}
93+
94+
ctx.SetContext(provider.CtxKeyApiName, apiName)
9595
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
9696
ctx.DisableReroute()
9797

98-
ctx.SetContext(ctxKeyApiName, apiName)
99-
100-
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
10198
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
102-
if needHandleBody || needHandleStreamingBody {
99+
if needHandleStreamingBody {
103100
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
104101
}
105102

@@ -138,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
138135
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
139136

140137
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
141-
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
138+
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
142139

143140
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
144141
if settingErr != nil {
@@ -186,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
186183
log.Errorf("unable to load :status header from response: %v", err)
187184
}
188185
ctx.DontReadResponseBody()
189-
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
190-
191-
return types.ActionContinue
186+
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
192187
}
193188

194189
// Reset ctxApiTokenRequestFailureCount if the request is successful,
195190
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
196191
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
197192

198-
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
199-
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
200-
action, err := handler.OnResponseHeaders(ctx, apiName, log)
201-
if err == nil {
202-
checkStream(&ctx, log)
203-
return action
204-
}
205-
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
206-
return types.ActionContinue
193+
headers := util.GetOriginalResponseHeaders()
194+
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
195+
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
196+
handler.TransformResponseHeaders(ctx, apiName, headers, log)
197+
} else {
198+
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
207199
}
200+
util.ReplaceResponseHeaders(headers)
208201

209202
checkStream(&ctx, log)
210-
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
211203
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
212-
if !needHandleBody && !needHandleStreamingBody {
213-
ctx.DontReadResponseBody()
214-
} else if !needHandleStreamingBody {
204+
if !needHandleStreamingBody {
215205
ctx.BufferResponseBody()
216206
}
217207

@@ -230,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
230220
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
231221

232222
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
233-
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
223+
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
234224
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
235225
if err == nil && modifiedChunk != nil {
236226
return modifiedChunk
@@ -249,16 +239,17 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
249239
}
250240

251241
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
252-
//log.Debugf("response body: %s", string(body))
253242

254-
if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok {
255-
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
256-
action, err := handler.OnResponseBody(ctx, apiName, body, log)
257-
if err == nil {
258-
return action
243+
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
244+
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
245+
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
246+
if err != nil {
247+
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
248+
return types.ActionContinue
249+
}
250+
if err = provider.ReplaceResponseBody(body, log); err != nil {
251+
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
259252
}
260-
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
261-
return types.ActionContinue
262253
}
263254
return types.ActionContinue
264255
}

plugins/wasm-go/extensions/ai-proxy/provider/claude.go

+4-16
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010

1111
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
1212
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
13-
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
1413
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
1514
)
1615

@@ -139,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
139138
return json.Marshal(claudeRequest)
140139
}
141140

142-
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
141+
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
143142
claudeResponse := &claudeTextGenResponse{}
144143
if err := json.Unmarshal(body, claudeResponse); err != nil {
145-
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err)
144+
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
146145
}
147146
if claudeResponse.Error != nil {
148-
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
147+
return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
149148
}
150149
response := c.responseClaude2OpenAI(ctx, claudeResponse)
151-
return types.ActionContinue, replaceJsonResponseBody(response, log)
152-
}
153-
154-
func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
155-
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
156-
if c.config.protocol == protocolOriginal {
157-
ctx.DontReadResponseBody()
158-
return types.ActionContinue, nil
159-
}
160-
161-
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
162-
return types.ActionContinue, nil
150+
return json.Marshal(response)
163151
}
164152

165153
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {

plugins/wasm-go/extensions/ai-proxy/provider/context.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
151151
if err != nil {
152152
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
153153
}
154-
if err := replaceHttpJsonRequestBody(body, log); err != nil {
154+
if err := replaceRequestBody(body, log); err != nil {
155155
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
156156
}
157157
}

plugins/wasm-go/extensions/ai-proxy/provider/deepl.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010

1111
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
1212
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
13-
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
1413
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
1514
)
1615

@@ -112,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
112111
return json.Marshal(baiduRequest)
113112
}
114113

115-
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
116-
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
117-
return types.ActionContinue, nil
118-
}
119-
120-
func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
114+
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
121115
deeplResponse := &deeplResponse{}
122116
if err := json.Unmarshal(body, deeplResponse); err != nil {
123-
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err)
117+
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
124118
}
125119
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
126-
return types.ActionContinue, replaceJsonResponseBody(response, log)
120+
return json.Marshal(response)
127121
}
128122

129123
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {

plugins/wasm-go/extensions/ai-proxy/provider/failover.go

+11-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
type failover struct {
2121
// @Title zh-CN 是否启用 apiToken 的 failover 机制
22-
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
22+
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
2323
// @Title zh-CN 触发 failover 连续请求失败的阈值
2424
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
2525
// @Title zh-CN 健康检测的成功阈值
@@ -29,7 +29,7 @@ type failover struct {
2929
// @Title zh-CN 健康检测的超时时间,单位毫秒
3030
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
3131
// @Title zh-CN 健康检测使用的模型
32-
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
32+
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
3333
// @Title zh-CN 本次请求使用的 apiToken
3434
ctxApiTokenInUse string
3535
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
@@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
184184
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
185185
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
186186
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
187-
headers := util.GetOriginalHttpHeaders()
187+
headers := util.GetOriginalRequestHeaders()
188188
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
189-
util.ReplaceOriginalHttpHeaders(headers)
189+
util.ReplaceRequestHeaders(headers)
190190
} else {
191191
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
192192
}
@@ -539,10 +539,15 @@ func (c *ProviderConfig) resetSharedData() {
539539
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
540540
}
541541

542-
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
542+
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action {
543543
if c.isFailoverEnabled() {
544544
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
545545
}
546+
if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) {
547+
c.retryFailedRequest(activeProvider, ctx, log)
548+
return types.HeaderStopAllIterationAndWatermark
549+
}
550+
return types.ActionContinue
546551
}
547552

548553
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
@@ -557,7 +562,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
557562
} else {
558563
apiToken = c.GetRandomToken()
559564
}
560-
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
565+
log.Debugf("Use apiToken %s to send request", apiToken)
561566
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
562567
}
563568

plugins/wasm-go/extensions/ai-proxy/provider/gemini.go

+10-21
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
105105
return json.Marshal(geminiRequest)
106106
}
107107

108-
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
109-
if g.config.protocol == protocolOriginal {
110-
ctx.DontReadResponseBody()
111-
return types.ActionContinue, nil
112-
}
113-
114-
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
115-
return types.ActionContinue, nil
116-
}
117-
118108
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
119109
log.Infof("chunk body:%s", string(chunk))
120110
if isLastChunk || len(chunk) == 0 {
@@ -148,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
148138
return []byte(modifiedResponseChunk), nil
149139
}
150140

151-
func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
141+
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
152142
if apiName == ApiNameChatCompletion {
153143
return g.onChatCompletionResponseBody(ctx, body, log)
154-
} else if apiName == ApiNameEmbeddings {
144+
} else {
155145
return g.onEmbeddingsResponseBody(ctx, body, log)
156146
}
157-
return types.ActionContinue, errUnsupportedApiName
158147
}
159148

160-
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
149+
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
161150
geminiResponse := &geminiChatResponse{}
162151
if err := json.Unmarshal(body, geminiResponse); err != nil {
163-
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
152+
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
164153
}
165154
if geminiResponse.Error != nil {
166-
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
155+
return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
167156
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
168157
}
169158
response := g.buildChatCompletionResponse(ctx, geminiResponse)
170-
return types.ActionContinue, replaceJsonResponseBody(response, log)
159+
return json.Marshal(response)
171160
}
172161

173-
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
162+
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
174163
geminiResponse := &geminiEmbeddingResponse{}
175164
if err := json.Unmarshal(body, geminiResponse); err != nil {
176-
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
165+
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
177166
}
178167
if geminiResponse.Error != nil {
179-
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
168+
return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
180169
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
181170
}
182171
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
183-
return types.ActionContinue, replaceJsonResponseBody(response, log)
172+
return json.Marshal(response)
184173
}
185174

186175
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {

0 commit comments

Comments
 (0)