diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 3aa4974b83..3746bcdaf2 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -34,6 +34,12 @@ var mimeTypeMap = map[string]string{ "text": "text/plain", } +var toolChoiceTypeMap = map[string]string{ + "none": "NONE", + "auto": "AUTO", + "required": "ANY", +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { geminiRequest := ChatRequest{ @@ -92,7 +98,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { }, } } - shouldAddDummyModelMessage := false + if textRequest.ToolChoice != nil { + geminiRequest.ToolConfig = &ToolConfig{ + FunctionCallingConfig: FunctionCallingConfig{ + Mode: "auto", + }, + } + switch mode := textRequest.ToolChoice.(type) { + case string: + geminiRequest.ToolConfig.FunctionCallingConfig.Mode = toolChoiceTypeMap[mode] + case map[string]interface{}: + geminiRequest.ToolConfig.FunctionCallingConfig.Mode = "ANY" + if fn, ok := mode["function"].(map[string]interface{}); ok { + if name, ok := fn["name"].(string); ok { + geminiRequest.ToolConfig.FunctionCallingConfig.AllowedFunctionNames = []string{name} + } + } + } + } for _, message := range textRequest.Messages { content := ChatContent{ Role: message.Role, @@ -130,25 +153,12 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { if content.Role == "assistant" { content.Role = "model" } - // Converting system prompt to prompt from user for the same reason + // Converting system prompt to SystemInstructions if content.Role == "system" { - content.Role = "user" - shouldAddDummyModelMessage = true + geminiRequest.SystemInstruction = &content + continue } geminiRequest.Contents = append(geminiRequest.Contents, content) - - // If a system message is the last message, we need to add a dummy model message to make gemini happy - if shouldAddDummyModelMessage { - geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{ - Role: "model", - Parts: []Part{ - { - Text: "Okay", - }, - }, - }) - shouldAddDummyModelMessage = false - } } return &geminiRequest @@ -186,10 +196,16 @@ func (g *ChatResponse) GetResponseText() string { if g == nil { return "" } - if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { - return g.Candidates[0].Content.Parts[0].Text + var builder strings.Builder + for _, candidate := range g.Candidates { + for idx, part := range candidate.Content.Parts { + if idx > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } } - return "" + return builder.String() } type ChatCandidate struct { @@ -252,8 +268,8 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { choice.Message.ToolCalls = getToolCalls(&candidate) } else { var builder strings.Builder - for _, part := range candidate.Content.Parts { - if i > 0 { + for idx, part := range candidate.Content.Parts { + if idx > 0 { builder.WriteString("\n") } builder.WriteString(part.Text) diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index 720cb65d19..a19248bc90 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -1,10 +1,12 @@ package gemini type ChatRequest struct { - Contents []ChatContent `json:"contents"` - SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` - Tools []ChatTools `json:"tools,omitempty"` + Contents []ChatContent `json:"contents"` + SystemInstruction *ChatContent `json:"system_instruction,omitempty"` + SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` + Tools []ChatTools `json:"tools,omitempty"` + ToolConfig *ToolConfig `json:"tool_config,omitempty"` } type EmbeddingRequest struct { @@ -74,3 +76,12 @@ type ChatGenerationConfig struct { CandidateCount int `json:"candidateCount,omitempty"` StopSequences []string `json:"stopSequences,omitempty"` } + +type FunctionCallingConfig struct { + Mode string `json:"mode,omitempty"` + AllowedFunctionNames []string `json:"allowed_function_names,omitempty"` +} + +type ToolConfig struct { + FunctionCallingConfig FunctionCallingConfig `json:"function_calling_config"` +}