Skip to content

Commit

Permalink
support deepseek
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonHX committed May 9, 2024
1 parent c661d52 commit 21cad8f
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 1 deletion.
6 changes: 5 additions & 1 deletion grpcServer/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ const AZURE_OPENAI_LLM_API = "azure_openai"
const BAICHUAN_LLM_API = "baichuan"
const GEMINI_LLM_API = "gemini"
const MOONSHOT_LLM_API = "moonshot"
const DEEPSEEK_LLM_API = "deepseek"

func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.LLMRequestSchema) (*model.LLMResponseSchema, error) {
info := rs.GetLlmRequestInfo()
switch info.GetLlmApiType() {
case OPENAI_LLM_API:
cli := NewOpenAIClient(info)
return OpenAIChatCompletion(cli, rs)

case DEEPSEEK_LLM_API:
fallthrough
case MOONSHOT_LLM_API:
cli := NewOpenAIClient(info)
if functionCallingRequestMake(rs) {
Expand Down Expand Up @@ -82,6 +84,8 @@ func (uno *UnoForwardServer) StreamRequestLLM(rs *model.LLMRequestSchema, sv mod
case OPENAI_LLM_API:
cli := NewOpenAIClient(info)
return OpenAIChatCompletionStreaming(cli, rs, sv)
case DEEPSEEK_LLM_API:
fallthrough
case MOONSHOT_LLM_API:
cli := NewOpenAIClient(info)
if functionCallingRequestMake(rs) {
Expand Down
3 changes: 3 additions & 0 deletions grpcServer/relay_gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ func getProvider(m string) (string, error) {
if strings.Contains(m, "moonshot") {
return "moonshot", nil
}
if strings.Contains(m, "deepseek") {
return "deepseek", nil
}
return "", errors.New("could not get provider")
}

Expand Down
2 changes: 2 additions & 0 deletions relay/reqTransformer/ChatGPT.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena
switch api {
case "moonshot":
url = "https://api.moonshot.cn/v1"
case "deepseek":
url = "https://api.deepseek.com/v1"
}
return &model.LLMRequestSchema{
Messages: messages,
Expand Down
133 changes: 133 additions & 0 deletions tests_grpc/deepseek_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package tests_grpc_test

import (
"context"
"log"
"os"
"testing"

"github.com/joho/godotenv"
"go.limit.dev/unollm/grpcServer"
"go.limit.dev/unollm/model"
"go.limit.dev/unollm/utils"
)

func TestDeepSeek(t *testing.T) {
godotenv.Load("../.env")

messages := make([]*model.LLMChatCompletionMessage, 0)
messages = append(messages, &model.LLMChatCompletionMessage{
Role: "user",
Content: "假如今天下大雨,我是否需要带伞?",
})
OPENAIApiKey := os.Getenv("TEST_DEEPSEEK_API")
req_info := model.LLMRequestInfo{
LlmApiType: grpcServer.DEEPSEEK_LLM_API,
Model: "deepseek-chat",
Temperature: 0.9,
TopP: 0.9,
TopK: 1,
Url: "https://api.deepseek.com/v1",
Token: OPENAIApiKey,
}
req := model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &req_info,
}
mockServer := grpcServer.UnoForwardServer{}
res, err := mockServer.BlockingRequestLLM(context.Background(), &req)
if err != nil {
t.Error(err)
}
log.Println("res: ", res)
}

func TestDeepSeekStreaming(t *testing.T) {
godotenv.Load("../.env")

messages := make([]*model.LLMChatCompletionMessage, 0)
messages = append(messages, &model.LLMChatCompletionMessage{
Role: "user",
Content: "假如今天下大雨,我是否需要带伞?",
})
OPENAIApiKey := os.Getenv("TEST_DEEPSEEK_API")
req_info := model.LLMRequestInfo{
LlmApiType: grpcServer.DEEPSEEK_LLM_API,
Model: "deepseek-chat",
Temperature: 0.9,
TopP: 0.9,
TopK: 1,
Url: "https://api.deepseek.com/v1",
Token: OPENAIApiKey,
}
req := model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &req_info,
}
mockServer := grpcServer.UnoForwardServer{}
mockServerPipe := utils.MockServerStream{
Stream: make(chan *model.PartialLLMResponse, 1000),
}
err := mockServer.StreamRequestLLM(&req, &mockServerPipe)
if err != nil {
t.Fatal(err)
}
for {
res := <-mockServerPipe.Stream
log.Println(res)
if res.LlmTokenCount != nil {
log.Println(res.LlmTokenCount)
return
}
}
}

func TestDeepSeekFunctionCalling(t *testing.T) {
godotenv.Load("../.env")

messages := make([]*model.LLMChatCompletionMessage, 0)
messages = append(messages, &model.LLMChatCompletionMessage{
Role: "user",
Content: "whats the weather like in Poston?",
})
OPENAIApiKey := os.Getenv("TEST_DEEPSEEK_API")
req_info := model.LLMRequestInfo{
LlmApiType: grpcServer.DEEPSEEK_LLM_API,
Model: "deepseek-chat",
Temperature: 0.9,
TopP: 0.9,
TopK: 1,
Url: "https://api.deepseek.com/v1",
Token: OPENAIApiKey,
Functions: []*model.Function{
{
Name: "get_weather",
Description: "Get the weather of a location",
Parameters: []*model.FunctionCallingParameter{
{
Name: "location",
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
{
Name: "unit",
Type: "string",
Enums: []string{"celsius", "fahrenheit"},
},
},
Requireds: []string{"location", "unit"},
},
},
UseFunctionCalling: true,
}
req := model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &req_info,
}
mockServer := grpcServer.UnoForwardServer{}
res, err := mockServer.BlockingRequestLLM(context.Background(), &req)
if err != nil {
t.Fatal(err)
}
log.Printf("res: %#v", res.ToolCalls[0])
}

0 comments on commit 21cad8f

Please sign in to comment.