Skip to content

Commit

Permalink
Add Config to Disable Model Mapping (#41)
Browse files Browse the repository at this point in the history
Introduce a new environment variable, DISABLE_MODEL_MAPPING, when
enabled removes the OpenAI -> Gemini Model mapping and provides
access to the named gemini models directly.

Moved model mapping logic to its own `models.go` file inside the
adapter package.

Additionally, fixed a bug where responses would return the gemini
model name even though model mapping was enabled.
  • Loading branch information
ekatiyar authored Aug 8, 2024
1 parent da78e75 commit fddc1e2
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 52 deletions.
43 changes: 36 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,25 @@ go build -o gemini main.go

We recommend deploying Gemini-OpenAI-Proxy using Docker for a straightforward setup. Follow these steps to deploy with Docker:

```bash
docker run --restart=always -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest
```
You can either do this on the command line:
```bash
docker run --restart=unless-stopped -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest
```

Or with the following docker-compose config:
```yaml
version: '3'
services:
gemini:
container_name: gemini
environment: # Set Environment Variables here. Defaults listed below
- GPT_4_VISION_PREVIEW=gemini-1.5-flash-latest
- DISABLE_MODEL_MAPPING=0
ports:
- "8080:8080"
image: zhu327/gemini-openai-proxy:latest
restart: unless-stopped
```

Adjust the port mapping (e.g., `-p 8080:8080`) as needed, and ensure that the Docker image version (`zhu327/gemini-openai-proxy:latest`) aligns with your requirements.

Expand Down Expand Up @@ -83,6 +99,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
"temperature": 0.7
}'
```
If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data. Otherwise, the default uses the `gemini-1.5-flash-latest` model

If you already have access to the Gemini 1.5 Pro api, you can use:

Expand All @@ -104,7 +121,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
-H "Content-Type: application/json" \
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
-d '{
"model": "ada-002",
"model": "text-embedding-ada-002",
"input": "This is a test sentence."
}'
```
Expand All @@ -116,7 +133,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
-H "Content-Type: application/json" \
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
-d '{
"model": "ada-002",
"model": "text-embedding-ada-002",
"input": ["This is a test sentence.", "This is another test sentence"]
}'
```
Expand All @@ -129,9 +146,21 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
| gpt-4 | gemini-1.5-flash-latest |
| gpt-4-turbo-preview | gemini-1.5-pro-latest |
| gpt-4-vision-preview | gemini-1.0-pro-vision-latest |
| ada-002 | text-embedding-004 |
| text-embedding-ada-002 | text-embedding-004 |

If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data.
If you want to disable model mapping, configure the environment variable `DISABLE_MODEL_MAPPING=1`. This will allow you to refer to the Gemini models directly.

Here is an example API request with model mapping disabled:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
-d '{
"model": "gemini-1.0-pro-latest",
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
```

4. **Handle Responses:**
Process the responses from the Gemini-OpenAI-Proxy in the same way you would handle responses from OpenAI.
Expand Down
24 changes: 13 additions & 11 deletions api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,52 @@ func IndexHandler(c *gin.Context) {
}

func ModelListHandler(c *gin.Context) {
owner := adapter.GetOwner()
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": []any{
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT3Dot5Turbo,
ID: adapter.GetModel(openai.GPT3Dot5Turbo),
Object: "model",
OwnedBy: "openai",
OwnedBy: owner,
},
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT4,
ID: adapter.GetModel(openai.GPT4),
Object: "model",
OwnedBy: "openai",
OwnedBy: owner,
},
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT4TurboPreview,
ID: adapter.GetModel(openai.GPT4TurboPreview),
Object: "model",
OwnedBy: "openai",
OwnedBy: owner,
},
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT4VisionPreview,
ID: adapter.GetModel(openai.GPT4VisionPreview),
Object: "model",
OwnedBy: "openai",
OwnedBy: owner,
},
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT3Ada002,
ID: adapter.GetModel(string(openai.AdaEmbeddingV2)),
Object: "model",
OwnedBy: "openai",
OwnedBy: owner,
},
},
})
}

func ModelRetrieveHandler(c *gin.Context) {
model := c.Param("model")
owner := adapter.GetOwner()
c.JSON(http.StatusOK, openai.Model{
CreatedAt: 1686935002,
ID: model,
Object: "model",
OwnedBy: "openai",
OwnedBy: owner,
})
}

Expand Down
11 changes: 3 additions & 8 deletions pkg/adapter/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ import (
)

const (
Gemini1Pro = "gemini-1.0-pro-latest"
Gemini1Dot5Pro = "gemini-1.5-pro-latest"
Gemini1Dot5Flash = "gemini-1.5-flash-latest"
TextEmbedding004 = "text-embedding-004"

genaiRoleUser = "user"
genaiRoleModel = "model"
)
Expand Down Expand Up @@ -121,7 +116,7 @@ func genaiResponseToStreamCompletionResponse(
ID: fmt.Sprintf("chatcmpl-%s", respID),
Object: "chat.completion.chunk",
Created: created,
Model: model,
Model: GetMappedModel(model),
Choices: make([]CompletionChoice, 0, len(genaiResp.Candidates)),
}

