Skip to content

Commit 86478ad

Browse files
committed
Create ctx from testing context, instead of using context.Background
1 parent 95fa041 commit 86478ad

File tree

1 file changed

+48
-20
lines changed

1 file changed

+48
-20
lines changed

pkg/pg/txdb.go

+48-20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414

1515
"github.com/jmoiron/sqlx"
1616
"go.uber.org/multierr"
17+
18+
"github.com/smartcontractkit/chainlink-common/pkg/utils"
1719
)
1820

1921
// txdb is a simplified version of https://github.com/DATA-DOG/go-txdb
@@ -32,7 +34,7 @@ import (
3234
// store to use the raw DialectPostgres dialect and setup a one-use database.
3335
// See heavyweight.FullTestDB() as a convenience function to help you do this,
3436
// but please use sparingly because as it's name implies, it is expensive.
35-
func RegisterTxDb(dbURL string) error {
37+
func RegisterTxDb(ctx context.Context, dbURL string) error {
3638
drivers := sql.Drivers()
3739
for _, driver := range drivers {
3840
if driver == string(TransactionWrappedPostgres) {
@@ -58,8 +60,15 @@ func RegisterTxDb(dbURL string) error {
5860
if !strings.HasSuffix(parsed.Path, "_test") {
5961
return fmt.Errorf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:])
6062
}
63+
abort := make(chan struct{})
64+
go func() {
65+
<-ctx.Done()
66+
abort <- struct{}{} // abort all queries when context is cancelled
67+
}()
68+
6169
name := string(TransactionWrappedPostgres)
6270
sql.Register(name, &txDriver{
71+
abort: abort,
6372
dbURL: dbURL,
6473
conns: make(map[string]*conn),
6574
})
@@ -76,6 +85,7 @@ var _ driver.SessionResetter = &conn{}
7685
// When `Close` is called, transaction is rolled back.
7786
type txDriver struct {
7887
sync.Mutex
88+
abort <-chan struct{}
7989
db *sql.DB
8090
conns map[string]*conn
8191

@@ -99,7 +109,7 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) {
99109
if err != nil {
100110
return nil, err
101111
}
102-
c = &conn{tx: tx, opened: 1, dsn: dsn}
112+
c = &conn{abort: d.abort, tx: tx, opened: 1, dsn: dsn}
103113
c.removeSelf = func() error {
104114
return d.deleteConn(c)
105115
}
@@ -130,6 +140,7 @@ func (d *txDriver) deleteConn(c *conn) error {
130140

131141
type conn struct {
132142
sync.Mutex
143+
abort <-chan struct{}
133144
dsn string
134145
tx *sql.Tx // tx may be shared by many conns, definitive one lives in the map keyed by DSN on the txDriver. Do not modify from conn
135146
closed bool
@@ -156,26 +167,32 @@ func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, err
156167

157168
// Prepare returns a prepared statement, bound to this connection.
158169
func (c *conn) Prepare(query string) (driver.Stmt, error) {
159-
return c.PrepareContext(context.Background(), query)
170+
ctx, cancel := utils.ContextFromChan(c.abort)
171+
defer cancel()
172+
return c.PrepareContext(ctx, query)
160173
}
161174

162175
// Implement the "ConnPrepareContext" interface
163-
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
176+
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
164177
c.Lock()
165178
defer c.Unlock()
166179
if c.closed {
167180
panic("conn is closed")
168181
}
169182

170-
// TODO: Fix context handling
171-
// FIXME: It is not safe to give the passed in context to the tx directly
183+
// It is not safe to give the passed in context to the tx directly
172184
// because the tx is shared by many conns and cancelling the context will
173-
// destroy the tx which can affect other conns
174-
st, err := c.tx.PrepareContext(context.Background(), query)
185+
// destroy the tx which can affect other conns. Instead, we pass the context
186+
// passed to NewSqlxDb when the database was set up so the operation can at
187+
// least be aborted immediately if the whole test is interrupted.
188+
ctx, cancel := utils.ContextFromChan(c.abort)
189+
defer cancel()
190+
191+
st, err := c.tx.PrepareContext(ctx, query)
175192
if err != nil {
176193
return nil, err
177194
}
178-
return &stmt{st, c}, nil
195+
return &stmt{c.abort, st, c}, nil
179196
}
180197

181198
// IsValid is called prior to placing the connection into the
@@ -212,8 +229,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
212229
panic("conn is closed")
213230
}
214231

215-
// TODO: Fix context handling
216-
rs, err := c.tx.QueryContext(context.Background(), query, mapNamedArgs(args)...)
232+
ctx, cancel := utils.ContextFromChan(c.abort)
233+
defer cancel()
234+
235+
rs, err := c.tx.QueryContext(ctx, query, mapNamedArgs(args)...)
217236
if err != nil {
218237
return nil, err
219238
}
@@ -229,8 +248,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
229248
if c.closed {
230249
return nil, fmt.Errorf("conn is closed")
231250
}
232-
// TODO: Fix context handling
233-
return c.tx.ExecContext(context.Background(), query, mapNamedArgs(args)...)
251+
ctx, cancel := utils.ContextFromChan(c.abort)
252+
defer cancel()
253+
254+
return c.tx.ExecContext(ctx, query, mapNamedArgs(args)...)
234255
}
235256

236257
// tryOpen attempts to increment the open count, but returns false if closed.
@@ -305,8 +326,9 @@ func (tx tx) Rollback() error {
305326
}
306327

307328
type stmt struct {
308-
st *sql.Stmt
309-
conn *conn
329+
abort <-chan struct{}
330+
st *sql.Stmt
331+
conn *conn
310332
}
311333

312334
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
@@ -325,8 +347,11 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
325347
if s.conn.closed {
326348
panic("conn is closed")
327349
}
328-
// TODO: Fix context handling
329-
return s.st.ExecContext(context.Background(), mapNamedArgs(args)...)
350+
351+
ctx, cancel := utils.ContextFromChan(s.abort)
352+
defer cancel()
353+
354+
return s.st.ExecContext(ctx, mapNamedArgs(args)...)
330355
}
331356

332357
func mapArgs(args []driver.Value) (res []interface{}) {
@@ -358,14 +383,17 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
358383
}
359384

360385
// Implement the "StmtQueryContext" interface
361-
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
386+
func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver.Rows, error) {
362387
s.conn.Lock()
363388
defer s.conn.Unlock()
364389
if s.conn.closed {
365390
panic("conn is closed")
366391
}
367-
// TODO: Fix context handling
368-
rows, err := s.st.QueryContext(context.Background(), mapNamedArgs(args)...)
392+
393+
ctx, cancel := utils.ContextFromChan(s.abort)
394+
defer cancel()
395+
396+
rows, err := s.st.QueryContext(ctx, mapNamedArgs(args)...)
369397
if err != nil {
370398
return nil, err
371399
}

0 commit comments

Comments
 (0)