Skip to content

Commit

Permalink
fix: use driver.Conn as Tx
Browse files Browse the repository at this point in the history
  • Loading branch information
proullon committed Aug 17, 2023
1 parent d851e3f commit a7e15a6
Showing 1 changed file with 74 additions and 16 deletions.
90 changes: 74 additions & 16 deletions driver/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql/driver"

"github.com/proullon/ramsql/engine/executor"
"github.com/proullon/ramsql/engine/log"
)

// Conn implements sql/driver Conn interface
Expand All @@ -29,7 +30,8 @@ import (
// https://pkg.go.dev/database/sql/driver#ConnPrepareContext
// https://pkg.go.dev/database/sql/driver#ConnBeginTx
type Conn struct {
e *executor.Engine
e *executor.Engine
tx *executor.Tx
}

func newConn(e *executor.Engine) *Conn {
Expand Down Expand Up @@ -83,6 +85,11 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) {
//
// Implemented for Conn interface
func (c *Conn) Close() error {
if c.tx != nil {
c.tx.Rollback()
c.tx = nil
}

return nil
}

Expand All @@ -92,7 +99,12 @@ func (c *Conn) Close() error {
//
// Implemented for Conn interface
func (c *Conn) Begin() (driver.Tx, error) {
return executor.NewTx(context.Background(), c.e, sql.TxOptions{})
tx, err := executor.NewTx(context.Background(), c.e, sql.TxOptions{})
if err != nil {
return nil, err
}
c.tx = tx
return c, nil
}

// BeginTx starts and returns a new transaction.
Expand All @@ -103,19 +115,51 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
Isolation: sql.IsolationLevel(opts.Isolation),
ReadOnly: opts.ReadOnly,
}
return executor.NewTx(ctx, c.e, o)
tx, err := executor.NewTx(ctx, c.e, o)
if err != nil {
return nil, err
}
c.tx = tx
return c, nil
}

func (c *Conn) Rollback() error {
if c.tx == nil {
return nil
}
err := c.tx.Rollback()
c.tx = nil
return err
}

func (c *Conn) Commit() error {
if c.tx == nil {
return nil
}
err := c.tx.Commit()
c.tx = nil
return err
}

// QueryContext is the sql package prefered way to run QUERY.
//
// Implemented for QueryerContext interface
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var err error
autocommit := false

tx, err := c.e.Begin()
if err != nil {
return nil, err
log.Info("Conn.QueryContext: %s", query)

tx := c.tx

if tx == nil {
autocommit = true
tx, err = c.e.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
}
defer tx.Rollback()

a := make([]executor.NamedValue, len(args))
for i, arg := range args {
Expand All @@ -129,9 +173,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, err
}

err = tx.Commit()
if err != nil {
return nil, err
if autocommit {
err = tx.Commit()
if err != nil {
return nil, err
}
}

return newRows(cols, tuples), nil
Expand All @@ -141,9 +187,19 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
//
// Implemented for ExecerContext interface
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
tx, err := c.e.Begin()
if err != nil {
return nil, err
var err error
autocommit := false
log.Info("Conn.ExecContext: %s", query)

tx := c.tx

if tx == nil {
autocommit = true
tx, err = c.e.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
}

a := make([]executor.NamedValue, len(args))
Expand All @@ -159,9 +215,11 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name
return r, r.err
}

err = tx.Commit()
if err != nil {
return r, r.err
if autocommit {
err = tx.Commit()
if err != nil {
return r, r.err
}
}

return r, r.err
Expand Down

0 comments on commit a7e15a6

Please sign in to comment.