Expand Down Expand Up @@ -156,7 +151,7 @@ func genaiResponseToOpenaiResponse(
ID: fmt.Sprintf("chatcmpl-%s", util.GetUUID()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
Model: GetMappedModel(model),
Choices: make([]openai.ChatCompletionChoice, 0, len(genaiResp.Candidates)),
}

Expand Down Expand Up @@ -260,7 +255,7 @@ func (g *GeminiAdapter) GenerateEmbedding(
openaiResp := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.Embedding, 0, len(genaiResp.Embeddings)),
Model: openai.EmbeddingModel(g.model),
Model: openai.EmbeddingModel(GetMappedModel(g.model)),
}

for i, genaiEmbedding := range genaiResp.Embeddings {
Expand Down
109 changes: 109 additions & 0 deletions pkg/adapter/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package adapter

import (
"os"
"strings"

openai "github.com/sashabaranov/go-openai"
)

const (
Gemini1Pro = "gemini-1.0-pro-latest"
Gemini1Dot5Pro = "gemini-1.5-pro-latest"
Gemini1Dot5Flash = "gemini-1.5-flash-latest"
Gemini1Dot5ProV = "gemini-1.0-pro-vision-latest" // Converted to one of the above models in struct::ToGenaiModel
TextEmbedding004 = "text-embedding-004"
)

var USE_MODEL_MAPPING bool = os.Getenv("DISABLE_MODEL_MAPPING") != "1"

func GetOwner() string {
if USE_MODEL_MAPPING {
return "openai"
} else {
return "google"
}
}

func GetModel(openAiModelName string) string {
if USE_MODEL_MAPPING {
return openAiModelName
} else {
return ConvertModel(openAiModelName)
}
}

func GetMappedModel(geminiModelName string) string {
if !USE_MODEL_MAPPING {
return geminiModelName
}
switch {
case geminiModelName == Gemini1Dot5ProV:
return openai.GPT4VisionPreview
case geminiModelName == Gemini1Dot5Pro:
return openai.GPT4TurboPreview
case geminiModelName == Gemini1Dot5Flash:
return openai.GPT4
case geminiModelName == TextEmbedding004:
return string(openai.AdaEmbeddingV2)
default:
return openai.GPT3Dot5Turbo
}
}

func ConvertModel(openAiModelName string) string {
switch {
case openAiModelName == openai.GPT4VisionPreview:
return Gemini1Dot5ProV
case openAiModelName == openai.GPT4TurboPreview || openAiModelName == openai.GPT4Turbo1106 || openAiModelName == openai.GPT4Turbo0125:
return Gemini1Dot5Pro
case strings.HasPrefix(openAiModelName, openai.GPT4):
return Gemini1Dot5Flash
case openAiModelName == string(openai.AdaEmbeddingV2):
return TextEmbedding004
default:
return Gemini1Pro
}
}

func (req *ChatCompletionRequest) ToGenaiModel() string {
if USE_MODEL_MAPPING {
return req.ParseModelWithMapping()
} else {
return req.ParseModelWithoutMapping()
}
}

func (req *ChatCompletionRequest) ParseModelWithoutMapping() string {
switch {
case req.Model == Gemini1Dot5ProV:
if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro {
return Gemini1Dot5Pro
}

return Gemini1Dot5Flash
default:
return req.Model
}
}

func (req *ChatCompletionRequest) ParseModelWithMapping() string {
switch {
case req.Model == openai.GPT4VisionPreview:
if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro {
return Gemini1Dot5Pro
}

return Gemini1Dot5Flash
default:
return ConvertModel(req.Model)
}
}

func (req *EmbeddingRequest) ToGenaiModel() string {
if USE_MODEL_MAPPING {
return ConvertModel(req.Model)
} else {
return req.Model
}
}
29 changes: 3 additions & 26 deletions pkg/adapter/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package adapter

import (
"encoding/json"
"os"
"strings"

"github.com/google/generative-ai-go/genai"
"github.com/pkg/errors"
Expand Down Expand Up @@ -43,27 +41,10 @@ type ChatCompletionRequest struct {
Stop []string `json:"stop,omitempty"`
}

func (req *ChatCompletionRequest) ToGenaiModel() string {
switch {
case req.Model == openai.GPT4VisionPreview:
if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro {
return Gemini1Dot5Pro
}

return Gemini1Dot5Flash
case req.Model == openai.GPT4TurboPreview || req.Model == openai.GPT4Turbo1106 || req.Model == openai.GPT4Turbo0125:
return Gemini1Dot5Pro
case strings.HasPrefix(req.Model, openai.GPT4):
return Gemini1Dot5Flash
default:
return Gemini1Pro
}
}

func (req *ChatCompletionRequest) ToGenaiMessages() ([]*genai.Content, error) {
if req.Model == openai.GPT4VisionPreview {
if req.Model == Gemini1Dot5ProV || req.Model == openai.GPT4VisionPreview {
return req.toVisionGenaiContent()
} else if req.Model == openai.GPT3Ada002 {
} else if req.Model == TextEmbedding004 || req.Model == string(openai.AdaEmbeddingV2) {
return nil, errors.New("Chat Completion is not supported for embedding model")
}

Expand Down Expand Up @@ -209,7 +190,7 @@ type EmbeddingRequest struct {
}

func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) {
if req.Model != openai.GPT3Ada002 {
if req.Model != TextEmbedding004 && req.Model != string(openai.AdaEmbeddingV2) {
return nil, errors.New("Embedding is not supported for chat model " + req.Model)
}

Expand All @@ -225,7 +206,3 @@ func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) {

return content, nil
}

func (req *EmbeddingRequest) ToGenaiModel() string {
return TextEmbedding004
}

0 comments on commit fddc1e2

Please sign in to comment.