Skip to content

Commit 9035c2a

Browse files
committed
Only abort tx's when last connection is closed
Also: convert rest of panic isn't ordinary errors
1 parent 8e479fe commit 9035c2a

File tree

3 files changed

+28
-32
lines changed

3 files changed

+28
-32
lines changed

pkg/pg/pg.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
func NewSqlxDB(t testing.TB, dbURL string) *sqlx.DB {
1616
tests.SkipShortDB(t)
17-
err := RegisterTxDb(tests.Context(t), dbURL)
17+
err := RegisterTxDb(dbURL)
1818
if err != nil {
1919
t.Errorf("failed to register txdb dialect: %s", err.Error())
2020
return nil

pkg/pg/txdb.go

+25-29
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7-
"flag"
87
"fmt"
98
"io"
109
"net/url"
@@ -34,18 +33,14 @@ import (
3433
// store to use the raw DialectPostgres dialect and setup a one-use database.
3534
// See heavyweight.FullTestDB() as a convenience function to help you do this,
3635
// but please use sparingly because as it's name implies, it is expensive.
37-
func RegisterTxDb(ctx context.Context, dbURL string) error {
36+
func RegisterTxDb(dbURL string) error {
3837
drivers := sql.Drivers()
3938
for _, driver := range drivers {
4039
if driver == string(TransactionWrappedPostgres) {
4140
// TxDB driver already registered
4241
return nil
4342
}
4443
}
45-
testing.Init()
46-
if !flag.Parsed() {
47-
flag.Parse()
48-
}
4944
if testing.Short() {
5045
// -short tests don't need a DB
5146
return nil
@@ -60,15 +55,10 @@ func RegisterTxDb(ctx context.Context, dbURL string) error {
6055
if !strings.HasSuffix(parsed.Path, "_test") {
6156
return fmt.Errorf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:])
6257
}
63-
abort := make(chan struct{})
64-
go func() {
65-
<-ctx.Done()
66-
abort <- struct{}{} // abort all queries when context is cancelled
67-
}()
6858

6959
name := string(TransactionWrappedPostgres)
7060
sql.Register(name, &txDriver{
71-
abort: abort,
61+
abort: make(chan struct{}),
7262
dbURL: dbURL,
7363
conns: make(map[string]*conn),
7464
})
@@ -85,7 +75,7 @@ var _ driver.SessionResetter = &conn{}
8575
// When `Close` is called, transaction is rolled back.
8676
type txDriver struct {
8777
sync.Mutex
88-
abort <-chan struct{}
78+
abort chan struct{}
8979
db *sql.DB
9080
conns map[string]*conn
9181

@@ -130,6 +120,7 @@ func (d *txDriver) deleteConn(c *conn) error {
130120
}
131121
delete(d.conns, c.dsn)
132122
if len(d.conns) == 0 && d.db != nil {
123+
close(d.abort)
133124
if err := d.db.Close(); err != nil {
134125
return err
135126
}
@@ -152,7 +143,7 @@ func (c *conn) Begin() (driver.Tx, error) {
152143
c.Lock()
153144
defer c.Unlock()
154145
if c.closed {
155-
panic("conn is closed")
146+
return nil, fmt.Errorf("conn is closed")
156147
}
157148
// Begin is a noop because the transaction was already opened
158149
return tx{c.tx}, nil
@@ -177,7 +168,7 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
177168
c.Lock()
178169
defer c.Unlock()
179170
if c.closed {
180-
panic("conn is closed")
171+
return nil, fmt.Errorf("conn is closed")
181172
}
182173

183174
// It is not safe to give the passed in context to the tx directly
@@ -226,7 +217,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
226217
c.Lock()
227218
defer c.Unlock()
228219
if c.closed {
229-
panic("conn is closed")
220+
return nil, fmt.Errorf("conn is closed")
230221
}
231222

232223
ctx, cancel := utils.ContextFromChan(c.abort)
@@ -277,38 +268,43 @@ func (c *conn) tryOpen() bool {
277268
// Drivers must ensure all network calls made by Close
278269
// do not block indefinitely (e.g. apply a timeout).
279270
func (c *conn) Close() (err error) {
280-
if !c.close() {
281-
return
271+
newlyClosed, err := c.close()
272+
if err != nil {
273+
return err
274+
}
275+
if !newlyClosed {
276+
return nil
282277
}
278+
283279
// Wait to remove self to avoid nesting locks.
284-
if err := c.removeSelf(); err != nil {
285-
panic(err)
280+
if err = c.removeSelf(); err != nil {
281+
return err
286282
}
287283
return
288284
}
289285

290-
func (c *conn) close() bool {
286+
func (c *conn) close() (bool, error) {
291287
c.Lock()
292288
defer c.Unlock()
293289
if c.closed {
294290
// Double close, should be a safe to make this a noop
295291
// PGX allows double close
296292
// See: https://github.com/jackc/pgx/blob/a457da8bffa4f90ad672fa093ee87f20cf06687b/conn.go#L249
297-
return false
293+
return false, nil
298294
}
299295

300296
c.opened--
301297
if c.opened > 0 {
302-
return false
298+
return false, nil
303299
}
304300
if c.tx != nil {
305301
if err := c.tx.Rollback(); err != nil {
306-
panic(err)
302+
return false, err
307303
}
308304
c.tx = nil
309305
}
310306
c.closed = true
311-
return true
307+
return true, nil
312308
}
313309

314310
type tx struct {
@@ -335,7 +331,7 @@ func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
335331
s.conn.Lock()
336332
defer s.conn.Unlock()
337333
if s.conn.closed {
338-
panic("conn is closed")
334+
return nil, fmt.Errorf("conn is closed")
339335
}
340336
return s.st.Exec(mapArgs(args)...)
341337
}
@@ -345,7 +341,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
345341
s.conn.Lock()
346342
defer s.conn.Unlock()
347343
if s.conn.closed {
348-
panic("conn is closed")
344+
return nil, fmt.Errorf("conn is closed")
349345
}
350346

351347
ctx, cancel := utils.ContextFromChan(s.abort)
@@ -370,7 +366,7 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
370366
s.conn.Lock()
371367
defer s.conn.Unlock()
372368
if s.conn.closed {
373-
panic("conn is closed")
369+
return nil, fmt.Errorf("conn is closed")
374370
}
375371
rows, err := s.st.Query(mapArgs(args)...)
376372
defer func() {
@@ -387,7 +383,7 @@ func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver
387383
s.conn.Lock()
388384
defer s.conn.Unlock()
389385
if s.conn.closed {
390-
panic("conn is closed")
386+
return nil, fmt.Errorf("conn is closed")
391387
}
392388

393389
ctx, cancel := utils.ContextFromChan(s.abort)

pkg/pg/txdb_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ func TestTxDBDriver(t *testing.T) {
5555
})
5656

5757
t.Run("Make sure calling sql.Register() can be called twice", func(t *testing.T) {
58-
require.NoError(t, RegisterTxDb(tests.Context(t), "foo"))
59-
require.NoError(t, RegisterTxDb(tests.Context(t), "bar"))
58+
require.NoError(t, RegisterTxDb("foo"))
59+
require.NoError(t, RegisterTxDb("bar"))
6060
drivers := sql.Drivers()
6161
assert.Contains(t, drivers, "txdb")
6262
})

0 commit comments

Comments
 (0)