Skip to content

Commit 0580e9d

Browse files
author
Stefan Tudose
committed
support QueryerContext interface
1 parent 489a697 commit 0580e9d

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

conn.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
190190
}
191191

192192
// ExecContext calls the original ExecContext (or Exec as a fallback) method of the connection.
193-
// It will trigger PreExec, PostExec hooks.
193+
// It will trigger PreExec, Exec, PostExec hooks.
194194
//
195-
// If the original connection doesn't satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
195+
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
196196
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
197197
execer, exOk := conn.Conn.(driver.Execer)
198198
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
@@ -257,10 +257,11 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
257257
// QueryContext executes a query that may return rows.
258258
// It wil trigger PreQuery, Query, PostQuery hooks.
259259
//
260-
// If the original connection does not satisfy "database/sql/driver".Queryer, it return ErrSkip error.
260+
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
261261
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
262-
queryer, ok := conn.Conn.(driver.Queryer)
263-
if !ok {
262+
queryer, qok := conn.Conn.(driver.Queryer)
263+
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
264+
if !qok && !qCtxOk {
264265
return nil, driver.ErrSkip
265266
}
266267

@@ -281,7 +282,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
281282
}
282283

283284
// call the original method.
284-
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
285+
if queryerCtx != nil {
285286
rows, err = queryerCtx.QueryContext(c, stmt.QueryString, args)
286287
} else {
287288
select {

0 commit comments

Comments
 (0)