@@ -20,8 +20,6 @@ import (
20
20
const (
21
21
pluginName = "ai-proxy"
22
22
23
- ctxKeyApiName = "apiName"
24
-
25
23
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
26
24
)
27
25
@@ -92,14 +90,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
92
90
log .Warnf ("[onHttpRequestHeader] unsupported path: %s" , path .Path )
93
91
return types .ActionContinue
94
92
}
93
+
94
+ ctx .SetContext (provider .CtxKeyApiName , apiName )
95
95
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
96
96
ctx .DisableReroute ()
97
97
98
- ctx .SetContext (ctxKeyApiName , apiName )
99
-
100
- _ , needHandleBody := activeProvider .(provider.ResponseBodyHandler )
101
98
_ , needHandleStreamingBody := activeProvider .(provider.StreamingResponseBodyHandler )
102
- if needHandleBody || needHandleStreamingBody {
99
+ if needHandleStreamingBody {
103
100
proxywasm .RemoveHttpRequestHeader ("Accept-Encoding" )
104
101
}
105
102
@@ -138,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
138
135
log .Debugf ("[onHttpRequestBody] provider=%s" , activeProvider .GetProviderType ())
139
136
140
137
if handler , ok := activeProvider .(provider.RequestBodyHandler ); ok {
141
- apiName , _ := ctx .GetContext (ctxKeyApiName ).(provider.ApiName )
138
+ apiName , _ := ctx .GetContext (provider . CtxKeyApiName ).(provider.ApiName )
142
139
143
140
newBody , settingErr := pluginConfig .GetProviderConfig ().ReplaceByCustomSettings (body )
144
141
if settingErr != nil {
@@ -186,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
186
183
log .Errorf ("unable to load :status header from response: %v" , err )
187
184
}
188
185
ctx .DontReadResponseBody ()
189
- providerConfig .OnRequestFailed (ctx , apiTokenInUse , log )
190
-
191
- return types .ActionContinue
186
+ return providerConfig .OnRequestFailed (activeProvider , ctx , apiTokenInUse , log )
192
187
}
193
188
194
189
// Reset ctxApiTokenRequestFailureCount if the request is successful,
195
190
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
196
191
providerConfig .ResetApiTokenRequestFailureCount (apiTokenInUse , log )
197
192
198
- if handler , ok := activeProvider .(provider.ResponseHeadersHandler ); ok {
199
- apiName , _ := ctx .GetContext (ctxKeyApiName ).(provider.ApiName )
200
- action , err := handler .OnResponseHeaders (ctx , apiName , log )
201
- if err == nil {
202
- checkStream (& ctx , log )
203
- return action
204
- }
205
- util .ErrorHandler ("ai-proxy.proc_resp_headers_failed" , fmt .Errorf ("failed to process response headers: %v" , err ))
206
- return types .ActionContinue
193
+ headers := util .GetOriginalResponseHeaders ()
194
+ if handler , ok := activeProvider .(provider.TransformResponseHeadersHandler ); ok {
195
+ apiName , _ := ctx .GetContext (provider .CtxKeyApiName ).(provider.ApiName )
196
+ handler .TransformResponseHeaders (ctx , apiName , headers , log )
197
+ } else {
198
+ providerConfig .DefaultTransformResponseHeaders (ctx , headers )
207
199
}
200
+ util .ReplaceResponseHeaders (headers )
208
201
209
202
checkStream (& ctx , log )
210
- _ , needHandleBody := activeProvider .(provider.ResponseBodyHandler )
211
203
_ , needHandleStreamingBody := activeProvider .(provider.StreamingResponseBodyHandler )
212
- if ! needHandleBody && ! needHandleStreamingBody {
213
- ctx .DontReadResponseBody ()
214
- } else if ! needHandleStreamingBody {
204
+ if ! needHandleStreamingBody {
215
205
ctx .BufferResponseBody ()
216
206
}
217
207
@@ -230,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
230
220
log .Debugf ("isLastChunk=%v chunk: %s" , isLastChunk , string (chunk ))
231
221
232
222
if handler , ok := activeProvider .(provider.StreamingResponseBodyHandler ); ok {
233
- apiName , _ := ctx .GetContext (ctxKeyApiName ).(provider.ApiName )
223
+ apiName , _ := ctx .GetContext (provider . CtxKeyApiName ).(provider.ApiName )
234
224
modifiedChunk , err := handler .OnStreamingResponseBody (ctx , apiName , chunk , isLastChunk , log )
235
225
if err == nil && modifiedChunk != nil {
236
226
return modifiedChunk
@@ -249,16 +239,17 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
249
239
}
250
240
251
241
log .Debugf ("[onHttpResponseBody] provider=%s" , activeProvider .GetProviderType ())
252
- //log.Debugf("response body: %s", string(body))
253
242
254
- if handler , ok := activeProvider .(provider.ResponseBodyHandler ); ok {
255
- apiName , _ := ctx .GetContext (ctxKeyApiName ).(provider.ApiName )
256
- action , err := handler .OnResponseBody (ctx , apiName , body , log )
257
- if err == nil {
258
- return action
243
+ if handler , ok := activeProvider .(provider.TransformResponseBodyHandler ); ok {
244
+ apiName , _ := ctx .GetContext (provider .CtxKeyApiName ).(provider.ApiName )
245
+ body , err := handler .TransformResponseBody (ctx , apiName , body , log )
246
+ if err != nil {
247
+ util .ErrorHandler ("ai-proxy.proc_resp_body_failed" , fmt .Errorf ("failed to process response body: %v" , err ))
248
+ return types .ActionContinue
249
+ }
250
+ if err = provider .ReplaceResponseBody (body , log ); err != nil {
251
+ util .ErrorHandler ("ai-proxy.replace_resp_body_failed" , fmt .Errorf ("failed to replace response body: %v" , err ))
259
252
}
260
- util .ErrorHandler ("ai-proxy.proc_resp_body_failed" , fmt .Errorf ("failed to process response body: %v" , err ))
261
- return types .ActionContinue
262
253
}
263
254
return types .ActionContinue
264
255
}
0 commit comments