diff --git a/README.md b/README.md index 5f9947b0a7..fb137c2311 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [together.ai](https://www.together.ai/) + [x] [novita.ai](https://www.novita.ai/) + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud) + + [x] [xAI](https://x.ai/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -399,6 +400,7 @@ graph LR 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 +29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/config/config.go b/common/config/config.go index 11da0b967d..2eb894ef72 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var OidcEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -70,6 +71,13 @@ var GitHubClientSecret = "" var LarkClientId = "" var LarkClientSecret = "" +var OidcClientId = "" +var OidcClientSecret = "" +var OidcWellKnown = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" @@ -152,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) + +var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 90556b3af6..115558a51c 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -20,4 +20,5 @@ const ( BaseURL = "base_url" AvailableModels = "available_models" KeyRequestBody = "key_request_body" + SystemPrompt = "system_prompt" ) diff --git a/common/gin.go b/common/gin.go index 549d3279c9..815b4ee54a 100644 --- a/common/gin.go +++ b/common/gin.go @@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = json.Unmarshal(requestBody, &v) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) } else { - // skip for now - // TODO: someday non json request have variant model, we will need to implementation this + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = c.ShouldBind(&v) } if err != nil { return err } // Reset request body - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } diff --git a/common/helper/helper.go b/common/helper/helper.go index e06dfb6e64..df7b0a5f9c 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -137,3 +137,23 @@ func String2Int(str string) int { } return num } + +func Float64PtrMax(p *float64, maxValue float64) *float64 { + if p == nil { + return nil + } + if *p > maxValue { + return &maxValue + } + return p +} + +func Float64PtrMin(p *float64, minValue float64) *float64 { + if p == nil { + return nil + } + if *p < minValue { + return &minValue + } + return p +} diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 0000000000..7b4ad4b9ee --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,225 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcClientId, + "client_secret": config.OidcClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/billing.go b/controller/billing.go index 0d03e4c189..e837157f03 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) { if config.DisplayTokenStatEnabled { tokenId := c.GetInt(ctxkey.TokenId) token, err = model.GetTokenById(tokenId) - expiredTime = token.ExpiredTime - remainQuota = token.RemainQuota - usedQuota = token.UsedQuota + if err == nil { + expiredTime = token.ExpiredTime + remainQuota = token.RemainQuota + usedQuota = token.UsedQuota + } } else { userId := c.GetInt(ctxkey.Id) remainQuota, err = model.GetUserQuota(userId) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 535927444e..a6ffaafe7c 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct { TotalUsed float64 `json:"total_used"` } +type SiliconFlowUsageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Status bool `json:"status"` + Data struct { + ID string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + Email string `json:"email"` + IsAdmin bool `json:"isAdmin"` + Balance string `json:"balance"` + Status string `json:"status"` + Introduction string `json:"introduction"` + Role string `json:"role"` + ChargeBalance string `json:"chargeBalance"` + TotalBalance string `json:"totalBalance"` + Category string `json:"category"` + } `json:"data"` +} + // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} @@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { return response.TotalAvailable, nil } +func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { + url := "https://api.siliconflow.cn/v1/user/info" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := SiliconFlowUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if response.Code != 20000 { + return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) + } + balance, err := strconv.ParseFloat(response.Data.Balance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { @@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return updateChannelAPI2GPTBalance(channel) case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) + case channeltype.SiliconFlow: + return updateChannelSiliconFlowBalance(channel) default: return 0, errors.New("尚未实现") } diff --git a/controller/channel-test.go b/controller/channel-test.go index f8327284c0..971f53826e 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -76,9 +76,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques if len(modelNames) > 0 { modelName = modelNames[0] } - if modelMap != nil && modelMap[modelName] != "" { - modelName = modelMap[modelName] - } + } + if modelMap != nil && modelMap[modelName] != "" { + modelName = modelMap[modelName] } meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName diff --git a/controller/misc.go b/controller/misc.go index 2928b8fb33..ae90087017 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": config.EmailVerificationEnabled, - "github_oauth": config.GitHubOAuthEnabled, - "github_client_id": config.GitHubClientId, - "lark_client_id": config.LarkClientId, - "system_name": config.SystemName, - "logo": config.Logo, - "footer_html": config.Footer, - "wechat_qrcode": config.WeChatAccountQRCodeImageURL, - "wechat_login": config.WeChatAuthEnabled, - "server_address": config.ServerAddress, - "turnstile_check": config.TurnstileCheckEnabled, - "turnstile_site_key": config.TurnstileSiteKey, - "top_up_link": config.TopUpLink, - "chat_link": config.ChatLink, - "quota_per_unit": config.QuotaPerUnit, - "display_in_currency": config.DisplayInCurrencyEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_client_id": config.OidcClientId, + "oidc_well_known": config.OidcWellKnown, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, }, }) return diff --git a/middleware/distributor.go b/middleware/distributor.go index 0c4b04c341..0aceb29dd8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -12,7 +12,7 @@ import ( ) type ModelRequest struct { - Model string `json:"model"` + Model string `json:"model" form:"model"` } func Distribute() func(c *gin.Context) { @@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) + if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { + c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) + } c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/middleware/gzip.go b/middleware/gzip.go new file mode 100644 index 0000000000..4d4ce0c255 --- /dev/null +++ b/middleware/gzip.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "compress/gzip" + "github.com/gin-gonic/gin" + "io" + "net/http" +) + +func GzipDecodeMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.GetHeader("Content-Encoding") == "gzip" { + gzipReader, err := gzip.NewReader(c.Request.Body) + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + defer gzipReader.Close() + + // Replace the request body with the decompressed data + c.Request.Body = io.NopCloser(gzipReader) + } + + // Continue processing the request + c.Next() + } +} diff --git a/model/channel.go b/model/channel.go index 759dfd4fed..4b0f4b01aa 100644 --- a/model/channel.go +++ b/model/channel.go @@ -37,6 +37,7 @@ type Channel struct { ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Config string `json:"config"` + SystemPrompt *string `json:"system_prompt" gorm:"type:text"` } type ChannelConfig struct { diff --git a/model/log.go b/model/log.go index 6fba776a53..58fdd513cd 100644 --- a/model/log.go +++ b/model/log.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { } func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { - tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") + ifnull := "ifnull" + if common.UsingPostgreSQL { + ifnull = "COALESCE" + } + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) if username != "" { tx = tx.Where("username = ?", username) } @@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + ifnull := "ifnull" + if common.UsingPostgreSQL { + ifnull = "COALESCE" + } + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull)) if username != "" { tx = tx.Where("username = ?", username) } diff --git a/model/option.go b/model/option.go index bed8d4c37d..8fd30aee2a 100644 --- a/model/option.go +++ b/model/option.go @@ -28,6 +28,7 @@ func InitOptionMap() { config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue case "WeChatAuthEnabled": config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { config.LarkClientId = value case "LarkClientSecret": config.LarkClientSecret = value + case "OidcClientId": + config.OidcClientId = value + case "OidcClientSecret": + config.OidcClientSecret = value + case "OidcWellKnown": + config.OidcWellKnown = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/token.go b/model/token.go index 96e6b4918a..91e72a82ad 100644 --- a/model/token.go +++ b/model/token.go @@ -30,7 +30,7 @@ type Token struct { RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota - Models *string `json:"models" gorm:"default:''"` // allowed models + Models *string `json:"models" gorm:"type:text"` // allowed models Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } @@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) { return &token, err } -func (token *Token) Insert() error { +func (t *Token) Insert() error { var err error - err = DB.Create(token).Error + err = DB.Create(t).Error return err } // Update Make sure your token's fields is completed, because this will update non-zero values -func (token *Token) Update() error { +func (t *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error + err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error return err } -func (token *Token) SelectUpdate() error { +func (t *Token) SelectUpdate() error { // This can update zero values - return DB.Model(token).Select("accessed_time", "status").Updates(token).Error + return DB.Model(t).Select("accessed_time", "status").Updates(t).Error } -func (token *Token) Delete() error { +func (t *Token) Delete() error { var err error - err = DB.Delete(token).Error + err = DB.Delete(t).Error return err } +func (t *Token) GetModels() string { + if t == nil { + return "" + } + if t.Models == nil { + return "" + } + return *t.Models +} + func DeleteTokenById(id int, userId int) (err error) { // Why we need userId here? In case user want to delete other's token. if id == 0 || userId == 0 { @@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) + if err != nil { + return err + } if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) } else { err = IncreaseUserQuota(token.UserId, -quota) } - if err != nil { - return err - } if !token.UnlimitedQuota { if quota > 0 { err = DecreaseTokenQuota(tokenId, quota) diff --git a/model/user.go b/model/user.go index 924d72f940..a964a0d7d2 100644 --- a/model/user.go +++ b/model/user.go @@ -39,6 +39,7 @@ type User struct { GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error { return nil } +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/monitor/manage.go b/monitor/manage.go index 946e78afe9..44c13612d3 100644 --- a/monitor/manage.go +++ b/monitor/manage.go @@ -1,10 +1,11 @@ package monitor import ( - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/relay/model" "net/http" "strings" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" ) func ShouldDisableChannel(err *model.Error, statusCode int) bool { @@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { return true } switch err.Type { - case "insufficient_quota": - return true - // https://docs.anthropic.com/claude/reference/errors - case "authentication_error": - return true - case "permission_error": - return true - case "forbidden": + case "insufficient_quota", "authentication_error", "permission_error", "forbidden": return true } if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } - if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic - return true - } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { - return true - } - //if strings.Contains(err.Message, "quota") { - // return true - //} - if strings.Contains(err.Message, "credit") { - return true - } - if strings.Contains(err.Message, "balance") { + + lowerMessage := strings.ToLower(err.Message) + if strings.Contains(lowerMessage, "your access was terminated") || + strings.Contains(lowerMessage, "violation of our policies") || + strings.Contains(lowerMessage, "your credit balance is too low") || + strings.Contains(lowerMessage, "organization has been disabled") || + strings.Contains(lowerMessage, "credit") || + strings.Contains(lowerMessage, "balance") || + strings.Contains(lowerMessage, "permission denied") || + strings.Contains(lowerMessage, "organization has been restricted") || // groq + strings.Contains(lowerMessage, "已欠费") { return true } return false diff --git a/one-api b/one-api new file mode 100755 index 0000000000..4c9190bb93 Binary files /dev/null and b/one-api differ diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index f9039dbe49..6a73c7072f 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,6 +3,7 @@ package ali import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/render" "io" "net/http" @@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } - if request.TopP >= 1 { - request.TopP = 0.9999 - } + request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) return &ChatRequest{ Model: aliModel, Input: Input{ @@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ - Model: "text-embedding-v1", + Model: request.Model, Input: struct { Texts []string `json:"texts"` }{ @@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat StatusCode: resp.StatusCode, }, nil } - + requestModel := c.GetString(ctxkey.RequestModel) fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + fullTextResponse.Model = requestModel jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go index 450b5f5292..a680c7e24b 100644 --- a/relay/adaptor/ali/model.go +++ b/relay/adaptor/ali/model.go @@ -16,13 +16,13 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Seed uint64 `json:"seed,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` ResultFormat string `json:"result_format,omitempty"` Tools []model.Tool `json:"tools,omitempty"` } diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index 143d1efc25..cb574706d4 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -3,7 +3,11 @@ package anthropic var ModelList = []string{ "claude-instant-1.2", "claude-2.0", "claude-2.1", "claude-3-haiku-20240307", + "claude-3-5-haiku-20241022", "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620", + "claude-3-5-sonnet-20241022", + "claude-3-5-sonnet-latest", + "claude-3-5-haiku-20241022", } diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 47f766291d..47f193faa0 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -48,8 +48,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 7142e46f72..3fe3dfd8fa 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", "claude-2.1": "anthropic.claude-v2:1", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", - "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", } func awsModelID(requestModel string) (string, error) { diff --git a/relay/adaptor/aws/claude/model.go b/relay/adaptor/aws/claude/model.go index 6d00b68865..106228877b 100644 --- a/relay/adaptor/aws/claude/model.go +++ b/relay/adaptor/aws/claude/model.go @@ -11,8 +11,8 @@ type Request struct { Messages []anthropic.Message `json:"messages"` System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go index 7b86c3b8ff..6cb64cdeac 100644 --- a/relay/adaptor/aws/llama3/model.go +++ b/relay/adaptor/aws/llama3/model.go @@ -4,10 +4,10 @@ package aws // // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html type Request struct { - Prompt string `json:"prompt"` - MaxGenLen int `json:"max_gen_len,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` } // Response is the response from AWS Llama3 diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index ebe70c3241..ac8a562544 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -35,9 +35,9 @@ type Message struct { type ChatRequest struct { Messages []Message `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - PenaltyScore float64 `json:"penalty_score,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + PenaltyScore *float64 `json:"penalty_score,omitempty"` Stream bool `json:"stream,omitempty"` System string `json:"system,omitempty"` DisableSearch bool `json:"disable_search,omitempty"` diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0d3bafe098..8e382ba7ad 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -9,5 +9,5 @@ type Request struct { Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 45db437b6b..736c5a8d86 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { K: textRequest.TopK, Stream: textRequest.Stream, FrequencyPenalty: textRequest.FrequencyPenalty, - PresencePenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.PresencePenalty, Seed: int(textRequest.Seed), } if cohereRequest.Model == "" { diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go index 64fa9c9403..3a8bc99dc7 100644 --- a/relay/adaptor/cohere/model.go +++ b/relay/adaptor/cohere/model.go @@ -10,15 +10,15 @@ type Request struct { PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" Connectors []Connector `json:"connectors,omitempty"` Documents []Document `json:"documents,omitempty"` - Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3 + Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3 MaxTokens int `json:"max_tokens,omitempty"` MaxInputTokens int `json:"max_input_tokens,omitempty"` K int `json:"k,omitempty"` // 默认值为0 - P float64 `json:"p,omitempty"` // 默认值为0.75 + P *float64 `json:"p,omitempty"` // 默认值为0.75 Seed int `json:"seed,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 - PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 Tools []Tool `json:"tools,omitempty"` ToolResults []ToolResult `json:"tool_results,omitempty"` } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 51fd6aa801..d6ab45d489 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,11 +4,12 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -28,6 +29,11 @@ const ( VisionMaxImageNum = 16 ) +var mimeTypeMap = map[string]string{ + "json_object": "application/json", + "text": "text/plain", +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { geminiRequest := ChatRequest{ @@ -56,6 +62,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { MaxOutputTokens: textRequest.MaxTokens, }, } + if textRequest.ResponseFormat != nil { + if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { + geminiRequest.GenerationConfig.ResponseMimeType = mimeType + } + if textRequest.ResponseFormat.JsonSchema != nil { + geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema + geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] + } + } if textRequest.Tools != nil { functions := make([]model.Function, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index f7179ea48e..720cb65d19 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -65,10 +65,12 @@ type ChatTools struct { } type ChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` } diff --git a/relay/adaptor/groq/constants.go b/relay/adaptor/groq/constants.go index 559851eea0..0864ebe75e 100644 --- a/relay/adaptor/groq/constants.go +++ b/relay/adaptor/groq/constants.go @@ -4,14 +4,24 @@ package groq var ModelList = []string{ "gemma-7b-it", - "mixtral-8x7b-32768", - "llama3-8b-8192", - "llama3-70b-8192", "gemma2-9b-it", - "llama-3.1-405b-reasoning", "llama-3.1-70b-versatile", "llama-3.1-8b-instant", + "llama-3.2-11b-text-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-1b-preview", + "llama-3.2-3b-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-90b-text-preview", + "llama-3.2-90b-vision-preview", + "llama-guard-3-8b", + "llama3-70b-8192", + "llama3-8b-8192", "llama3-groq-70b-8192-tool-use-preview", "llama3-groq-8b-8192-tool-use-preview", + "llava-v1.5-7b-4096-preview", + "mixtral-8x7b-32768", + "distil-whisper-large-v3-en", "whisper-large-v3", + "whisper-large-v3-turbo", } diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go index 7039984fcc..94f2ab7332 100644 --- a/relay/adaptor/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -1,14 +1,14 @@ package ollama type Options struct { - Seed int `json:"seed,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } type Message struct { diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 5dc395adfa..6946e402a8 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } + if request.Stream { + // always return usage in stream mode + if request.StreamOptions == nil { + request.StreamOptions = &model.StreamOptions{} + } + request.StreamOptions.IncludeUsage = true + } return request, nil } diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go index 0512f05ca7..15b4dcc032 100644 --- a/relay/adaptor/openai/compatible.go +++ b/relay/adaptor/openai/compatible.go @@ -11,9 +11,10 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" "github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/togetherai" - "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" + "github.com/songquanpeng/one-api/relay/adaptor/xai" "github.com/songquanpeng/one-api/relay/channeltype" ) @@ -32,6 +33,7 @@ var CompatibleChannels = []int{ channeltype.TogetherAI, channeltype.Novita, channeltype.SiliconFlow, + channeltype.XAI, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -64,6 +66,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "novita", novita.ModelList case channeltype.SiliconFlow: return "siliconflow", siliconflow.ModelList + case channeltype.XAI: + return "xai", xai.ModelList default: return "openai", ModelList } diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go index 156a50e7b0..aacdba1ad3 100644 --- a/relay/adaptor/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -8,6 +8,8 @@ var ModelList = []string{ "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 9ee547b371..970807384f 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E render.StringData(c, data) // if error happened, pass the data to client continue // just ignore the error } - if len(streamResponse.Choices) == 0 { - // but for empty choice, we should not pass it to client, this is for azure + if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { + // but for empty choice and no usage, we should not pass it to client, this is for azure continue // just ignore empty choice } render.StringData(c, data) diff --git a/relay/adaptor/palm/model.go b/relay/adaptor/palm/model.go index f653022c3e..2bdd8f298b 100644 --- a/relay/adaptor/palm/model.go +++ b/relay/adaptor/palm/model.go @@ -19,11 +19,11 @@ type Prompt struct { } type ChatRequest struct { - Prompt Prompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` + Prompt Prompt `json:"prompt"` + Temperature *float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` } type Error struct { diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go index a82e562b2b..6a2346cac5 100644 --- a/relay/adaptor/stepfun/constants.go +++ b/relay/adaptor/stepfun/constants.go @@ -1,7 +1,13 @@ package stepfun var ModelList = []string{ + "step-1-8k", "step-1-32k", + "step-1-128k", + "step-1-256k", + "step-1-flash", + "step-2-16k", + "step-1v-8k", "step-1v-32k", - "step-1-200k", + "step-1x-medium", } diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go index be415a94c8..e8631e5f47 100644 --- a/relay/adaptor/tencent/constants.go +++ b/relay/adaptor/tencent/constants.go @@ -5,4 +5,5 @@ var ModelList = []string{ "hunyuan-standard", "hunyuan-standard-256K", "hunyuan-pro", + "hunyuan-vision", } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 365e33aef6..827c8a46dd 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { Model: &request.Model, Stream: &request.Stream, Messages: messages, - TopP: &request.TopP, - Temperature: &request.Temperature, + TopP: request.TopP, + Temperature: request.Temperature, } } diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index b39e2ddab1..cb911cfea0 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -13,7 +13,12 @@ import ( ) var ModelList = []string{ - "claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229", + "claude-3-haiku@20240307", + "claude-3-sonnet@20240229", + "claude-3-opus@20240229", + "claude-3-5-sonnet@20240620", + "claude-3-5-sonnet-v2@20241022", + "claude-3-5-haiku@20241022", } const anthropicVersion = "vertex-2023-10-16" diff --git a/relay/adaptor/vertexai/claude/model.go b/relay/adaptor/vertexai/claude/model.go index e1bd5dd48d..c08ba460d9 100644 --- a/relay/adaptor/vertexai/claude/model.go +++ b/relay/adaptor/vertexai/claude/model.go @@ -11,8 +11,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index 43e6cbcde3..ceff1ed2a0 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -15,7 +15,7 @@ import ( ) var ModelList = []string{ - "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", } type Adaptor struct { diff --git a/relay/adaptor/xai/constants.go b/relay/adaptor/xai/constants.go new file mode 100644 index 0000000000..9082b999a3 --- /dev/null +++ b/relay/adaptor/xai/constants.go @@ -0,0 +1,5 @@ +package xai + +var ModelList = []string{ + "grok-beta", +} diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go index 12a5621099..5b82ac292f 100644 --- a/relay/adaptor/xunfei/constants.go +++ b/relay/adaptor/xunfei/constants.go @@ -5,6 +5,8 @@ var ModelList = []string{ "SparkDesk-v1.1", "SparkDesk-v2.1", "SparkDesk-v3.1", + "SparkDesk-v3.1-128K", "SparkDesk-v3.5", + "SparkDesk-v3.5-32K", "SparkDesk-v4.0", } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index ef6120e570..3984ba5a98 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, } func parseAPIVersionByModelName(modelName string) string { - parts := strings.Split(modelName, "-") - if len(parts) == 2 { - return parts[1] + index := strings.IndexAny(modelName, "-") + if index != -1 { + return modelName[index+1:] } return "" } @@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string { func apiVersion2domain(apiVersion string) string { switch apiVersion { case "v1.1": - return "general" + return "lite" case "v2.1": return "generalv2" case "v3.1": return "generalv3" + case "v3.1-128K": + return "pro-128k" case "v3.5": return "generalv3.5" + case "v3.5-32K": + return "max-32k" case "v4.0": return "4.0Ultra" } @@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string { } func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { + var authUrl string domain := apiVersion2domain(apiVersion) - authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + switch apiVersion { + case "v3.1-128K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret) + break + case "v3.5-32K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) + break + default: + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + } return domain, authUrl } diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go index 1f37c04655..c9fb1bb8f2 100644 --- a/relay/adaptor/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -19,11 +19,11 @@ type ChatRequest struct { } `json:"header"` Parameter struct { Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` + Domain string `json:"domain,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` } `json:"chat"` } `json:"parameter"` Payload struct { diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 78b01fb3f7..660bd37960 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -4,13 +4,13 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "io" - "math" "net/http" "strings" ) @@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, err default: - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + // TopP [0.0, 1.0] + request.TopP = helper.Float64PtrMax(request.TopP, 1) + request.TopP = helper.Float64PtrMin(request.TopP, 0) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) + // Temperature [0.0, 1.0] + request.Temperature = helper.Float64PtrMax(request.Temperature, 1) + request.Temperature = helper.Float64PtrMin(request.Temperature, 0) a.SetVersionByModeName(request.Model) if a.APIVersion == "v4" { return request, nil diff --git a/relay/adaptor/zhipu/model.go b/relay/adaptor/zhipu/model.go index f91de1dced..06e22dc153 100644 --- a/relay/adaptor/zhipu/model.go +++ b/relay/adaptor/zhipu/model.go @@ -12,8 +12,8 @@ type Message struct { type Request struct { Prompt []Message `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` RequestId string `json:"request_id,omitempty"` Incremental bool `json:"incremental,omitempty"` } diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go index ced0c6678c..c8c42a15c0 100644 --- a/relay/billing/ratio/image.go +++ b/relay/billing/ratio/image.go @@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{ "720x1280": 1, "1280x720": 1, }, + "step-1x-medium": { + "256x256": 1, + "512x512": 1, + "768x768": 1, + "1024x1024": 1, + "1280x800": 1, + "800x1280": 1, + }, } var ImageGenerationAmounts = map[string][2]int{ @@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{ "ali-stable-diffusion-v1.5": {1, 4}, // Ali "wanx-v1": {1, 4}, // Ali "cogview-3": {1, 1}, + "step-1x-medium": {1, 1}, } var ImagePromptLengthLimitations = map[string]int{ @@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{ "ali-stable-diffusion-v1.5": 4000, "wanx-v1": 4000, "cogview-3": 833, + "step-1x-medium": 4000, } var ImageOriginModelName = map[string]string{ diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 7bc6cd5420..1b58ec0902 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -34,7 +34,9 @@ var ModelRatio = map[string]float64{ "gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4o": 2.5, // $0.005 / 1K tokens + "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens @@ -77,8 +79,10 @@ var ModelRatio = map[string]float64{ "claude-2.0": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD, "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-5-haiku-20241022": 1.0 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, @@ -126,7 +130,9 @@ var ModelRatio = map[string]float64{ "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens @@ -158,23 +164,34 @@ var ModelRatio = map[string]float64{ "mistral-embed": 0.1 / 1000 * USD, // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed "gemma-7b-it": 0.07 / 1000000 * USD, - "mixtral-8x7b-32768": 0.24 / 1000000 * USD, - "llama3-8b-8192": 0.05 / 1000000 * USD, - "llama3-70b-8192": 0.59 / 1000000 * USD, "gemma2-9b-it": 0.20 / 1000000 * USD, - "llama-3.1-405b-reasoning": 0.89 / 1000000 * USD, "llama-3.1-70b-versatile": 0.59 / 1000000 * USD, "llama-3.1-8b-instant": 0.05 / 1000000 * USD, + "llama-3.2-11b-text-preview": 0.05 / 1000000 * USD, + "llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD, + "llama-3.2-1b-preview": 0.05 / 1000000 * USD, + "llama-3.2-3b-preview": 0.05 / 1000000 * USD, + "llama-3.2-90b-text-preview": 0.59 / 1000000 * USD, + "llama-guard-3-8b": 0.05 / 1000000 * USD, + "llama3-70b-8192": 0.59 / 1000000 * USD, + "llama3-8b-8192": 0.05 / 1000000 * USD, "llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, "llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD, + "mixtral-8x7b-32768": 0.24 / 1000000 * USD, + // https://platform.lingyiwanwu.com/docs#-计费单元 "yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB, "yi-vl-plus": 6.0 / 1000 * RMB, - // stepfun todo - "step-1v-32k": 0.024 * RMB, - "step-1-32k": 0.024 * RMB, - "step-1-200k": 0.15 * RMB, + // https://platform.stepfun.com/docs/pricing/details + "step-1-8k": 0.005 / 1000 * RMB, + "step-1-32k": 0.015 / 1000 * RMB, + "step-1-128k": 0.040 / 1000 * RMB, + "step-1-256k": 0.095 / 1000 * RMB, + "step-1-flash": 0.001 / 1000 * RMB, + "step-2-16k": 0.038 / 1000 * RMB, + "step-1v-8k": 0.005 / 1000 * RMB, + "step-1v-32k": 0.015 / 1000 * RMB, // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens @@ -192,6 +209,8 @@ var ModelRatio = map[string]float64{ "deepl-zh": 25.0 / 1000 * USD, "deepl-en": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD, + // https://console.x.ai/ + "grok-beta": 5.0 / 1000 * USD, } var CompletionRatio = map[string]float64{ @@ -200,8 +219,10 @@ var CompletionRatio = map[string]float64{ "llama3-70b-8192(33)": 0.0035 / 0.00265, } -var DefaultModelRatio map[string]float64 -var DefaultCompletionRatio map[string]float64 +var ( + DefaultModelRatio map[string]float64 + DefaultCompletionRatio map[string]float64 +) func init() { DefaultModelRatio = make(map[string]float64) @@ -313,7 +334,7 @@ func GetCompletionRatio(name string, channelType int) float64 { return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { - if strings.HasPrefix(name, "gpt-4o-mini") { + if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" { return 4 } if strings.HasPrefix(name, "gpt-4-turbo") || @@ -323,6 +344,9 @@ func GetCompletionRatio(name string, channelType int) float64 { } return 2 } + if name == "chatgpt-4o-latest" { + return 3 + } if strings.HasPrefix(name, "claude-3") { return 5 } @@ -351,6 +375,8 @@ func GetCompletionRatio(name string, channelType int) float64 { return 3 case "command-r-plus": return 5 + case "grok-beta": + return 3 } return 1 } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index a261cff85d..98316959a1 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -46,5 +46,6 @@ const ( VertextAI Proxy SiliconFlow + XAI Dummy ) diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 8727faea15..b8bd61f89e 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -45,7 +45,8 @@ var ChannelBaseURLs = []string{ "https://api.novita.ai/v3/openai", // 41 "", // 42 "", // 43 - "https://api.siliconflow.cn", // 44 + "https://api.siliconflow.cn", // 44 + "https://api.x.ai", // 45 } func init() { diff --git a/relay/constant/role/define.go b/relay/constant/role/define.go index 972488c5c9..5097c97e21 100644 --- a/relay/constant/role/define.go +++ b/relay/constant/role/define.go @@ -1,5 +1,6 @@ package role const ( + System = "system" Assistant = "assistant" ) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 87d22f13f6..567dee7c5a 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/songquanpeng/one-api/relay/constant/role" "math" "net/http" "strings" @@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return @@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + var extraLog string + if systemPromptReset { + extraLog = " (注意系统提示词已被重置)" + } + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) @@ -154,3 +159,23 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { } return false } + +func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { + if prompt == "" { + return false + } + if len(request.Messages) == 0 { + return false + } + if request.Messages[0].Role == role.System { + request.Messages[0].Content = prompt + logger.Infof(ctx, "rewrite system prompt") + return true + } + request.Messages = append([]relaymodel.Message{{ + Role: role.System, + Content: prompt, + }}, request.Messages...) + logger.Infof(ctx, "add system prompt") + return true +} diff --git a/relay/controller/text.go b/relay/controller/text.go index 52ee9949ae..9a47c58bc2 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/config" "io" "net/http" @@ -35,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta.OriginModelName = textRequest.Model textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model + // set system prompt if not empty + systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) // get model ratio & group ratio modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) @@ -79,12 +82,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return respErr } // post-consume quota - go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) return nil } func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { - if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { + if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { // no need to convert request for openai return c.Request.Body, nil } diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index b1761e9a7c..bcbe10453a 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -30,6 +30,7 @@ type Meta struct { ActualModelName string RequestURLPath string PromptTokens int // only for DoResponse + SystemPrompt string } func GetByContext(c *gin.Context) *Meta { @@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta { BaseURL: c.GetString(ctxkey.BaseURL), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), + SystemPrompt: c.GetString(ctxkey.SystemPrompt), } cfg, ok := c.Get(ctxkey.Config) if ok { diff --git a/relay/model/constant.go b/relay/model/constant.go index f6cf1924d1..c9d6d645c6 100644 --- a/relay/model/constant.go +++ b/relay/model/constant.go @@ -1,6 +1,7 @@ package model const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" + ContentTypeText = "text" + ContentTypeImageURL = "image_url" + ContentTypeInputAudio = "input_audio" ) diff --git a/relay/model/general.go b/relay/model/general.go index c34c1c2d5d..288c07ffb5 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -1,35 +1,70 @@ package model type ResponseFormat struct { - Type string `json:"type,omitempty"` + Type string `json:"type,omitempty"` + JsonSchema *JSONSchema `json:"json_schema,omitempty"` +} + +type JSONSchema struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Schema map[string]interface{} `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +type Audio struct { + Voice string `json:"voice,omitempty"` + Format string `json:"format,omitempty"` +} + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` } type GeneralOpenAIRequest struct { - Messages []Message `json:"messages,omitempty"` - Model string `json:"model,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Stop any `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - FunctionCall any `json:"function_call,omitempty"` - Functions any `json:"functions,omitempty"` - User string `json:"user,omitempty"` - Prompt any `json:"prompt,omitempty"` - Input any `json:"input,omitempty"` - EncodingFormat string `json:"encoding_format,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + // https://platform.openai.com/docs/api-reference/chat/create + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store *bool `json:"store,omitempty"` + Metadata any `json:"metadata,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` + // https://platform.openai.com/docs/api-reference/embeddings/create + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // https://platform.openai.com/docs/api-reference/images/create + Prompt any `json:"prompt,omitempty"` + Quality *string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style *string `json:"style,omitempty"` + // Others + Instruction string `json:"instruction,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/router/api.go b/router/api.go index d2ada4ebda..6d00c6eaa1 100644 --- a/router/api.go +++ b/router/api.go @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) diff --git a/router/relay.go b/router/relay.go index 094ea5fb51..8f3c73030d 100644 --- a/router/relay.go +++ b/router/relay.go @@ -9,6 +9,7 @@ import ( func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) + router.Use(middleware.GzipDecodeMiddleware()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js index 0853ddfbee..48836c859a 100644 --- a/web/air/src/components/TokensTable.js +++ b/web/air/src/components/TokensTable.js @@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken'; const COPY_OPTIONS = [ { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, - { key: 'opencat', text: 'OpenCat', value: 'opencat' } + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' }, ]; const OPEN_LINK_OPTIONS = [ { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, - { key: 'opencat', text: 'OpenCat', value: 'opencat' } + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' } ]; function renderTimestamp(timestamp) { @@ -60,7 +62,12 @@ const TokensTable = () => { onOpenLink('next-mj'); } }, - { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } + { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } + } ]; const columns = [ @@ -177,6 +184,11 @@ const TokensTable = () => { node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { onOpenLink('opencat', record.key); } + }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } } ] } @@ -382,6 +394,9 @@ const TokensTable = () => { case 'next-mj': url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; break; + case 'lobechat': + url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`; + break; default: if (!chatLink) { showError('管理员未设置聊天链接'); diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js index 04fe94f17a..a7e984ecf5 100644 --- a/web/air/src/constants/channel.constants.js +++ b/web/air/src/constants/channel.constants.js @@ -30,6 +30,7 @@ export const CHANNEL_OPTIONS = [ { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, + { key: 45, text: 'xAI', value: 45, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index 73fd2da200..4a810830bd 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -43,6 +43,7 @@ const EditChannel = (props) => { base_url: '', other: '', model_mapping: '', + system_prompt: '', models: [], auto_ban: 1, groups: ['default'] @@ -63,7 +64,7 @@ const EditChannel = (props) => { let localModels = []; switch (value) { case 14: - localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; + localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"]; break; case 11: localModels = ['PaLM-2']; @@ -78,7 +79,7 @@ const EditChannel = (props) => { localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: - localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; + localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0']; break; case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; @@ -304,163 +305,163 @@ const EditChannel = (props) => { width={isMobile() ? '100%' : 600} > -
+
类型:
{ - handleInputChange('base_url', value) - }} - value={inputs.base_url} - autoComplete='new-password' - /> -
- 默认 API 版本: -
- { - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 3 && ( + <> +
+ + 注意,模型部署名称必须和模型名称保持一致,因为 One API 会把请求体中的 + model + 参数替换为你的部署名称(模型名称中的点会被剔除),图片演示。 + + }> + +
+
+ AZURE_OPENAI_ENDPOINT: +
+ { + handleInputChange('base_url', value) + }} + value={inputs.base_url} + autoComplete='new-password' + /> +
+ 默认 API 版本: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } { - inputs.type === 8 && ( - <> -
- Base URL: -
- { - handleInputChange('base_url', value) - }} - value={inputs.base_url} - autoComplete='new-password' - /> - - ) + inputs.type === 8 && ( + <> +
+ Base URL: +
+ { + handleInputChange('base_url', value) + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + ) } -
+
名称:
{ - handleInputChange('name', value) - }} - value={inputs.name} - autoComplete='new-password' + required + name='name' + placeholder={'请为渠道命名'} + onChange={value => { + handleInputChange('name', value) + }} + value={inputs.name} + autoComplete='new-password' /> -
+
分组:
{ - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 18 && ( + <> +
+ 模型版本: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } { - inputs.type === 21 && ( - <> -
- 知识库 ID: -
- { - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 21 && ( + <> +
+ 知识库 ID: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } -
+
模型:
填入 - } - placeholder='输入自定义模型名称' - value={customModel} - onChange={(value) => { - setCustomModel(value.trim()); - }} + addonAfter={ + + } + placeholder='输入自定义模型名称' + value={customModel} + onChange={(value) => { + setCustomModel(value.trim()); + }} />
-
+
模型重定向: