From fddc1e2254d1585da0bf19d661f0582f3cdcca60 Mon Sep 17 00:00:00 2001 From: ekatiyar Date: Wed, 7 Aug 2024 18:11:01 -0700 Subject: [PATCH] Add Config to Disable Model Mapping (#41) Introduce a new environment variable, DISABLE_MODEL_MAPPING, when enabled removes the OpenAI -> Gemini Model mapping and provides access to the named gemini models directly. Moved model mapping logic to its own `models.go` file inside the adapter package. Additionally, fixed a bug where responses would return the gemini model name even though model mapping was enabled. --- README.md | 43 ++++++++++++++--- api/handler.go | 24 +++++----- pkg/adapter/chat.go | 11 ++--- pkg/adapter/models.go | 109 ++++++++++++++++++++++++++++++++++++++++++ pkg/adapter/struct.go | 29 ++--------- 5 files changed, 164 insertions(+), 52 deletions(-) create mode 100644 pkg/adapter/models.go diff --git a/README.md b/README.md index 15e9154..957609b 100644 --- a/README.md +++ b/README.md @@ -30,9 +30,25 @@ go build -o gemini main.go We recommend deploying Gemini-OpenAI-Proxy using Docker for a straightforward setup. Follow these steps to deploy with Docker: -```bash -docker run --restart=always -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest -``` + You can either do this on the command line: + ```bash + docker run --restart=unless-stopped -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest + ``` + + Or with the following docker-compose config: + ```yaml + version: '3' + services: + gemini: + container_name: gemini + environment: # Set Environment Variables here. Defaults listed below + - GPT_4_VISION_PREVIEW=gemini-1.5-flash-latest + - DISABLE_MODEL_MAPPING=0 + ports: + - "8080:8080" + image: zhu327/gemini-openai-proxy:latest + restart: unless-stopped + ``` Adjust the port mapping (e.g., `-p 8080:8080`) as needed, and ensure that the Docker image version (`zhu327/gemini-openai-proxy:latest`) aligns with your requirements. @@ -83,6 +99,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali "temperature": 0.7 }' ``` + If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data. Otherwise, the default uses the `gemini-1.5-flash-latest` model If you already have access to the Gemini 1.5 Pro api, you can use: @@ -104,7 +121,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali -H "Content-Type: application/json" \ -H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \ -d '{ - "model": "ada-002", + "model": "text-embedding-ada-002", "input": "This is a test sentence." }' ``` @@ -116,7 +133,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali -H "Content-Type: application/json" \ -H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \ -d '{ - "model": "ada-002", + "model": "text-embedding-ada-002", "input": ["This is a test sentence.", "This is another test sentence"] }' ``` @@ -129,9 +146,21 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali | gpt-4 | gemini-1.5-flash-latest | | gpt-4-turbo-preview | gemini-1.5-pro-latest | | gpt-4-vision-preview | gemini-1.0-pro-vision-latest | - | ada-002 | text-embedding-004 | + | text-embedding-ada-002 | text-embedding-004 | - If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data. + If you want to disable model mapping, configure the environment variable `DISABLE_MODEL_MAPPING=1`. This will allow you to refer to the Gemini models directly. + + Here is an example API request with model mapping disabled: + ```bash + curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \ + -d '{ + "model": "gemini-1.0-pro-latest", + "messages": [{"role": "user", "content": "Say this is a test!"}], + "temperature": 0.7 + }' + ``` 4. **Handle Responses:** Process the responses from the Gemini-OpenAI-Proxy in the same way you would handle responses from OpenAI. diff --git a/api/handler.go b/api/handler.go index 4fabfb6..b28e916 100644 --- a/api/handler.go +++ b/api/handler.go @@ -21,38 +21,39 @@ func IndexHandler(c *gin.Context) { } func ModelListHandler(c *gin.Context) { + owner := adapter.GetOwner() c.JSON(http.StatusOK, gin.H{ "object": "list", "data": []any{ openai.Model{ CreatedAt: 1686935002, - ID: openai.GPT3Dot5Turbo, + ID: adapter.GetModel(openai.GPT3Dot5Turbo), Object: "model", - OwnedBy: "openai", + OwnedBy: owner, }, openai.Model{ CreatedAt: 1686935002, - ID: openai.GPT4, + ID: adapter.GetModel(openai.GPT4), Object: "model", - OwnedBy: "openai", + OwnedBy: owner, }, openai.Model{ CreatedAt: 1686935002, - ID: openai.GPT4TurboPreview, + ID: adapter.GetModel(openai.GPT4TurboPreview), Object: "model", - OwnedBy: "openai", + OwnedBy: owner, }, openai.Model{ CreatedAt: 1686935002, - ID: openai.GPT4VisionPreview, + ID: adapter.GetModel(openai.GPT4VisionPreview), Object: "model", - OwnedBy: "openai", + OwnedBy: owner, }, openai.Model{ CreatedAt: 1686935002, - ID: openai.GPT3Ada002, + ID: adapter.GetModel(string(openai.AdaEmbeddingV2)), Object: "model", - OwnedBy: "openai", + OwnedBy: owner, }, }, }) @@ -60,11 +61,12 @@ func ModelListHandler(c *gin.Context) { func ModelRetrieveHandler(c *gin.Context) { model := c.Param("model") + owner := adapter.GetOwner() c.JSON(http.StatusOK, openai.Model{ CreatedAt: 1686935002, ID: model, Object: "model", - OwnedBy: "openai", + OwnedBy: owner, }) } diff --git a/pkg/adapter/chat.go b/pkg/adapter/chat.go index 5a19dcd..8c5f398 100644 --- a/pkg/adapter/chat.go +++ b/pkg/adapter/chat.go @@ -17,11 +17,6 @@ import ( ) const ( - Gemini1Pro = "gemini-1.0-pro-latest" - Gemini1Dot5Pro = "gemini-1.5-pro-latest" - Gemini1Dot5Flash = "gemini-1.5-flash-latest" - TextEmbedding004 = "text-embedding-004" - genaiRoleUser = "user" genaiRoleModel = "model" ) @@ -121,7 +116,7 @@ func genaiResponseToStreamCompletionResponse( ID: fmt.Sprintf("chatcmpl-%s", respID), Object: "chat.completion.chunk", Created: created, - Model: model, + Model: GetMappedModel(model), Choices: make([]CompletionChoice, 0, len(genaiResp.Candidates)), } @@ -156,7 +151,7 @@ func genaiResponseToOpenaiResponse( ID: fmt.Sprintf("chatcmpl-%s", util.GetUUID()), Object: "chat.completion", Created: time.Now().Unix(), - Model: model, + Model: GetMappedModel(model), Choices: make([]openai.ChatCompletionChoice, 0, len(genaiResp.Candidates)), } @@ -260,7 +255,7 @@ func (g *GeminiAdapter) GenerateEmbedding( openaiResp := openai.EmbeddingResponse{ Object: "list", Data: make([]openai.Embedding, 0, len(genaiResp.Embeddings)), - Model: openai.EmbeddingModel(g.model), + Model: openai.EmbeddingModel(GetMappedModel(g.model)), } for i, genaiEmbedding := range genaiResp.Embeddings { diff --git a/pkg/adapter/models.go b/pkg/adapter/models.go new file mode 100644 index 0000000..8e976ff --- /dev/null +++ b/pkg/adapter/models.go @@ -0,0 +1,109 @@ +package adapter + +import ( + "os" + "strings" + + openai "github.com/sashabaranov/go-openai" +) + +const ( + Gemini1Pro = "gemini-1.0-pro-latest" + Gemini1Dot5Pro = "gemini-1.5-pro-latest" + Gemini1Dot5Flash = "gemini-1.5-flash-latest" + Gemini1Dot5ProV = "gemini-1.0-pro-vision-latest" // Converted to one of the above models in struct::ToGenaiModel + TextEmbedding004 = "text-embedding-004" +) + +var USE_MODEL_MAPPING bool = os.Getenv("DISABLE_MODEL_MAPPING") != "1" + +func GetOwner() string { + if USE_MODEL_MAPPING { + return "openai" + } else { + return "google" + } +} + +func GetModel(openAiModelName string) string { + if USE_MODEL_MAPPING { + return openAiModelName + } else { + return ConvertModel(openAiModelName) + } +} + +func GetMappedModel(geminiModelName string) string { + if !USE_MODEL_MAPPING { + return geminiModelName + } + switch { + case geminiModelName == Gemini1Dot5ProV: + return openai.GPT4VisionPreview + case geminiModelName == Gemini1Dot5Pro: + return openai.GPT4TurboPreview + case geminiModelName == Gemini1Dot5Flash: + return openai.GPT4 + case geminiModelName == TextEmbedding004: + return string(openai.AdaEmbeddingV2) + default: + return openai.GPT3Dot5Turbo + } +} + +func ConvertModel(openAiModelName string) string { + switch { + case openAiModelName == openai.GPT4VisionPreview: + return Gemini1Dot5ProV + case openAiModelName == openai.GPT4TurboPreview || openAiModelName == openai.GPT4Turbo1106 || openAiModelName == openai.GPT4Turbo0125: + return Gemini1Dot5Pro + case strings.HasPrefix(openAiModelName, openai.GPT4): + return Gemini1Dot5Flash + case openAiModelName == string(openai.AdaEmbeddingV2): + return TextEmbedding004 + default: + return Gemini1Pro + } +} + +func (req *ChatCompletionRequest) ToGenaiModel() string { + if USE_MODEL_MAPPING { + return req.ParseModelWithMapping() + } else { + return req.ParseModelWithoutMapping() + } +} + +func (req *ChatCompletionRequest) ParseModelWithoutMapping() string { + switch { + case req.Model == Gemini1Dot5ProV: + if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro { + return Gemini1Dot5Pro + } + + return Gemini1Dot5Flash + default: + return req.Model + } +} + +func (req *ChatCompletionRequest) ParseModelWithMapping() string { + switch { + case req.Model == openai.GPT4VisionPreview: + if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro { + return Gemini1Dot5Pro + } + + return Gemini1Dot5Flash + default: + return ConvertModel(req.Model) + } +} + +func (req *EmbeddingRequest) ToGenaiModel() string { + if USE_MODEL_MAPPING { + return ConvertModel(req.Model) + } else { + return req.Model + } +} diff --git a/pkg/adapter/struct.go b/pkg/adapter/struct.go index 5c0378d..db79de9 100644 --- a/pkg/adapter/struct.go +++ b/pkg/adapter/struct.go @@ -2,8 +2,6 @@ package adapter import ( "encoding/json" - "os" - "strings" "github.com/google/generative-ai-go/genai" "github.com/pkg/errors" @@ -43,27 +41,10 @@ type ChatCompletionRequest struct { Stop []string `json:"stop,omitempty"` } -func (req *ChatCompletionRequest) ToGenaiModel() string { - switch { - case req.Model == openai.GPT4VisionPreview: - if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro { - return Gemini1Dot5Pro - } - - return Gemini1Dot5Flash - case req.Model == openai.GPT4TurboPreview || req.Model == openai.GPT4Turbo1106 || req.Model == openai.GPT4Turbo0125: - return Gemini1Dot5Pro - case strings.HasPrefix(req.Model, openai.GPT4): - return Gemini1Dot5Flash - default: - return Gemini1Pro - } -} - func (req *ChatCompletionRequest) ToGenaiMessages() ([]*genai.Content, error) { - if req.Model == openai.GPT4VisionPreview { + if req.Model == Gemini1Dot5ProV || req.Model == openai.GPT4VisionPreview { return req.toVisionGenaiContent() - } else if req.Model == openai.GPT3Ada002 { + } else if req.Model == TextEmbedding004 || req.Model == string(openai.AdaEmbeddingV2) { return nil, errors.New("Chat Completion is not supported for embedding model") } @@ -209,7 +190,7 @@ type EmbeddingRequest struct { } func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) { - if req.Model != openai.GPT3Ada002 { + if req.Model != TextEmbedding004 && req.Model != string(openai.AdaEmbeddingV2) { return nil, errors.New("Embedding is not supported for chat model " + req.Model) } @@ -225,7 +206,3 @@ func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) { return content, nil } - -func (req *EmbeddingRequest) ToGenaiModel() string { - return TextEmbedding004 -}