Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Default Options Value Override Feature #1859

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ logs
data
/web/node_modules
cmd.md
.env
.env
one-api
one-api-motor
1 change: 1 addition & 0 deletions common/ctxkey/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
OriginalModel = "original_model"
Group = "group"
ModelMapping = "model_mapping"
ParamsOverride = "params_override"
ChannelName = "channel_name"
TokenId = "token_id"
TokenName = "token_name"
Expand Down
1 change: 1 addition & 0 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name)
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.ParamsOverride, channel.GetParamsOverride())
c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
Expand Down
15 changes: 15 additions & 0 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
ParamsOverride *string `json:"default_params_override" gorm:"type:text;default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
}
Expand Down Expand Up @@ -123,6 +124,20 @@ func (channel *Channel) GetModelMapping() map[string]string {
return modelMapping
}

func (channel *Channel) GetParamsOverride() map[string]map[string]interface{} {
if channel.ParamsOverride == nil || *channel.ParamsOverride == "" || *channel.ParamsOverride == "{}" {
return nil
}
paramsOverride := make(map[string]map[string]interface{})
err := json.Unmarshal([]byte(*channel.ParamsOverride), &paramsOverride)
if err != nil {
logger.SysError(fmt.Sprintf("failed to unmarshal params override for channel %d, error: %s", channel.Id, err.Error()))
return nil
}
return paramsOverride
}


func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
Expand Down
104 changes: 97 additions & 7 deletions relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"io"
"net/http"
"io/ioutil"
"context"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
Expand All @@ -23,13 +25,34 @@ import (
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
meta.IsStream = textRequest.Stream

// Read the original request body
bodyBytes, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
logger.Errorf(ctx, "Failed to read request body: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
}

// Restore the request body for `getAndValidateTextRequest`
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))

// Call `getAndValidateTextRequest`
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
meta.IsStream = textRequest.Stream

// Parse the request body into a map
var rawRequest map[string]interface{}
if err := json.Unmarshal(bodyBytes, &rawRequest); err != nil {
logger.Errorf(ctx, "Failed to parse request body into map: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest)
}

// Apply parameter overrides
applyParameterOverrides(ctx, meta, textRequest, rawRequest)

// map model name
meta.OriginModelName = textRequest.Model
Expand Down Expand Up @@ -105,3 +128,70 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}

func applyParameterOverrides(ctx context.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, rawRequest map[string]interface{}) {
if meta.ParamsOverride != nil {
modelName := meta.OriginModelName
if overrideParams, exists := meta.ParamsOverride[modelName]; exists {
logger.Infof(ctx, "Applying parameter overrides for model %s on channel %d", modelName, meta.ChannelId)
for key, value := range overrideParams {
if _, userSpecified := rawRequest[key]; !userSpecified {
// Apply the override since the user didn't specify this parameter
switch key {
case "temperature":
if v, ok := value.(float64); ok {
textRequest.Temperature = v
} else if v, ok := value.(int); ok {
textRequest.Temperature = float64(v)
}
case "max_tokens":
if v, ok := value.(float64); ok {
textRequest.MaxTokens = int(v)
} else if v, ok := value.(int); ok {
textRequest.MaxTokens = v
}
case "top_p":
if v, ok := value.(float64); ok {
textRequest.TopP = v
} else if v, ok := value.(int); ok {
textRequest.TopP = float64(v)
}
case "frequency_penalty":
if v, ok := value.(float64); ok {
textRequest.FrequencyPenalty = v
} else if v, ok := value.(int); ok {
textRequest.FrequencyPenalty = float64(v)
}
case "presence_penalty":
if v, ok := value.(float64); ok {
textRequest.PresencePenalty = v
} else if v, ok := value.(int); ok {
textRequest.PresencePenalty = float64(v)
}
case "stop":
textRequest.Stop = value
case "n":
if v, ok := value.(float64); ok {
textRequest.N = int(v)
} else if v, ok := value.(int); ok {
textRequest.N = v
}
case "stream":
if v, ok := value.(bool); ok {
textRequest.Stream = v
}
case "num_ctx":
if v, ok := value.(float64); ok {
textRequest.NumCtx = int(v)
} else if v, ok := value.(int); ok {
textRequest.NumCtx = v
}
// Handle other parameters as needed
default:
logger.Warnf(ctx, "Unknown parameter override key: %s", key)
}
}
}
}
}
}
6 changes: 6 additions & 0 deletions relay/meta/relay_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Meta struct {
UserId int
Group string
ModelMapping map[string]string
ParamsOverride map[string]map[string]interface{}
// BaseURL is the proxy url set in the channel config
BaseURL string
APIKey string
Expand Down Expand Up @@ -47,6 +48,11 @@ func GetByContext(c *gin.Context) *Meta {
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(),
}
// Retrieve ParamsOverride
paramsOverride, exists := c.Get(ctxkey.ParamsOverride)
if exists && paramsOverride != nil {
meta.ParamsOverride = paramsOverride.(map[string]map[string]interface{})
}
cfg, ok := c.Get(ctxkey.Config)
if ok {
meta.Config = cfg.(model.ChannelConfig)
Expand Down
21 changes: 20 additions & 1 deletion web/default/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ const EditChannel = () => {
showInfo('模型映射必须是合法的 JSON 格式!');
return;
}
if (inputs.default_params_override !== '' && !verifyJSON(inputs.default_params_override)) {
showInfo('默认参数Override必须是合法的 JSON 格式!');
return;
}
let localInputs = {...inputs};
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
Expand Down Expand Up @@ -425,7 +429,7 @@ const EditChannel = () => {
)
}
{
inputs.type !== 43 && (
inputs.type !== 43 && (
<Form.Field>
<Form.TextArea
label='模型重定向'
Expand All @@ -439,6 +443,21 @@ const EditChannel = () => {
</Form.Field>
)
}
{
inputs.type !== 43 && (
<Form.Field>
<Form.TextArea
label='默认参数Override'
placeholder={`此项可选,用于修改请求体中的默认参数,为一个 JSON 字符串,键为请求中模型名称,值为要替换的默认参数,例如:\n${JSON.stringify({ 'llama3:70b': { 'num_ctx': 11520, 'temperature': 0.2 }, 'qwen2:72b': { 'num_ctx': 11520, 'temperature': 0.8 } }, null, 2)}`}
name='default_params_override'
onChange={handleInputChange}
value={inputs.default_params_override}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
)
}
{
inputs.type === 33 && (
<Form.Field>
Expand Down