Skip to content

Commit e330e8c

Browse files
authored
fix: rename explainQuery to unsafeExplainQuery (#225)
Related discussion #223 We already improved it in #224, this PR makes it more secure by 1. Renaming `explainQuery` to `unsafeExplainQuery` while the single query check and transaction wrapping is pretty secure already, there might be edge cases where a query escapes the designated boundaries. Therefore, we rename the tool to `unsafeExplainQuery` 2. Create a new tool called `safeExplainQuery`, this operates on `queryId` and fetches the actual SQL from `pg_stat_statements` table itself thereby eliminating the code path that can lead to any SQL injection. This is done in the following stacked PR #226 3. Use the new `safeExplainQuery` tool instead of `unsafeExplainQuery`, to make it work, we had to additionally return `queryId` from `getSlowQueries` tool in addition to the slow SQL query. This is done in the following stacked PR #227
1 parent 56d2b05 commit e330e8c

File tree

9 files changed

+90
-14
lines changed

9 files changed

+90
-14
lines changed

apps/dbagent/src/evals/chat/tool-choice.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ describe.concurrent('tool_choice', () => {
153153
{
154154
id: 'explain_query',
155155
prompt: 'Explain SELECT * FROM dogs',
156-
expectedToolCalls: ['explainQuery'],
156+
expectedToolCalls: ['safeExplainQuery'],
157157
allowOtherTools: false
158158
},
159159
{

apps/dbagent/src/lib/ai/prompts.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ If the user asks for something that is not related to PostgreSQL or database adm
1111
export const chatSystemPrompt = `
1212
Provide clear, concise, and accurate responses to questions.
1313
Use the provided tools to get context from the PostgreSQL database to answer questions.
14-
When asked why a query is slow, call the explainQuery tool and also take into account the table sizes.
14+
When asked why a query is slow, call the safeExplainQuery tool and also take into account the table sizes.
1515
During the initial assessment use the getTablesInfo, getPerfromanceAndVacuumSettings, getConnectionsStats, and getPostgresExtensions, and others if you want.
1616
When asked to run a playbook, use the getPlaybook tool to get the playbook contents. Then use the contents of the playbook
1717
as an action plan. Execute the plan step by step.

apps/dbagent/src/lib/ai/tools/db.ts

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import { Tool, tool } from 'ai';
22
import { z } from 'zod';
33
import { getPerformanceAndVacuumSettings, toolFindTableSchema } from '~/lib/tools/dbinfo';
4-
import { toolDescribeTable, toolExplainQuery, toolGetSlowQueries } from '~/lib/tools/slow-queries';
4+
import {
5+
toolDescribeTable,
6+
toolGetSlowQueries,
7+
toolSafeExplainQuery,
8+
toolUnsafeExplainQuery
9+
} from '~/lib/tools/slow-queries';
510
import {
611
toolCurrentActiveQueries,
712
toolGetConnectionsGroups,
@@ -29,7 +34,8 @@ export class DBSQLTools implements ToolsetGroup {
2934
toolset(): Record<string, Tool> {
3035
return {
3136
getSlowQueries: this.getSlowQueries(),
32-
explainQuery: this.explainQuery(),
37+
unsafeExplainQuery: this.unsafeExplainQuery(),
38+
safeExplainQuery: this.safeExplainQuery(),
3339
describeTable: this.describeTable(),
3440
findTableSchema: this.findTableSchema(),
3541
getCurrentActiveQueries: this.getCurrentActiveQueries(),
@@ -46,7 +52,7 @@ export class DBSQLTools implements ToolsetGroup {
4652
return tool({
4753
description: `Get a list of slow queries formatted as a JSON array. Contains how many times the query was called,
4854
the max execution time in seconds, the mean execution time in seconds, the total execution time
49-
(all calls together) in seconds, and the query itself.`,
55+
(all calls together) in seconds, the query itself, and the queryid for use with safeExplainQuery.`,
5056
parameters: z.object({}),
5157
execute: async () => {
5258
try {
@@ -58,7 +64,7 @@ the max execution time in seconds, the mean execution time in seconds, the total
5864
});
5965
}
6066

61-
explainQuery(): Tool {
67+
unsafeExplainQuery(): Tool {
6268
const pool = this.#pool;
6369
return tool({
6470
description: `Run explain on a a query. Returns the explain plan as received from PostgreSQL.
@@ -73,7 +79,7 @@ If you know the schema, pass it in as well.`,
7379
try {
7480
const explain = await withPoolConnection(
7581
pool,
76-
async (client) => await toolExplainQuery(client, schema, query)
82+
async (client) => await toolUnsafeExplainQuery(client, schema, query)
7783
);
7884
if (!explain) return 'Could not run EXPLAIN on the query';
7985

@@ -85,6 +91,26 @@ If you know the schema, pass it in as well.`,
8591
});
8692
}
8793

94+
safeExplainQuery(): Tool {
95+
const pool = this.#pool;
96+
return tool({
97+
description: `Safely run EXPLAIN on a query by fetching it from pg_stat_statements using queryId.
98+
This prevents SQL injection by not accepting raw SQL queries. Returns the explain plan as received from PostgreSQL.
99+
Use the queryid field from the getSlowQueries tool output as the queryId parameter.`,
100+
parameters: z.object({
101+
schema: z.string(),
102+
queryId: z.string().describe('The query ID from pg_stat_statements (use the queryid field from getSlowQueries)')
103+
}),
104+
execute: async ({ schema = 'public', queryId }) => {
105+
try {
106+
return await withPoolConnection(pool, async (client) => await toolSafeExplainQuery(client, schema, queryId));
107+
} catch (error) {
108+
return `Error running safe EXPLAIN on the query: ${error}`;
109+
}
110+
}
111+
});
112+
}
113+
88114
describeTable(): Tool {
89115
const pool = this.#pool;
90116
return tool({

apps/dbagent/src/lib/targetdb/db.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ interface SlowQuery {
361361
mean_exec_secs: number;
362362
total_exec_secs: number;
363363
query: string;
364+
queryid: string;
364365
}
365366

366367
export async function getSlowQueries(client: ClientBase, thresholdMs: number): Promise<SlowQuery[]> {
@@ -370,7 +371,8 @@ export async function getSlowQueries(client: ClientBase, thresholdMs: number): P
370371
round(max_exec_time/1000) max_exec_secs,
371372
round(mean_exec_time/1000) mean_exec_secs,
372373
round(total_exec_time/1000) total_exec_secs,
373-
query
374+
query,
375+
queryid::text as queryid
374376
FROM pg_stat_statements
375377
WHERE max_exec_time > $1
376378
ORDER BY total_exec_time DESC
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import { ClientBase } from './db';
2+
import { isSingleStatement } from './unsafe-explain';
3+
4+
export async function safeExplainQuery(client: ClientBase, schema: string, queryId: string): Promise<string> {
5+
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(schema)) {
6+
return 'Invalid schema name. Only alphanumeric characters and underscores are allowed.';
7+
}
8+
9+
// First, fetch the query from pg_stat_statements
10+
const queryResult = await client.query('SELECT query FROM pg_stat_statements WHERE queryid = $1', [queryId]);
11+
12+
if (queryResult.rows.length === 0) {
13+
return 'Query not found in pg_stat_statements for the given queryId';
14+
}
15+
16+
const query = queryResult.rows[0].query;
17+
18+
if (!isSingleStatement(query)) {
19+
return 'The query is not a single safe statement. Only SELECT, INSERT, UPDATE, DELETE, and WITH statements are allowed.';
20+
}
21+
22+
const hasPlaceholders = /\$\d+/.test(query);
23+
24+
let toReturn = '';
25+
try {
26+
await client.query('BEGIN');
27+
await client.query("SET LOCAL statement_timeout = '2000ms'");
28+
await client.query("SET LOCAL lock_timeout = '200ms'");
29+
await client.query(`SET search_path TO ${schema}`);
30+
const explainQuery = hasPlaceholders ? `EXPLAIN (GENERIC_PLAN true) ${query}` : `EXPLAIN ${query}`;
31+
console.log(schema);
32+
console.log(explainQuery);
33+
const result = await client.query(explainQuery);
34+
console.log(result.rows);
35+
toReturn = result.rows.map((row: { [key: string]: string }) => row['QUERY PLAN']).join('\n');
36+
} catch (error) {
37+
console.error('Error explaining query', error);
38+
toReturn = 'I could not run EXPLAIN on that query. Try a different method.';
39+
}
40+
await client.query('ROLLBACK');
41+
return toReturn;
42+
}

apps/dbagent/src/lib/targetdb/explain.test.ts renamed to apps/dbagent/src/lib/targetdb/unsafe-explain.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { isSingleStatement } from './explain';
1+
import { isSingleStatement } from './unsafe-explain';
22

33
describe('isSingleStatement', () => {
44
describe('positive tests - should return true', () => {

apps/dbagent/src/lib/targetdb/explain.ts renamed to apps/dbagent/src/lib/targetdb/unsafe-explain.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ class SQLParser {
253253
}
254254
}
255255

256-
export async function explainQuery(client: ClientBase, schema: string, query: string): Promise<string> {
256+
export async function unsafeExplainQuery(client: ClientBase, schema: string, query: string): Promise<string> {
257257
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(schema)) {
258258
return 'Invalid schema name. Only alphanumeric characters and underscores are allowed.';
259259
}

apps/dbagent/src/lib/tools/playbooks.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Use the tool findTableSchema to find the schema of the table involved in the slo
2424
Use the tool describeTable to describe the table you found.
2525
2626
Step 4:
27-
Use the tool explainQuery to explain the slow queries. Make sure to pass the schema you found to the tool.
27+
Use the tool safeExplainQuery to explain the slow queries. Make sure to pass the schema you found to the tool.
2828
Also, it's very important to replace the query parameters ($1, $2, etc) with the actual values. Generate your own values, but
2929
take into account the data types of the columns.
3030

apps/dbagent/src/lib/tools/slow-queries.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { ClientBase, describeTable, getSlowQueries } from '../targetdb/db';
2-
import { explainQuery } from '../targetdb/explain';
2+
import { safeExplainQuery } from '../targetdb/safe-explain';
3+
import { unsafeExplainQuery } from '../targetdb/unsafe-explain';
34

45
export async function toolGetSlowQueries(client: ClientBase, thresholdMs: number): Promise<string> {
56
const slowQueries = await getSlowQueries(client, thresholdMs);
@@ -20,8 +21,13 @@ export async function toolGetSlowQueries(client: ClientBase, thresholdMs: number
2021
return JSON.stringify(filteredSlowQueries);
2122
}
2223

23-
export async function toolExplainQuery(client: ClientBase, schema: string, query: string): Promise<string> {
24-
const result = await explainQuery(client, schema, query);
24+
export async function toolUnsafeExplainQuery(client: ClientBase, schema: string, query: string): Promise<string> {
25+
const result = await unsafeExplainQuery(client, schema, query);
26+
return JSON.stringify(result);
27+
}
28+
29+
export async function toolSafeExplainQuery(client: ClientBase, schema: string, queryId: string): Promise<string> {
30+
const result = await safeExplainQuery(client, schema, queryId);
2531
return JSON.stringify(result);
2632
}
2733

0 commit comments

Comments
 (0)