Skip to content

Commit

Permalink
Functions without context are no longer supported
Browse files Browse the repository at this point in the history
Default to a function that accepts a context.
Functions that do not accept a context have been removed.
  • Loading branch information
noborus committed Aug 12, 2024
1 parent 0c69696 commit 7b39139
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 156 deletions.
11 changes: 5 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,8 @@ func run(writer io.Writer, cfg *dbConfig, args []string) error {

outFormat := strToFormat(outFormat)
if outFormat == trdsql.GUESS {
if outWithoutGuess {
outFormat = trdsql.CSV
} else {
outFormat = trdsql.CSV
if !outWithoutGuess {
outFormat = outGuessFormat(outFile)
}
}
Expand All @@ -207,6 +206,7 @@ func run(writer io.Writer, cfg *dbConfig, args []string) error {
} else {
writer = cWriter
}

w := trdsql.NewWriter(
trdsql.OutDelimiter(outDelimiter),
trdsql.OutFormat(outFormat),
Expand Down Expand Up @@ -240,13 +240,12 @@ func run(writer io.Writer, cfg *dbConfig, args []string) error {

ctx := context.Background()

if err := trd.ExecContext(ctx, query); err != nil {
if err := trd.Exec(ctx, query); err != nil {
return err
}

if wc, ok := writer.(io.Closer); ok {
err := wc.Close()
if err != nil {
if err := wc.Close(); err != nil {
return err
}
}
Expand Down
28 changes: 6 additions & 22 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@ func (db *DB) Disconnect() error {

// CreateTable is create a (temporary) table in the database.
// The arguments are the table name, column name, column type, and temporary flag.
func (db *DB) CreateTable(tableName string, columnNames []string, columnTypes []string, isTemporary bool) error {
return db.CreateTableContext(context.Background(), tableName, columnNames, columnTypes, isTemporary)
}

// CreateTableContext is create a (temporary) table in the database.
// The arguments are the table name, column name, column type, and temporary flag.
func (db *DB) CreateTableContext(ctx context.Context, tableName string, columnNames []string, columnTypes []string, isTemporary bool) error {
func (db *DB) CreateTable(ctx context.Context, tableName string, columnNames []string, columnTypes []string, isTemporary bool) error {
if db.Tx == nil {
return ErrNoTransaction
}
Expand Down Expand Up @@ -107,12 +101,7 @@ type importTable struct {
}

// Import is imports data into a table.
func (db *DB) Import(tableName string, columnNames []string, reader Reader) error {
return db.ImportContext(context.Background(), tableName, columnNames, reader)
}

// ImportContext is imports data into a table.
func (db *DB) ImportContext(ctx context.Context, tableName string, columnNames []string, reader Reader) error {
func (db *DB) Import(ctx context.Context, tableName string, columnNames []string, reader Reader) error {
if db.Tx == nil {
return ErrNoTransaction
}
Expand Down Expand Up @@ -355,22 +344,17 @@ func (db *DB) QuotedName(orgName string) string {
return buf.String()
}

// Select is executes SQL select statements.
func (db *DB) Select(query string) (*sql.Rows, error) {
return db.SelectContext(context.Background(), query)
}

// SelectContext is executes SQL select statements with context.
// SelectContext is a wrapper for QueryContext.
func (db *DB) SelectContext(ctx context.Context, query string) (*sql.Rows, error) {
// Select is executes SQL select statements with context.
// Select is a wrapper for QueryContext.
func (db *DB) Select(ctx context.Context, query string) (*sql.Rows, error) {
rows, err := db.Tx.QueryContext(ctx, query)
if err != nil {
return rows, fmt.Errorf("%w [%s]", err, query)
}
return rows, nil
}

func (db *DB) OtherExecContext(ctx context.Context, query string) error {
func (db *DB) OtherExec(ctx context.Context, query string) error {
_, err := db.Tx.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("%w [%s]", err, query)
Expand Down
13 changes: 9 additions & 4 deletions database_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trdsql

import (
"context"
"testing"

_ "github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -154,6 +155,7 @@ func TestDB_CreateTable(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
db, err := Connect(tt.fields.driver, tt.fields.dsn)
if err != nil {
t.Fatal(err)
Expand All @@ -162,7 +164,7 @@ func TestDB_CreateTable(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := db.CreateTable(tt.args.tableName, tt.args.names, tt.args.types, tt.args.isTemporary); (err != nil) != tt.wantErr {
if err := db.CreateTable(ctx, tt.args.tableName, tt.args.names, tt.args.types, tt.args.isTemporary); (err != nil) != tt.wantErr {
t.Errorf("DB.CreateTable() error = %v, wantErr %v", err, tt.wantErr)
}
if err := db.Tx.Commit(); err != nil {
Expand Down Expand Up @@ -204,6 +206,7 @@ func TestDB_Select(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
db, err := Connect(tt.fields.driver, tt.fields.dsn)
if err != nil {
t.Fatal(err)
Expand All @@ -212,7 +215,7 @@ func TestDB_Select(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, err = db.Select(tt.args.query)
_, err = db.Select(ctx, tt.args.query)
if (err != nil) != tt.wantErr {
t.Errorf("DB.Select() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -258,6 +261,7 @@ func TestDB_Func(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
db, err := Connect(tt.fields.driver, tt.fields.dsn)
if err != nil {
t.Fatal(err)
Expand All @@ -266,7 +270,7 @@ func TestDB_Func(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, err = db.Select(tt.args.query)
_, err = db.Select(ctx, tt.args.query)
if (err != nil) != tt.wantErr {
t.Errorf("DB.Select() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -320,7 +324,8 @@ func TestDB_Import(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := db.Import(tt.args.tableName, tt.args.columnNames, tt.args.reader); (err != nil) != tt.wantErr {
ctx := context.Background()
if err := db.Import(ctx, tt.args.tableName, tt.args.columnNames, tt.args.reader); (err != nil) != tt.wantErr {
t.Errorf("DB.Import() error = %v, wantErr %v", err, tt.wantErr)
}
if err := db.Tx.Commit(); err != nil {
Expand Down
18 changes: 11 additions & 7 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package trdsql_test

import (
"bytes"
"context"
"fmt"
"log"
"os"
Expand All @@ -22,7 +23,7 @@ Ken,Thompson,ken
defer func() {
defer os.Remove(tmpfile.Name())
}()

ctx := context.Background()
if _, err := tmpfile.Write(in); err != nil {
log.Print(err)
return
Expand All @@ -33,7 +34,7 @@ Ken,Thompson,ken
)
// #nosec G201
query := fmt.Sprintf("SELECT c1 FROM %s ORDER BY c1", tmpfile.Name())
if err := trd.Exec(query); err != nil {
if err := trd.Exec(ctx, query); err != nil {
log.Print(err)
return
}
Expand Down Expand Up @@ -62,7 +63,7 @@ Ken,Thompson,ken
log.Print(err)
return
}

ctx := context.Background()
// NewImporter
importer := trdsql.NewImporter(
trdsql.InFormat(trdsql.CSV),
Expand All @@ -78,7 +79,7 @@ Ken,Thompson,ken
trd := trdsql.NewTRDSQL(importer, exporter)
// #nosec G201
query := fmt.Sprintf("SELECT * FROM %s ORDER BY username", tmpfile.Name())
err = trd.Exec(query)
err = trd.Exec(ctx, query)
if err != nil {
log.Print(err)
return
Expand Down Expand Up @@ -115,8 +116,9 @@ func ExampleSliceImporter() {
tableName := "slice"
importer := trdsql.NewSliceImporter(tableName, data)
trd := trdsql.NewTRDSQL(importer, trdsql.NewExporter(trdsql.NewWriter()))
ctx := context.Background()

err := trd.Exec("SELECT name,id FROM slice ORDER BY id DESC")
err := trd.Exec(ctx, "SELECT name,id FROM slice ORDER BY id DESC")
if err != nil {
log.Print(err)
return
Expand All @@ -140,8 +142,9 @@ func ExampleSliceWriter() {
importer := trdsql.NewSliceImporter(tableName, data)
writer := trdsql.NewSliceWriter()
trd := trdsql.NewTRDSQL(importer, trdsql.NewExporter(writer))
ctx := context.Background()

err := trd.Exec("SELECT name,id FROM slice ORDER BY id DESC")
err := trd.Exec(ctx, "SELECT name,id FROM slice ORDER BY id DESC")
if err != nil {
log.Print(err)
return
Expand Down Expand Up @@ -202,6 +205,7 @@ func ExampleBufferImporter() {
}
]
`
ctx := context.Background()
r := bytes.NewBufferString(jsonString)
importer, err := trdsql.NewBufferImporter("test", r, trdsql.InFormat(trdsql.JSON))
if err != nil {
Expand All @@ -213,7 +217,7 @@ func ExampleBufferImporter() {
trdsql.OutDelimiter("\t"),
)
trd := trdsql.NewTRDSQL(importer, trdsql.NewExporter(writer))
err = trd.Exec("SELECT name,gender,company FROM test")
err = trd.Exec(ctx, "SELECT name,gender,company FROM test")
if err != nil {
log.Print(err)
return
Expand Down
30 changes: 11 additions & 19 deletions exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import (
// Exporter is the interface for processing query results.
// Exporter executes SQL and outputs to Writer.
type Exporter interface {
Export(db *DB, sql string) error
ExportContext(ctx context.Context, db *DB, sql string) error
Export(ctx context.Context, db *DB, sql string) error
}

// WriteFormat represents a structure that satisfies Exporter.
Expand All @@ -33,30 +32,23 @@ func NewExporter(writer Writer) *WriteFormat {
}

// Export is execute SQL(Select) and the result is written out by the writer.
// Export is called from Exec.
func (e *WriteFormat) Export(db *DB, sql string) error {
ctx := context.Background()
return e.ExportContext(ctx, db, sql)
}

// ExportContext is execute SQL(Select) and the result is written out by the writer.
// ExportContext is called from ExecContext.
func (e *WriteFormat) ExportContext(ctx context.Context, db *DB, sqlQuery string) error {
// Export is called from ExecContext.
func (e *WriteFormat) Export(ctx context.Context, db *DB, sqlQuery string) error {
queries := sqlss.SplitQueries(sqlQuery)
if !multi || len(queries) == 1 {
return e.exportContext(ctx, db, sqlQuery)
return e.export(ctx, db, sqlQuery)
}

e.multi = true
for _, query := range queries {
if err := e.exportContext(ctx, db, query); err != nil {
if err := e.export(ctx, db, query); err != nil {
return err
}
}
return nil
}

func (e *WriteFormat) exportContext(ctx context.Context, db *DB, query string) error {
func (e *WriteFormat) export(ctx context.Context, db *DB, query string) error {
if db.Tx == nil {
return ErrNoTransaction
}
Expand All @@ -67,11 +59,11 @@ func (e *WriteFormat) exportContext(ctx context.Context, db *DB, query string) e
}
debug.Printf(query)

if db.isExecContext(query) {
return db.OtherExecContext(ctx, query)
if db.isExec(query) {
return db.OtherExec(ctx, query)
}

rows, err := db.SelectContext(ctx, query)
rows, err := db.Select(ctx, query)
if err != nil {
return err
}
Expand Down Expand Up @@ -138,9 +130,9 @@ func (e *WriteFormat) write(ctx context.Context, rows *sql.Rows) error {
return e.Writer.PostWrite()
}

// isExecContext returns true if the query is not a SELECT statement.
// isExec returns true if the query is not a SELECT statement.
// Queries that return no rows in SQlite should use ExecContext and therefore return true.
func (db *DB) isExecContext(query string) bool {
func (db *DB) isExec(query string) bool {
if db.driver == "sqlite3" || db.driver == "sqlite" {
return !strings.HasPrefix(strings.ToUpper(query), "SELECT")
}
Expand Down
4 changes: 3 additions & 1 deletion exporter_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trdsql

import (
"context"
"testing"
)

Expand Down Expand Up @@ -55,8 +56,9 @@ func TestWriteFormat_Export(t *testing.T) {
if err != nil {
t.Fatal("Connect error")
}
ctx := context.Background()
e := NewExporter(nil)
if err := e.Export(db, tt.args.query); (err != nil) != tt.wantErr {
if err := e.Export(ctx, db, tt.args.query); (err != nil) != tt.wantErr {
t.Errorf("WriteFormat.Export() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ require (
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.18.2
github.com/ulikunitz/xz v0.5.12
golang.org/x/term v0.22.0
modernc.org/sqlite v1.31.1
golang.org/x/term v0.23.0
modernc.org/sqlite v1.32.0
)

require (
Expand Down Expand Up @@ -53,16 +53,16 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
golang.org/x/sys v0.23.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/gc/v3 v3.0.0-20240801135723-a856999a2e4a // indirect
modernc.org/libc v1.56.0 // indirect
modernc.org/libc v1.59.1 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
modernc.org/strutil v1.2.0 // indirect
Expand Down
Loading

0 comments on commit 7b39139

Please sign in to comment.