Skip to content

Commit 11a5da7

Browse files
committed
feat: prompt pure sql with formatter
1 parent e07341a commit 11a5da7

File tree

7 files changed

+345
-291
lines changed

7 files changed

+345
-291
lines changed

app/composables/useAgent.ts

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import { destr } from 'destr'
2+
import { format } from 'sql-formatter'
3+
import { toast } from 'vue-sonner'
4+
5+
interface UseSqlAgentOptions {
6+
usableTableNames: MaybeRef<string[]>
7+
limit: MaybeRef<number>
8+
beforeInitialzeModel: () => void
9+
beforeQueryRelevantTables: () => void
10+
beforeQueryPrompt: () => void
11+
beforeQuerySqlStatement: () => void
12+
beforeQuerySqlValidation: () => void
13+
}
14+
15+
async function queryCreateTableSQL(value: string, cursor?: Database): Promise<string> {
16+
if (value && cursor) {
17+
const backend = cursor?.path.split(':')[0] ?? 'mysql'
18+
const sql = Sql.SHOW_CREATE_TABLE(value)[backend]
19+
if (!sql)
20+
return ''
21+
const [row] = await cursor?.select<any[]>(sql) ?? []
22+
if (backend === 'sqlite')
23+
return row.sql
24+
return row['Create Table']
25+
}
26+
27+
return ''
28+
}
29+
30+
export function useSqlAgent(cursorInstance: MaybeRef<Database | undefined> | undefined, options: Partial<UseSqlAgentOptions> = {}) {
31+
const store = useSettingsStore()
32+
const { model } = storeToRefs(store)
33+
34+
const retries = ref(3)
35+
const output = ref('')
36+
const usableTableNames = computed(() => unref(options.usableTableNames) ?? [])
37+
const cursor = computed(() => unref(cursorInstance))
38+
const llm = useLlm(model)
39+
const includes = ref<string[]>([])
40+
const backend = useCursorBackend(cursor)
41+
const prompt = ref('')
42+
const isLoading = ref(false)
43+
const error = ref('')
44+
const question = ref('')
45+
46+
async function queryRelevantTables() {
47+
if (usableTableNames.value.length < 25) {
48+
includes.value = usableTableNames.value
49+
return
50+
}
51+
52+
const tool = {
53+
type: 'function',
54+
function: {
55+
name: 'Table',
56+
description: 'Get relevant tables from a given list',
57+
parameters: {
58+
type: 'object',
59+
properties: {
60+
rows: {
61+
type: 'array',
62+
description: 'Name list of table in SQL database',
63+
items: {
64+
type: 'string',
65+
enum: usableTableNames.value,
66+
},
67+
},
68+
},
69+
required: ['rows'],
70+
},
71+
},
72+
}
73+
74+
const prompt = usePromptTemplate(RELEVANT_TABLES_PROMPT, { tableNames: usableTableNames.value.join('\n') })
75+
const defaults = usableTableNames.value.slice(0, 25)
76+
77+
for (let i = 0; i < retries.value; i++) {
78+
if (!isLoading.value)
79+
return
80+
81+
const { choices } = await llm.chat.completions.create({
82+
messages: [
83+
{ role: 'user', content: prompt.value },
84+
{ role: 'user', content: question.value },
85+
],
86+
tools: [tool],
87+
tool_choice: 'required',
88+
})
89+
90+
const [call] = choices[0].message.tool_calls
91+
92+
if (!call)
93+
continue
94+
95+
const args: string[] = call.function.arguments
96+
const rows = destr<{ rows: string[] }>(args).rows ?? []
97+
98+
if (rows.length) {
99+
includes.value = rows
100+
return
101+
}
102+
}
103+
104+
includes.value = defaults
105+
}
106+
107+
async function queryPrompt() {
108+
if (!isLoading.value)
109+
return
110+
111+
const prompts = await Promise.all(includes.value.map(t => (async () => {
112+
if (t && cursor.value) {
113+
const [v0, v1] = await Promise.all([
114+
queryCreateTableSQL(t, cursor.value),
115+
cursor.value?.select<Record<string, any>[]>(`SELECT * FROM \`${t}\` LIMIT 3;`) ?? [],
116+
])
117+
118+
const cols = Object.keys(v1[0] ?? {})
119+
const rows: string[] = [cols.join('\t')]
120+
121+
for (const i of v1) {
122+
rows.push(cols.map(k => String(i[k])).join('\t'))
123+
}
124+
125+
return [
126+
v0,
127+
'\n/*',
128+
'3 rows from t_inf_app_application table:',
129+
rows.join('\n'),
130+
'*/',
131+
].join('\n')
132+
}
133+
})()))
134+
135+
const template = SQL_PROMPT?.[backend.value] ?? SQL_PROMPT.default as string
136+
137+
prompt.value = usePromptTemplate(template, {
138+
tableInfo: prompts.filter(Boolean).join('\n\n'),
139+
input: question.value,
140+
topK: unref(options.limit) ?? 5,
141+
dialect: backend.value,
142+
}).value
143+
}
144+
145+
async function querySqlStatement() {
146+
for (let i = 0; i < retries.value; i++) {
147+
if (!isLoading.value)
148+
return
149+
150+
const { choices } = (await llm.chat.completions.create({ messages: [{ role: 'user', content: prompt.value }] })) ?? {}
151+
output.value = choices?.[0]?.message.content ?? ''
152+
if (output.value)
153+
break
154+
}
155+
}
156+
157+
async function querySqlValidation() {
158+
let _error = ''
159+
160+
const prompt = usePromptTemplate(SQL_VALIDATE_PROMPT, { notFormattedQuery: unref(output.value) })
161+
162+
for (let i = 0; i < retries.value; i++) {
163+
if (!isLoading.value)
164+
return
165+
166+
const { choices } = (await llm.chat.completions.create({ messages: [{ role: 'user', content: prompt.value }] })) ?? {}
167+
const sql = choices?.[0]?.message.content ?? ''
168+
169+
if (!sql)
170+
continue
171+
172+
try {
173+
output.value = format(sql, {
174+
language: 'sql',
175+
tabWidth: 2,
176+
keywordCase: 'upper',
177+
})
178+
179+
_error = ''
180+
break
181+
}
182+
catch (e) {
183+
_error = e as string
184+
continue
185+
}
186+
}
187+
188+
error.value = _error
189+
}
190+
191+
async function execute(input: MaybeRef<string>) {
192+
prompt.value = ''
193+
output.value = ''
194+
includes.value = []
195+
question.value = unref(input)
196+
isLoading.value = true
197+
198+
try {
199+
options.beforeInitialzeModel?.()
200+
options.beforeQueryRelevantTables?.()
201+
await queryRelevantTables()
202+
options.beforeQueryPrompt?.()
203+
await queryPrompt()
204+
options.beforeQuerySqlStatement?.()
205+
await querySqlStatement()
206+
options.beforeQuerySqlValidation?.()
207+
await querySqlValidation()
208+
}
209+
catch (errors: any) {
210+
error.value = errors
211+
212+
if (Array.isArray(errors)) {
213+
errors.forEach(({ error }: any) => {
214+
toast(error.status, { description: error.message })
215+
})
216+
}
217+
218+
else {
219+
toast('NETWORK_ERROR', { description: 'An error occurred trying to load the resource' })
220+
}
221+
}
222+
finally {
223+
isLoading.value = false
224+
}
225+
226+
return output.value
227+
}
228+
229+
return {
230+
error,
231+
output,
232+
isLoading,
233+
execute,
234+
}
235+
}

