Skip to content

Commit 300c00d

Browse files
authored
Merge pull request #84 from cdleo/main
Adding capabilities to change the error returned by the driver
2 parents 7cdac3f + 3c089fb commit 300c00d

File tree

3 files changed

+45
-56
lines changed

3 files changed

+45
-56
lines changed

conn.go

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ type Conn struct {
1717
// It will trigger PrePing, Ping, PostPing hooks.
1818
//
1919
// If the original connection does not satisfy "database/sql/driver".Pinger, it does nothing.
20-
func (conn *Conn) Ping(c context.Context) error {
21-
var err error
20+
func (conn *Conn) Ping(c context.Context) (err error) {
2221
var ctx interface{}
2322
hooks := conn.Proxy.getHooks(c)
2423

2524
if hooks != nil {
26-
defer func() { hooks.postPing(c, ctx, conn, err) }()
25+
defer func() { err = hooks.postPing(c, ctx, conn, err) }()
2726
if ctx, err = hooks.prePing(c, conn); err != nil {
2827
return err
2928
}
@@ -49,31 +48,30 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) {
4948
}
5049

5150
// PrepareContext returns a prepared statement which is wrapped by Stmt.
52-
func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) {
51+
func (conn *Conn) PrepareContext(c context.Context, query string) (stmt driver.Stmt, err error) {
5352
var ctx interface{}
54-
var stmt = &Stmt{
53+
var stmtAux = &Stmt{
5554
QueryString: query,
5655
Proxy: conn.Proxy,
5756
Conn: conn,
5857
}
59-
var err error
6058
hooks := conn.Proxy.getHooks(c)
6159
if hooks != nil {
62-
defer func() { hooks.postPrepare(c, ctx, stmt, err) }()
63-
if ctx, err = hooks.prePrepare(c, stmt); err != nil {
60+
defer func() { err = hooks.postPrepare(c, ctx, stmtAux, err) }()
61+
if ctx, err = hooks.prePrepare(c, stmtAux); err != nil {
6462
return nil, err
6563
}
6664
}
6765

6866
if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok {
69-
stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString)
67+
stmtAux.Stmt, err = connCtx.PrepareContext(c, stmtAux.QueryString)
7068
} else {
71-
stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString)
69+
stmtAux.Stmt, err = conn.Conn.Prepare(stmtAux.QueryString)
7270
if err == nil {
7371
select {
7472
default:
7573
case <-c.Done():
76-
stmt.Stmt.Close()
74+
stmtAux.Stmt.Close()
7775
return nil, c.Err()
7876
}
7977
}
@@ -83,21 +81,20 @@ func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt,
8381
}
8482

8583
if hooks != nil {
86-
if err = hooks.prepare(c, ctx, stmt); err != nil {
84+
if err = hooks.prepare(c, ctx, stmtAux); err != nil {
8785
return nil, err
8886
}
8987
}
90-
return stmt, nil
88+
return stmtAux, nil
9189
}
9290

9391
// Close calls the original Close method.
94-
func (conn *Conn) Close() error {
92+
func (conn *Conn) Close() (err error) {
9593
ctx := context.Background()
96-
var err error
9794
var myctx interface{}
9895

9996
if hooks := conn.Proxy.hooks; hooks != nil {
100-
defer func() { hooks.postClose(ctx, myctx, conn, err) }()
97+
defer func() { err = hooks.postClose(ctx, myctx, conn, err) }()
10198
if myctx, err = hooks.preClose(ctx, conn); err != nil {
10299
return err
103100
}
@@ -123,14 +120,12 @@ func (conn *Conn) Begin() (driver.Tx, error) {
123120

124121
// BeginTx starts and returns a new transaction which is wrapped by Tx.
125122
// It will trigger PreBegin, Begin, PostBegin hooks.
126-
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (driver.Tx, error) {
123+
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
127124
// set the hooks.
128-
var err error
129125
var ctx interface{}
130-
var tx driver.Tx
131126
hooks := conn.Proxy.getHooks(c)
132127
if hooks != nil {
133-
defer func() { hooks.postBegin(c, ctx, conn, err) }()
128+
defer func() { err = hooks.postBegin(c, ctx, conn, err) }()
134129
if ctx, err = hooks.preBegin(c, conn); err != nil {
135130
return nil, err
136131
}
@@ -193,7 +188,7 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
193188
// It will trigger PreExec, Exec, PostExec hooks.
194189
//
195190
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
196-
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
191+
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (drv driver.Result, err error) {
197192
execer, exOk := conn.Conn.(driver.Execer)
198193
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
199194
if !exOk && !exCtxOk {
@@ -207,19 +202,17 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
207202
Conn: conn,
208203
}
209204
var ctx interface{}
210-
var err error
211-
var result driver.Result
212205
hooks := conn.Proxy.getHooks(c)
213206
if hooks != nil {
214-
defer func() { hooks.postExec(c, ctx, stmt, args, result, err) }()
207+
defer func() { err = hooks.postExec(c, ctx, stmt, args, drv, err) }()
215208
if ctx, err = hooks.preExec(c, stmt, args); err != nil {
216209
return nil, err
217210
}
218211
}
219212

220213
// call the original method.
221214
if execerCtx != nil {
222-
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
215+
drv, err = execerCtx.ExecContext(c, stmt.QueryString, args)
223216
} else {
224217
select {
225218
default:
@@ -230,19 +223,18 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
230223
if err0 != nil {
231224
return nil, err0
232225
}
233-
result, err = execer.Exec(stmt.QueryString, dargs)
226+
drv, err = execer.Exec(stmt.QueryString, dargs)
234227
}
235228
if err != nil {
236229
return nil, err
237230
}
238231

239232
if hooks != nil {
240-
if err = hooks.exec(c, ctx, stmt, args, result); err != nil {
233+
if err = hooks.exec(c, ctx, stmt, args, drv); err != nil {
241234
return nil, err
242235
}
243236
}
244-
245-
return result, nil
237+
return drv, err
246238
}
247239

248240
// Query executes a query that may return rows.
@@ -258,7 +250,7 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
258250
// It wil trigger PreQuery, Query, PostQuery hooks.
259251
//
260252
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
261-
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
253+
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
262254
queryer, qok := conn.Conn.(driver.Queryer)
263255
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
264256
if !qok && !qCtxOk {
@@ -271,11 +263,9 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
271263
Conn: conn,
272264
}
273265
var ctx interface{}
274-
var err error
275-
var rows driver.Rows
276266
hooks := conn.Proxy.getHooks(c)
277267
if hooks != nil {
278-
defer func() { hooks.postQuery(c, ctx, stmt, args, rows, err) }()
268+
defer func() { err = hooks.postQuery(c, ctx, stmt, args, rows, err) }()
279269
if ctx, err = hooks.preQuery(c, stmt, args); err != nil {
280270
return nil, err
281271
}
@@ -343,13 +333,12 @@ type sessionResetter interface {
343333
}
344334

345335
// ResetSession resets the state of Conn.
346-
func (conn *Conn) ResetSession(ctx context.Context) error {
347-
var err error
336+
func (conn *Conn) ResetSession(ctx context.Context) (err error) {
348337
var myctx interface{}
349338
hooks := conn.Proxy.getHooks(ctx)
350339

351340
if hooks != nil {
352-
defer func() { hooks.postResetSession(ctx, myctx, conn, err) }()
341+
defer func() { err = hooks.postResetSession(ctx, myctx, conn, err) }()
353342
if myctx, err = hooks.preResetSession(ctx, conn); err != nil {
354343
return err
355344
}

hooks.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ func (h *HooksContext) ping(c context.Context, ctx interface{}, conn *Conn) erro
412412

413413
func (h *HooksContext) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error {
414414
if h == nil || h.PostPing == nil {
415-
return nil
415+
return err
416416
}
417417
return h.PostPing(c, ctx, conn, err)
418418
}
@@ -433,7 +433,7 @@ func (h *HooksContext) open(c context.Context, ctx interface{}, conn *Conn) erro
433433

434434
func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error {
435435
if h == nil || h.PostOpen == nil {
436-
return nil
436+
return err
437437
}
438438
return h.PostOpen(c, ctx, conn, err)
439439
}
@@ -454,7 +454,7 @@ func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) e
454454

455455
func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
456456
if h == nil || h.PostPrepare == nil {
457-
return nil
457+
return err
458458
}
459459
return h.PostPrepare(c, ctx, stmt, err)
460460
}
@@ -475,7 +475,7 @@ func (h *HooksContext) exec(c context.Context, ctx interface{}, stmt *Stmt, args
475475

476476
func (h *HooksContext) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error {
477477
if h == nil || h.PostExec == nil {
478-
return nil
478+
return err
479479
}
480480
return h.PostExec(c, ctx, stmt, args, result, err)
481481
}
@@ -496,7 +496,7 @@ func (h *HooksContext) query(c context.Context, ctx interface{}, stmt *Stmt, arg
496496

497497
func (h *HooksContext) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error {
498498
if h == nil || h.PostQuery == nil {
499-
return nil
499+
return err
500500
}
501501
return h.PostQuery(c, ctx, stmt, args, rows, err)
502502
}
@@ -517,7 +517,7 @@ func (h *HooksContext) begin(c context.Context, ctx interface{}, conn *Conn) err
517517

518518
func (h *HooksContext) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error {
519519
if h == nil || h.PostBegin == nil {
520-
return nil
520+
return err
521521
}
522522
return h.PostBegin(c, ctx, conn, err)
523523
}
@@ -538,7 +538,7 @@ func (h *HooksContext) commit(c context.Context, ctx interface{}, tx *Tx) error
538538

539539
func (h *HooksContext) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error {
540540
if h == nil || h.PostCommit == nil {
541-
return nil
541+
return err
542542
}
543543
return h.PostCommit(c, ctx, tx, err)
544544
}
@@ -559,7 +559,7 @@ func (h *HooksContext) rollback(c context.Context, ctx interface{}, tx *Tx) erro
559559

560560
func (h *HooksContext) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error {
561561
if h == nil || h.PostRollback == nil {
562-
return nil
562+
return err
563563
}
564564
return h.PostRollback(c, ctx, tx, err)
565565
}
@@ -580,7 +580,7 @@ func (h *HooksContext) close(c context.Context, ctx interface{}, conn *Conn) err
580580

581581
func (h *HooksContext) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error {
582582
if h == nil || h.PostClose == nil {
583-
return nil
583+
return err
584584
}
585585
return h.PostClose(c, ctx, conn, err)
586586
}
@@ -601,7 +601,7 @@ func (h *HooksContext) resetSession(c context.Context, ctx interface{}, conn *Co
601601

602602
func (h *HooksContext) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
603603
if h == nil || h.PostResetSession == nil {
604-
return nil
604+
return err
605605
}
606606
return h.PostResetSession(c, ctx, conn, err)
607607
}

logging_hook_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (h *loggingHook) postPing(c context.Context, ctx interface{}, conn *Conn, e
3737
h.mu.Lock()
3838
defer h.mu.Unlock()
3939
fmt.Fprintln(h, "[PostPing]")
40-
return nil
40+
return err
4141
}
4242

4343
func (h *loggingHook) preOpen(c context.Context, name string) (interface{}, error) {
@@ -58,7 +58,7 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e
5858
h.mu.Lock()
5959
defer h.mu.Unlock()
6060
fmt.Fprintln(h, "[PostOpen]")
61-
return nil
61+
return err
6262
}
6363

6464
func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
@@ -79,7 +79,7 @@ func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt
7979
h.mu.Lock()
8080
defer h.mu.Unlock()
8181
fmt.Fprintln(h, "[PostPrepare]")
82-
return nil
82+
return err
8383
}
8484

8585
func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
@@ -100,7 +100,7 @@ func (h *loggingHook) postExec(c context.Context, ctx interface{}, stmt *Stmt, a
100100
h.mu.Lock()
101101
defer h.mu.Unlock()
102102
fmt.Fprintln(h, "[PostExec]")
103-
return nil
103+
return err
104104
}
105105

106106
func (h *loggingHook) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
@@ -121,7 +121,7 @@ func (h *loggingHook) postQuery(c context.Context, ctx interface{}, stmt *Stmt,
121121
h.mu.Lock()
122122
defer h.mu.Unlock()
123123
fmt.Fprintln(h, "[PostQuery]")
124-
return nil
124+
return err
125125
}
126126

127127
func (h *loggingHook) preBegin(c context.Context, conn *Conn) (interface{}, error) {
@@ -142,7 +142,7 @@ func (h *loggingHook) postBegin(c context.Context, ctx interface{}, conn *Conn,
142142
h.mu.Lock()
143143
defer h.mu.Unlock()
144144
fmt.Fprintln(h, "[PostBegin]")
145-
return nil
145+
return err
146146
}
147147

148148
func (h *loggingHook) preCommit(c context.Context, tx *Tx) (interface{}, error) {
@@ -163,7 +163,7 @@ func (h *loggingHook) postCommit(c context.Context, ctx interface{}, tx *Tx, err
163163
h.mu.Lock()
164164
defer h.mu.Unlock()
165165
fmt.Fprintln(h, "[PostCommit]")
166-
return nil
166+
return err
167167
}
168168

169169
func (h *loggingHook) preRollback(c context.Context, tx *Tx) (interface{}, error) {
@@ -184,7 +184,7 @@ func (h *loggingHook) postRollback(c context.Context, ctx interface{}, tx *Tx, e
184184
h.mu.Lock()
185185
defer h.mu.Unlock()
186186
fmt.Fprintln(h, "[PostRollback]")
187-
return nil
187+
return err
188188
}
189189

190190
func (h *loggingHook) preClose(c context.Context, conn *Conn) (interface{}, error) {
@@ -205,7 +205,7 @@ func (h *loggingHook) postClose(c context.Context, ctx interface{}, conn *Conn,
205205
h.mu.Lock()
206206
defer h.mu.Unlock()
207207
fmt.Fprintln(h, "[PostClose]")
208-
return nil
208+
return err
209209
}
210210

211211
func (h *loggingHook) preResetSession(c context.Context, conn *Conn) (interface{}, error) {
@@ -217,7 +217,7 @@ func (h *loggingHook) resetSession(c context.Context, ctx interface{}, conn *Con
217217
}
218218

219219
func (h *loggingHook) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
220-
return nil
220+
return err
221221
}
222222

223223
func (h *loggingHook) preIsValid(conn *Conn) (interface{}, error) {

0 commit comments

Comments
 (0)