diff --git a/driver/conn.go b/driver/conn.go index ae5e006..99da5ae 100644 --- a/driver/conn.go +++ b/driver/conn.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "github.com/proullon/ramsql/engine/executor" + "github.com/proullon/ramsql/engine/log" ) // Conn implements sql/driver Conn interface @@ -29,7 +30,8 @@ import ( // https://pkg.go.dev/database/sql/driver#ConnPrepareContext // https://pkg.go.dev/database/sql/driver#ConnBeginTx type Conn struct { - e *executor.Engine + e *executor.Engine + tx *executor.Tx } func newConn(e *executor.Engine) *Conn { @@ -83,6 +85,11 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { // // Implemented for Conn interface func (c *Conn) Close() error { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + return nil } @@ -92,7 +99,12 @@ func (c *Conn) Close() error { // // Implemented for Conn interface func (c *Conn) Begin() (driver.Tx, error) { - return executor.NewTx(context.Background(), c.e, sql.TxOptions{}) + tx, err := executor.NewTx(context.Background(), c.e, sql.TxOptions{}) + if err != nil { + return nil, err + } + c.tx = tx + return c, nil } // BeginTx starts and returns a new transaction. @@ -103,19 +115,51 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e Isolation: sql.IsolationLevel(opts.Isolation), ReadOnly: opts.ReadOnly, } - return executor.NewTx(ctx, c.e, o) + tx, err := executor.NewTx(ctx, c.e, o) + if err != nil { + return nil, err + } + c.tx = tx + return c, nil +} + +func (c *Conn) Rollback() error { + if c.tx == nil { + return nil + } + err := c.tx.Rollback() + c.tx = nil + return err +} + +func (c *Conn) Commit() error { + if c.tx == nil { + return nil + } + err := c.tx.Commit() + c.tx = nil + return err } // QueryContext is the sql package prefered way to run QUERY. // // Implemented for QueryerContext interface func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + var err error + autocommit := false - tx, err := c.e.Begin() - if err != nil { - return nil, err + log.Info("Conn.QueryContext: %s", query) + + tx := c.tx + + if tx == nil { + autocommit = true + tx, err = c.e.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() } - defer tx.Rollback() a := make([]executor.NamedValue, len(args)) for i, arg := range args { @@ -129,9 +173,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, err } - err = tx.Commit() - if err != nil { - return nil, err + if autocommit { + err = tx.Commit() + if err != nil { + return nil, err + } } return newRows(cols, tuples), nil @@ -141,9 +187,19 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam // // Implemented for ExecerContext interface func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - tx, err := c.e.Begin() - if err != nil { - return nil, err + var err error + autocommit := false + log.Info("Conn.ExecContext: %s", query) + + tx := c.tx + + if tx == nil { + autocommit = true + tx, err = c.e.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() } a := make([]executor.NamedValue, len(args)) @@ -159,9 +215,11 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name return r, r.err } - err = tx.Commit() - if err != nil { - return r, r.err + if autocommit { + err = tx.Commit() + if err != nil { + return r, r.err + } } return r, r.err