Skip to content

Commit 2677a07

Browse files
authored
Merge pull request #77 from stefantds/st/support-execer-context
Support context interfaces for Exec and Query
2 parents 8ec7bf3 + 0580e9d commit 2677a07

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

conn.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,14 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
189189
panic("not supported")
190190
}
191191

192-
// ExecContext calls the original Exec method of the connection.
192+
// ExecContext calls the original ExecContext (or Exec as a fallback) method of the connection.
193193
// It will trigger PreExec, Exec, PostExec hooks.
194194
//
195-
// If the original connection does not satisfy "database/sql/driver".Execer, it return ErrSkip error.
195+
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
196196
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
197-
execer, ok := conn.Conn.(driver.Execer)
198-
if !ok {
197+
execer, exOk := conn.Conn.(driver.Execer)
198+
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
199+
if !exOk && !exCtxOk {
199200
return nil, driver.ErrSkip
200201
}
201202

@@ -217,7 +218,7 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
217218
}
218219

219220
// call the original method.
220-
if execerCtx, ok := execer.(driver.ExecerContext); ok {
221+
if execerCtx != nil {
221222
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
222223
} else {
223224
select {
@@ -256,10 +257,11 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
256257
// QueryContext executes a query that may return rows.
257258
// It wil trigger PreQuery, Query, PostQuery hooks.
258259
//
259-
// If the original connection does not satisfy "database/sql/driver".Queryer, it return ErrSkip error.
260+
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
260261
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
261-
queryer, ok := conn.Conn.(driver.Queryer)
262-
if !ok {
262+
queryer, qok := conn.Conn.(driver.Queryer)
263+
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
264+
if !qok && !qCtxOk {
263265
return nil, driver.ErrSkip
264266
}
265267

@@ -280,7 +282,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
280282
}
281283

282284
// call the original method.
283-
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
285+
if queryerCtx != nil {
284286
rows, err = queryerCtx.QueryContext(c, stmt.QueryString, args)
285287
} else {
286288
select {

0 commit comments

Comments
 (0)