diff --git a/src/client.ts b/src/client.ts index f2c4ddb..1538c34 100644 --- a/src/client.ts +++ b/src/client.ts @@ -59,11 +59,15 @@ export class RDSClient extends Operator { try { const _conn = await this.#pool.getConnection(); const conn = new RDSConnection(_conn); - if (this.beforeQueryHandler) { - conn.beforeQuery(this.beforeQueryHandler); + if (this.beforeQueryHandlers.length > 0) { + for (const handler of this.beforeQueryHandlers) { + conn.beforeQuery(handler); + } } - if (this.afterQueryHandler) { - conn.afterQuery(this.afterQueryHandler); + if (this.afterQueryHandlers.length > 0) { + for (const handler of this.afterQueryHandlers) { + conn.afterQuery(handler); + } } return conn; } catch (err) { @@ -88,11 +92,15 @@ export class RDSClient extends Operator { throw err; } const tran = new RDSTransaction(conn); - if (this.beforeQueryHandler) { - tran.beforeQuery(this.beforeQueryHandler); + if (this.beforeQueryHandlers.length > 0) { + for (const handler of this.beforeQueryHandlers) { + tran.beforeQuery(handler); + } } - if (this.afterQueryHandler) { - tran.afterQuery(this.afterQueryHandler); + if (this.afterQueryHandlers.length > 0) { + for (const handler of this.afterQueryHandlers) { + tran.afterQuery(handler); + } } return tran; } diff --git a/src/operator.ts b/src/operator.ts index 7b50b84..e5929fc 100644 --- a/src/operator.ts +++ b/src/operator.ts @@ -16,17 +16,17 @@ const debug = debuglog('ali-rds:operator'); * Operator Interface */ export abstract class Operator { - protected beforeQueryHandler?: BeforeQueryHandler; - protected afterQueryHandler?: AfterQueryHandler; + protected beforeQueryHandlers: BeforeQueryHandler[] = []; + protected afterQueryHandlers: AfterQueryHandler[] = []; get literals() { return literals; } beforeQuery(beforeQueryHandler: BeforeQueryHandler) { - this.beforeQueryHandler = beforeQueryHandler; + this.beforeQueryHandlers.push(beforeQueryHandler); } afterQuery(afterQueryHandler: AfterQueryHandler) { - this.afterQueryHandler = afterQueryHandler; + this.afterQueryHandlers.push(afterQueryHandler); } escape(value: any, stringifyObjects?: boolean, timeZone?: string): string { @@ -57,21 +57,20 @@ export abstract class Operator { if (values) { sql = this.format(sql, values); } - if (this.beforeQueryHandler) { - const newSql = this.beforeQueryHandler(sql); - if (newSql) { - sql = newSql; + if (this.beforeQueryHandlers.length > 0) { + for (const beforeQueryHandler of this.beforeQueryHandlers) { + const newSql = beforeQueryHandler(sql); + if (newSql) { + sql = newSql; + } } } debug('query %o', sql); - let execDuration: number; const queryStart = Date.now(); + let rows: any; + let lastError: Error | undefined; try { - const rows = await this._query(sql); - execDuration = Date.now() - queryStart; - if (this.afterQueryHandler) { - this.afterQueryHandler(sql, rows, execDuration); - } + rows = await this._query(sql); if (Array.isArray(rows)) { debug('query get %o rows', rows.length); } else { @@ -79,13 +78,17 @@ export abstract class Operator { } return rows; } catch (err) { - execDuration = Date.now() - queryStart; + lastError = err; err.stack = `${err.stack}\n sql: ${sql}`; - if (this.afterQueryHandler) { - this.afterQueryHandler(sql, null, execDuration, err); - } debug('query error: %o', err); throw err; + } finally { + if (this.afterQueryHandlers.length > 0) { + const execDuration = Date.now() - queryStart; + for (const afterQueryHandler of this.afterQueryHandlers) { + afterQueryHandler(sql, rows, execDuration, lastError); + } + } } } diff --git a/test/client.test.ts b/test/client.test.ts index 1c201ff..27a83e3 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -1209,14 +1209,22 @@ describe('test/client.test.ts', () => { const db = new RDSClient(config); let count = 0; let lastSql = ''; + let counter2Before = 0; + let counter2After = 0; db.beforeQuery(sql => { count++; lastSql = sql; }); + db.beforeQuery(() => { + counter2Before++; + }); let lastArgs: any; db.afterQuery((...args) => { lastArgs = args; }); + db.afterQuery(() => { + counter2After++; + }); await db.query('select * from ?? limit 10', [ table ]); assert.equal(lastSql, 'select * from `ali-sdk-test-user` limit 10'); assert.equal(lastArgs[0], lastSql); @@ -1256,6 +1264,8 @@ describe('test/client.test.ts', () => { assert.equal(lastArgs[0], lastSql); assert.equal(lastArgs[1].affectedRows, 1); assert.equal(count, 4); + assert.equal(counter2Before, 4); + assert.equal(counter2After, 4); }); }); }); diff --git a/test/operator.test.ts b/test/operator.test.ts index 405d3d4..839f72b 100644 --- a/test/operator.test.ts +++ b/test/operator.test.ts @@ -95,22 +95,25 @@ describe('test/operator.test.ts', () => { it('should get query result on after hook', async () => { const op = new CustomOperator(); + let called = false; op.afterQuery((sql, result, execDuration, err) => { assert.equal(sql, 'foo'); assert.deepEqual(result, { sql }); assert.equal(typeof execDuration, 'number'); assert(execDuration >= 0); assert.equal(err, undefined); + called = true; }); const result = await op.query('foo'); assert.equal(result.sql, 'foo'); + assert(called); }); it('should get query error on after hook', async () => { const op = new CustomOperator(); op.afterQuery((sql, result, execDuration, err) => { assert.equal(sql, 'error'); - assert.equal(result, null); + assert.equal(result, undefined); assert.equal(typeof execDuration, 'number'); assert(execDuration >= 0); assert(err instanceof Error);