@@ -14,6 +14,8 @@ import (
14
14
15
15
"github.com/jmoiron/sqlx"
16
16
"go.uber.org/multierr"
17
+
18
+ "github.com/smartcontractkit/chainlink-common/pkg/utils"
17
19
)
18
20
19
21
// txdb is a simplified version of https://github.com/DATA-DOG/go-txdb
@@ -32,7 +34,7 @@ import (
32
34
// store to use the raw DialectPostgres dialect and setup a one-use database.
33
35
// See heavyweight.FullTestDB() as a convenience function to help you do this,
34
36
// but please use sparingly because as it's name implies, it is expensive.
35
- func RegisterTxDb (dbURL string ) error {
37
+ func RegisterTxDb (ctx context. Context , dbURL string ) error {
36
38
drivers := sql .Drivers ()
37
39
for _ , driver := range drivers {
38
40
if driver == string (TransactionWrappedPostgres ) {
@@ -58,8 +60,15 @@ func RegisterTxDb(dbURL string) error {
58
60
if ! strings .HasSuffix (parsed .Path , "_test" ) {
59
61
return fmt .Errorf ("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable" , parsed .Path [1 :])
60
62
}
63
+ abort := make (chan struct {})
64
+ go func () {
65
+ <- ctx .Done ()
66
+ abort <- struct {}{} // abort all queries when context is cancelled
67
+ }()
68
+
61
69
name := string (TransactionWrappedPostgres )
62
70
sql .Register (name , & txDriver {
71
+ abort : abort ,
63
72
dbURL : dbURL ,
64
73
conns : make (map [string ]* conn ),
65
74
})
@@ -76,6 +85,7 @@ var _ driver.SessionResetter = &conn{}
76
85
// When `Close` is called, transaction is rolled back.
77
86
type txDriver struct {
78
87
sync.Mutex
88
+ abort <- chan struct {}
79
89
db * sql.DB
80
90
conns map [string ]* conn
81
91
@@ -99,7 +109,7 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) {
99
109
if err != nil {
100
110
return nil , err
101
111
}
102
- c = & conn {tx : tx , opened : 1 , dsn : dsn }
112
+ c = & conn {abort : d . abort , tx : tx , opened : 1 , dsn : dsn }
103
113
c .removeSelf = func () error {
104
114
return d .deleteConn (c )
105
115
}
@@ -130,6 +140,7 @@ func (d *txDriver) deleteConn(c *conn) error {
130
140
131
141
type conn struct {
132
142
sync.Mutex
143
+ abort <- chan struct {}
133
144
dsn string
134
145
tx * sql.Tx // tx may be shared by many conns, definitive one lives in the map keyed by DSN on the txDriver. Do not modify from conn
135
146
closed bool
@@ -156,26 +167,32 @@ func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, err
156
167
157
168
// Prepare returns a prepared statement, bound to this connection.
158
169
func (c * conn ) Prepare (query string ) (driver.Stmt , error ) {
159
- return c .PrepareContext (context .Background (), query )
170
+ ctx , cancel := utils .ContextFromChan (c .abort )
171
+ defer cancel ()
172
+ return c .PrepareContext (ctx , query )
160
173
}
161
174
162
175
// Implement the "ConnPrepareContext" interface
163
- func (c * conn ) PrepareContext (ctx context.Context , query string ) (driver.Stmt , error ) {
176
+ func (c * conn ) PrepareContext (_ context.Context , query string ) (driver.Stmt , error ) {
164
177
c .Lock ()
165
178
defer c .Unlock ()
166
179
if c .closed {
167
180
panic ("conn is closed" )
168
181
}
169
182
170
- // TODO: Fix context handling
171
- // FIXME: It is not safe to give the passed in context to the tx directly
183
+ // It is not safe to give the passed in context to the tx directly
172
184
// because the tx is shared by many conns and cancelling the context will
173
- // destroy the tx which can affect other conns
174
- st , err := c .tx .PrepareContext (context .Background (), query )
185
+ // destroy the tx which can affect other conns. Instead, we pass the context
186
+ // passed to NewSqlxDb when the database was set up so the operation can at
187
+ // least be aborted immediately if the whole test is interrupted.
188
+ ctx , cancel := utils .ContextFromChan (c .abort )
189
+ defer cancel ()
190
+
191
+ st , err := c .tx .PrepareContext (ctx , query )
175
192
if err != nil {
176
193
return nil , err
177
194
}
178
- return & stmt {st , c }, nil
195
+ return & stmt {c . abort , st , c }, nil
179
196
}
180
197
181
198
// IsValid is called prior to placing the connection into the
@@ -212,8 +229,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
212
229
panic ("conn is closed" )
213
230
}
214
231
215
- // TODO: Fix context handling
216
- rs , err := c .tx .QueryContext (context .Background (), query , mapNamedArgs (args )... )
232
+ ctx , cancel := utils .ContextFromChan (c .abort )
233
+ defer cancel ()
234
+
235
+ rs , err := c .tx .QueryContext (ctx , query , mapNamedArgs (args )... )
217
236
if err != nil {
218
237
return nil , err
219
238
}
@@ -229,8 +248,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
229
248
if c .closed {
230
249
return nil , fmt .Errorf ("conn is closed" )
231
250
}
232
- // TODO: Fix context handling
233
- return c .tx .ExecContext (context .Background (), query , mapNamedArgs (args )... )
251
+ ctx , cancel := utils .ContextFromChan (c .abort )
252
+ defer cancel ()
253
+
254
+ return c .tx .ExecContext (ctx , query , mapNamedArgs (args )... )
234
255
}
235
256
236
257
// tryOpen attempts to increment the open count, but returns false if closed.
@@ -305,8 +326,9 @@ func (tx tx) Rollback() error {
305
326
}
306
327
307
328
type stmt struct {
308
- st * sql.Stmt
309
- conn * conn
329
+ abort <- chan struct {}
330
+ st * sql.Stmt
331
+ conn * conn
310
332
}
311
333
312
334
func (s stmt ) Exec (args []driver.Value ) (driver.Result , error ) {
@@ -325,8 +347,11 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
325
347
if s .conn .closed {
326
348
panic ("conn is closed" )
327
349
}
328
- // TODO: Fix context handling
329
- return s .st .ExecContext (context .Background (), mapNamedArgs (args )... )
350
+
351
+ ctx , cancel := utils .ContextFromChan (s .abort )
352
+ defer cancel ()
353
+
354
+ return s .st .ExecContext (ctx , mapNamedArgs (args )... )
330
355
}
331
356
332
357
func mapArgs (args []driver.Value ) (res []interface {}) {
@@ -358,14 +383,17 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
358
383
}
359
384
360
385
// Implement the "StmtQueryContext" interface
361
- func (s * stmt ) QueryContext (ctx context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
386
+ func (s * stmt ) QueryContext (_ context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
362
387
s .conn .Lock ()
363
388
defer s .conn .Unlock ()
364
389
if s .conn .closed {
365
390
panic ("conn is closed" )
366
391
}
367
- // TODO: Fix context handling
368
- rows , err := s .st .QueryContext (context .Background (), mapNamedArgs (args )... )
392
+
393
+ ctx , cancel := utils .ContextFromChan (s .abort )
394
+ defer cancel ()
395
+
396
+ rows , err := s .st .QueryContext (ctx , mapNamedArgs (args )... )
369
397
if err != nil {
370
398
return nil , err
371
399
}
0 commit comments