Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from songquanpeng:main #52

Merged
merged 35 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6714cf9
fix: Groq organization not auto-disabled when blocked (#1822)
kaqijiang Sep 21, 2024
58bf608
fix: postgres use COALESCE replace null (#1793)
guogeer Sep 21, 2024
8cc5448
feat: update disabled channel (#1780)
ye4293 Sep 21, 2024
89b63ca
feat: ResponseFormat support json_schema (#1759)
majian159 Sep 21, 2024
15f815c
fix: fix ali embedding model always use v1 (#1747)
leavegee Sep 21, 2024
c9ac670
feat: update stepfun models (#1740)
forrestlinfeng Sep 21, 2024
cba8240
feat: add lobechat open link options (#1741)
MrChihiro Sep 21, 2024
3a27c90
fix: getTokenById return token nil, make panic (#1728)
qinguoyi Sep 21, 2024
649ecbf
feat: support new openai models (4o 0806, chatgpt-4o-latest) (#1721)
HynoR Sep 21, 2024
99c8c77
feat: add oidc support (#1725)
Oniokey Sep 21, 2024
32f90a7
feat: support SparkDesk-v3.1-128K (#1732)
lihangfu Sep 22, 2024
a216810
feat: add siliconflow usage (#1798)
TimeTrapzz Sep 22, 2024
88acc5a
fix: return the usage info if not null (#1792)
byte911 Sep 22, 2024
29389ed
fix: modify the type of token models to be text (#1761)
xuruidong Sep 22, 2024
fdd7bf4
feat: support multipart/form-data format request (#1690)
bearslyricattack Sep 22, 2024
e32cb0b
feat: support SparkDesk-v3.5-32K (#1832)
lihangfu Oct 27, 2024
cb772e5
fix:unsuccessful lobechat redirection link (#1843)
MrChihiro Oct 27, 2024
a4d6e7a
feat: add Vertex AI gemini-1.5-pro-002 and gemini-1.5-flash-002 (#1854)
hellof20 Oct 27, 2024
3716e1b
fix: use modelMap when testing a channel (#1855)
liangjs Oct 27, 2024
6293786
feat: update groq model and price (#1864)
longkeyy Oct 27, 2024
f092eed
feat: add support for Claude Sonnet 3.5 v2 (#1888)
shaoyun Oct 27, 2024
6f13a3b
feat: update Gemini adaptor to support custom response format (#1892)
mxdlzg Oct 27, 2024
f75a17f
feat: always return usage in stream mode
songquanpeng Oct 27, 2024
7e51b04
feat: able to hide test model selector and balance col
songquanpeng Oct 27, 2024
b0b88a7
feat: added support for Claude 3.5 Haiku (#1912)
shaoyun Nov 7, 2024
8ec092b
feat: add support for xAI (#1915)
lwshen Nov 7, 2024
cbfc983
feat: add new claude models (#1910)
Laisky Nov 9, 2024
c368232
fix: changeoptional field to pointer type (#1907)
wanthigh Nov 9, 2024
a3d7df7
feat: update GeneralOpenAIRequest
songquanpeng Nov 9, 2024
2b2dc2c
fix: update Spark Lite's domain to lite (#1896)
lihangfu Nov 9, 2024
92cd46d
feat: able to use ENFORCE_INCLUDE_USAGE to enforce include usage in r…
songquanpeng Nov 9, 2024
6eb0770
feat: support set system prompt for channel (close #1920)
songquanpeng Nov 10, 2024
833fa7a
feat: support set system_prompt for theme air & berry
songquanpeng Nov 10, 2024
6ab87f8
feat: add warning in log when system prompt is reset
songquanpeng Nov 10, 2024
7c8628b
feat: support gzip decode (#1962)
Calcium-Ion Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
+ [x] [xAI](https://x.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
Expand Down Expand Up @@ -399,6 +400,7 @@ graph LR
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。

### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
Expand Down
10 changes: 10 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
Expand Down Expand Up @@ -70,6 +71,13 @@ var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""

var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""

var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
Expand Down Expand Up @@ -152,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
1 change: 1 addition & 0 deletions common/ctxkey/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ const (
BaseURL = "base_url"
AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
)
6 changes: 3 additions & 3 deletions common/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(&v)
}
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}

Expand Down
20 changes: 20 additions & 0 deletions common/helper/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,23 @@ func String2Int(str string) int {
}
return num
}

func Float64PtrMax(p *float64, maxValue float64) *float64 {
if p == nil {
return nil
}
if *p > maxValue {
return &maxValue
}
return p
}

func Float64PtrMin(p *float64, minValue float64) *float64 {
if p == nil {
return nil
}
if *p < minValue {
return &minValue
}
return p
}
225 changes: 225 additions & 0 deletions controller/auth/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package auth

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)

type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}

type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}

func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.OidcClientId,
"client_secret": config.OidcClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
return &oidcUser, nil
}

func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}

if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}

func OidcBind(c *gin.Context) {
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}
8 changes: 5 additions & 3 deletions controller/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
if err == nil {
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
}
} else {
userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId)
Expand Down
44 changes: 44 additions & 0 deletions controller/channel-billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
TotalUsed float64 `json:"total_used"`
}

type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}

// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
Expand Down Expand Up @@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil
}

func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}

func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
Expand All @@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAPI2GPTBalance(channel)
case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel)
case channeltype.SiliconFlow:
return updateChannelSiliconFlowBalance(channel)
default:
return 0, errors.New("尚未实现")
}
Expand Down
Loading
Loading