-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.ts
81 lines (70 loc) · 2.51 KB
/
models.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import { createAnthropic } from '@ai-sdk/anthropic'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createMistral } from '@ai-sdk/mistral'
import { createOpenAI } from '@ai-sdk/openai'
import { createOllama } from 'ollama-ai-provider'
import { createVertex } from '@ai-sdk/google-vertex'
export type LLMModel = {
id: string
name: string
provider: string
providerId: string
}
export type LLMModelConfig = {
model?: string
apiKey?: string
baseURL?: string
temperature?: number
topP?: number
topK?: number
frequencyPenalty?: number
presencePenalty?: number
maxTokens?: number
}
export function getModelClient(model: LLMModel, config: LLMModelConfig) {
const { id: modelNameString, providerId } = model
const { apiKey, baseURL } = config
console.log("modelNameString:",modelNameString)
const providerConfigs = {
anthropic: () => createAnthropic({ apiKey, baseURL })(modelNameString),
openai: () => createOpenAI({ apiKey, baseURL })(modelNameString),
google: () =>
createGoogleGenerativeAI({
apiKey:process.env.GOOGLE_GENERATIVE_AI_API_KEY,
baseURL
})(modelNameString),
mistral: () => createMistral({ apiKey, baseURL })(modelNameString),
groq: () =>
createOpenAI({
apiKey: apiKey || process.env.GROQ_API_KEY,
baseURL: baseURL || 'https://api.groq.com/openai/v1',
})(modelNameString),
togetherai: () =>
createOpenAI({
apiKey: apiKey || process.env.TOGETHER_AI_API_KEY,
baseURL: baseURL || 'https://api.together.xyz/v1',
})(modelNameString),
ollama: () => createOllama({ baseURL })(modelNameString),
fireworks: () =>
createOpenAI({
apiKey: apiKey || process.env.FIREWORKS_API_KEY,
baseURL: baseURL || 'https://api.fireworks.ai/inference/v1',
})(modelNameString),
vertex: () => createVertex({ googleAuthOptions: { credentials: JSON.parse(process.env.GOOGLE_VERTEX_CREDENTIALS || '{}') } })(modelNameString),
baidu: () => createOpenAI({ apiKey: process.env.BAIDU_API_KEY, baseURL: 'https://aip.baidubce.com/rest/2.0/ocr/v1' })(modelNameString),
}
const createClient =
providerConfigs[providerId as keyof typeof providerConfigs]
if (!createClient) {
throw new Error(`Unsupported provider: ${providerId}`)
}
return createClient()
}
export function getDefaultMode(model: LLMModel) {
const { id: modelNameString, providerId } = model
// monkey patch fireworks
if (providerId === 'fireworks') {
return 'json'
}
return 'auto'
}