Skip to content

Commit a387f93

Browse files
committed
refactor: 将 PrepareRequest 和 CompleteLLM 的选项参数类型从 ztype.Map 更改为函数类型,以提高灵活性和可维护性
1 parent 002828e commit a387f93

File tree

9 files changed

+29
-22
lines changed

9 files changed

+29
-22
lines changed

agent/deepseek.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ func (p *DeepseekProvider) streamable(ctx context.Context, body []byte) (*zjson.
126126
return zjson.ParseBytes(json), nil
127127
}
128128

129-
func (p *DeepseekProvider) PrepareRequest(messages *message.Messages, options ...ztype.Map) ([]byte, error) {
129+
func (p *DeepseekProvider) PrepareRequest(messages *message.Messages, options ...func(ztype.Map) ztype.Map) ([]byte, error) {
130130
requestBody := ztype.Map{
131131
"model": p.options.Model,
132132
"stream": p.options.Stream,
@@ -141,9 +141,7 @@ func (p *DeepseekProvider) PrepareRequest(messages *message.Messages, options ..
141141
})
142142

143143
for _, v := range options {
144-
for k, v := range v {
145-
requestBody[k] = v
146-
}
144+
v(requestBody)
147145
}
148146

149147
return json.Marshal(requestBody)

agent/deepseek_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ func TestNewDeepseekProvider(t *testing.T) {
3131

3232
data, err := deepseek.PrepareRequest(
3333
messages,
34-
ztype.Map{"temperature": 0.7},
34+
func(m ztype.Map) ztype.Map {
35+
m.Set("temperature", 0.7)
36+
return m
37+
},
3538
)
3639
tt.Log(string(data))
3740
tt.NoError(err, true)

agent/ollama.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (p *OllamaProvider) Generate(ctx context.Context, body []byte) (json *zjson
8383
return
8484
}
8585

86-
func (p *OllamaProvider) PrepareRequest(messages *message.Messages, options ...ztype.Map) ([]byte, error) {
86+
func (p *OllamaProvider) PrepareRequest(messages *message.Messages, options ...func(ztype.Map) ztype.Map) ([]byte, error) {
8787
requestBody := ztype.Map{
8888
"model": p.options.Model,
8989
"stream": p.options.Stream,
@@ -98,9 +98,7 @@ func (p *OllamaProvider) PrepareRequest(messages *message.Messages, options ...z
9898
})
9999

100100
for _, v := range options {
101-
for k, v := range v {
102-
requestBody[k] = v
103-
}
101+
requestBody = v(requestBody)
104102
}
105103

106104
return json.Marshal(requestBody)

agent/ollama_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ func TestNewOllamaProvider(t *testing.T) {
3535

3636
data, err := llm.PrepareRequest(
3737
messages,
38-
ztype.Map{"temperature": 0.3},
38+
func(m ztype.Map) ztype.Map {
39+
m.Set("temperature", 0.3)
40+
return m
41+
},
3942
)
4043
tt.Log(string(data))
4144
tt.NoError(err, true)

agent/openai.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func (p *OpenAIProvider) streamable(ctx context.Context, url string, header zhtt
162162
return zjson.ParseBytes(json), nil
163163
}
164164

165-
func (p *OpenAIProvider) PrepareRequest(messages *message.Messages, options ...ztype.Map) ([]byte, error) {
165+
func (p *OpenAIProvider) PrepareRequest(messages *message.Messages, options ...func(ztype.Map) ztype.Map) ([]byte, error) {
166166
requestBody := ztype.Map{
167167
"model": p.options.Model,
168168
"stream": p.options.Stream,
@@ -177,9 +177,7 @@ func (p *OpenAIProvider) PrepareRequest(messages *message.Messages, options ...z
177177
})
178178

179179
for _, v := range options {
180-
for k, v := range v {
181-
requestBody[k] = v
182-
}
180+
requestBody = v(requestBody)
183181
}
184182

185183
return json.Marshal(requestBody)

agent/openai_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ func TestNewOpenAIProvider(t *testing.T) {
3131

3232
data, err := openai.PrepareRequest(
3333
messages,
34-
ztype.Map{"temperature": 0.7},
34+
func(m ztype.Map) ztype.Map {
35+
m.Set("temperature", 0.7)
36+
return m
37+
},
3538
)
3639
tt.Log(string(data))
3740
tt.NoError(err, true)

agent/provide.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010

1111
type LLMAgent interface {
1212
Generate(ctx context.Context, data []byte) (*zjson.Res, error)
13-
PrepareRequest(messages *message.Messages, options ...ztype.Map) (body []byte, err error)
13+
PrepareRequest(messages *message.Messages, options ...func(ztype.Map) ztype.Map) (body []byte, err error)
1414
ParseResponse(*zjson.Res) (*Response, error)
1515
}
1616

zllm.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/zlsgo/zllm/utils"
1515
)
1616

17-
func CompleteLLM[T *message.Prompt | *message.Messages](ctx context.Context, llm agent.LLMAgent, p T, options ...ztype.Map) (string, error) {
17+
func CompleteLLM[T *message.Prompt | *message.Messages](ctx context.Context, llm agent.LLMAgent, p T, options ...func(ztype.Map) ztype.Map) (string, error) {
1818
var (
1919
messages *message.Messages
2020
err error
@@ -47,7 +47,7 @@ func CompleteLLM[T *message.Prompt | *message.Messages](ctx context.Context, llm
4747
return parse, err
4848
}
4949

50-
func CompleteLLMJSON[T *message.Prompt | *message.Messages](ctx context.Context, llm agent.LLMAgent, p T, options ...ztype.Map) (ztype.Map, error) {
50+
func CompleteLLMJSON[T *message.Prompt | *message.Messages](ctx context.Context, llm agent.LLMAgent, p T, options ...func(ztype.Map) ztype.Map) (ztype.Map, error) {
5151
resp, err := CompleteLLM(ctx, llm, p, options...)
5252
if err != nil {
5353
return nil, err

zllm_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ func TestExecutePromptMore(t *testing.T) {
204204
},
205205
)
206206

207-
resp, err := CompleteLLM(context.Background(), llm, p, ztype.Map{"stream": false})
207+
resp, err := CompleteLLM(context.Background(), llm, p, func(m ztype.Map) ztype.Map {
208+
m["stream"] = false
209+
return m
210+
})
208211
tt.Log(resp)
209212
tt.NoError(err, true)
210213
})
@@ -346,8 +349,8 @@ func TestLLM(t *testing.T) {
346349
utils.SetDebug(true)
347350
p := message.NewPrompt("北京的天气怎么样?", func(po *message.PromptOptions) {
348351
})
349-
resp, err := CompleteLLMJSON(context.Background(), llm, p, ztype.Map{
350-
"tools": ztype.Maps{
352+
resp, err := CompleteLLMJSON(context.Background(), llm, p, func(m ztype.Map) ztype.Map {
353+
m["tools"] = ztype.Maps{
351354
{
352355
"type": "function",
353356
"function": ztype.Map{
@@ -370,7 +373,8 @@ func TestLLM(t *testing.T) {
370373
},
371374
},
372375
},
373-
},
376+
}
377+
return m
374378
})
375379
tt.NoError(err, true)
376380
tt.Log(resp)

0 commit comments

Comments
 (0)