Skip to content

Commit a7f802c

Browse files
authored
Merge branch 'songquanpeng:main' into main
2 parents 9982d45 + 6ad1699 commit a7f802c

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

relay/adaptor/cloudflare/adaptor.go

+22-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"net/http"
8+
"strings"
89

910
"github.com/gin-gonic/gin"
1011
"github.com/songquanpeng/one-api/relay/adaptor"
@@ -28,14 +29,32 @@ func (a *Adaptor) Init(meta *meta.Meta) {
2829
a.meta = meta
2930
}
3031

32+
// WorkerAI cannot be used across accounts with AIGateWay
33+
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
34+
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
35+
func (a *Adaptor) isAIGateWay(baseURL string) bool {
36+
return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
37+
}
38+
3139
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
40+
isAIGateWay := a.isAIGateWay(meta.BaseURL)
41+
var urlPrefix string
42+
if isAIGateWay {
43+
urlPrefix = meta.BaseURL
44+
} else {
45+
urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
46+
}
47+
3248
switch meta.Mode {
3349
case relaymode.ChatCompletions:
34-
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
50+
return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
3551
case relaymode.Embeddings:
36-
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
52+
return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
3753
default:
38-
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
54+
if isAIGateWay {
55+
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
56+
}
57+
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
3958
}
4059
}
4160

relay/adaptor/openai/main.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"bufio"
55
"bytes"
66
"encoding/json"
7-
"github.com/songquanpeng/one-api/common/render"
87
"io"
98
"net/http"
109
"strings"
1110

11+
"github.com/songquanpeng/one-api/common/render"
12+
1213
"github.com/gin-gonic/gin"
1314
"github.com/songquanpeng/one-api/common"
1415
"github.com/songquanpeng/one-api/common/conv"
@@ -31,6 +32,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
3132

3233
common.SetEventStreamHeaders(c)
3334

35+
doneRendered := false
3436
for scanner.Scan() {
3537
data := scanner.Text()
3638
if len(data) < dataPrefixLength { // ignore blank line or wrong format
@@ -41,6 +43,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
4143
}
4244
if strings.HasPrefix(data[dataPrefixLength:], done) {
4345
render.StringData(c, data)
46+
doneRendered = true
4447
continue
4548
}
4649
switch relayMode {
@@ -81,7 +84,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
8184
logger.SysError("error reading stream: " + err.Error())
8285
}
8386

84-
render.Done(c)
87+
if !doneRendered {
88+
render.Done(c)
89+
}
8590

8691
err := resp.Body.Close()
8792
if err != nil {

0 commit comments

Comments
 (0)