app/composables/useLlm.ts

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type OpenAI from 'openai'
12
import { createFetch } from '@vueuse/core'
23

34
const _GOOGLE_AI_MODELS = GOOGLE_AI_MODELS.map(({ model }) => model)
@@ -8,35 +9,6 @@ const _MODELS = [
89
..._DEEPSEEK_MODELS,
910
]
1011

11-
interface ChatCompletionsResponse {
12-
choices: {
13-
finish_reason: string
14-
index: 0
15-
message: {
16-
content: string
17-
role: 'assistant'
18-
19-
tool_calls: {
20-
function: {
21-
arguments: string
22-
name: string
23-
}
24-
id: string
25-
type: string
26-
}[]
27-
}
28-
}[]
29-
30-
model: string
31-
object: string
32-
33-
usage: {
34-
completion_tokens: number
35-
prompt_tokens: number
36-
total_tokens: number
37-
}
38-
}
39-
4012
export function useLlm(model: MaybeRef<string>) {
4113
let _apiKey = ''
4214
let _baseURL = ''
@@ -66,13 +38,19 @@ export function useLlm(model: MaybeRef<string>) {
6638
baseUrl: _baseURL,
6739

6840
options: {
69-
async beforeFetch({ options }: any) {
41+
beforeFetch({ options }: any) {
7042
if (!options.headers)
7143
options.headers = {}
7244
options.headers.Authorization = `Bearer ${_apiKey}`
7345

7446
return { options }
7547
},
48+
49+
onFetchError(ctx) {
50+
if (ctx.data)
51+
ctx.error = ctx.data
52+
return ctx
53+
},
7654
},
7755
})
7856
})
@@ -83,7 +61,9 @@ export function useLlm(model: MaybeRef<string>) {
8361
async create(body: Record<string, any> = {}) {
8462
if (!openai.value)
8563
return
86-
const { data } = await openai.value<ChatCompletionsResponse>('/chat/completions').post({ model: _model.value, ...body }).json()
64+
const { data, error } = await openai.value<OpenAI.Chat.ChatCompletion>('/chat/completions').post({ model: _model.value, ...body }).json()
65+
if (error.value)
66+
throw error.value
8767
return data.value
8868
},
8969
},

0 commit comments

Comments
 (0)