Skip to content

Commit

Permalink
Allow users of the driver to disable query formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ohaibbq committed Apr 10, 2024
1 parent 5cfbbca commit e040cf6
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 26 deletions.
10 changes: 10 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ import (
"github.com/goccy/go-zetasqlite/internal"
)

// DisableQueryFormattingKey use to disable query formatting for queries that require raw SQLite access
type DisableQueryFormattingKey = internal.DisableQueryFormattingKey

// WithQueryFormattingDisabled for queries that require raw SQLite SQL
// This is useful for queries that do not require additional functionality from go-zetasqlite
// Utilizing this option often allows the SQLite query planner to generate more efficient plans
func WithQueryFormattingDisabled(ctx context.Context) context.Context {
return context.WithValue(ctx, internal.DisableQueryFormattingKey{}, true)
}

// WithCurrentTime use to replace the current time with the specified time.
// To replace the time, you need to pass the returned context as an argument to QueryContext.
// `CURRENT_DATE`, `CURRENT_DATETIME`, `CURRENT_TIME`, `CURRENT_TIMESTAMP` functions are targeted.
Expand Down
31 changes: 28 additions & 3 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatal("found unexpected row; expected no rows")
}
})
t.Run("prepared insert", func(t *testing.T) {
t.Run("prepared insert with named values", func(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
Expand All @@ -224,11 +224,11 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatal("expected error when inserting without args; got no error")
}

stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (?)")
stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (@itemID)")
if err != nil {
t.Fatal(err)
}
if _, err := stmt.Exec(456); err != nil {
if _, err := stmt.Exec(sql.Named("itemID", 456)); err != nil {
t.Fatal(err)
}

Expand All @@ -248,4 +248,29 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatal("expected no rows; expected one row")
}
})

t.Run("prepared select with named values, formatting disabled, uppercased parameter", func(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
ctx := zetasqlite.WithQueryFormattingDisabled(context.Background())
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS Items (ItemId INT64 NOT NULL)`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec("INSERT `Items` (`ItemId`) VALUES (123)"); err != nil {
t.Fatal(err)
}

stmt, err := db.PrepareContext(ctx, "SELECT `ItemID` FROM `Items` WHERE `ItemID` = @itemID AND @bool = TRUE")
if err != nil {
t.Fatal("unexpected error when preparing stmt; got %w", err)
}

var itemID string
err = stmt.QueryRowContext(ctx, sql.Named("itemID", 123), sql.Named("bool", true)).Scan(&itemID)
if err != nil {
t.Fatal("expected one row; got error %w", err)
}
})
}
26 changes: 22 additions & 4 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type Analyzer struct {
opt *zetasql.AnalyzerOptions
}

type DisableQueryFormattingKey struct{}

func NewAnalyzer(catalog *Catalog) (*Analyzer, error) {
opt, err := newAnalyzerOptions()
if err != nil {
Expand Down Expand Up @@ -511,14 +513,30 @@ func (a *Analyzer) newQueryStmtAction(ctx context.Context, query string, args []
Type: newType(col.Column().Type()),
})
}
formattedQuery, err := newNode(node).FormatSQL(ctx)
if err != nil {
return nil, fmt.Errorf("failed to format query %s: %w", query, err)
var formattedQuery string
params := getParamsFromNode(node)
if disabledFormatting, ok := ctx.Value(DisableQueryFormattingKey{}).(bool); ok && disabledFormatting {
formattedQuery = query
// ZetaSQL will always lowercase parameter names, so we must match it in the query
queryBytes := []byte(query)
for _, param := range params {
location := param.ParseLocationRange()
start := location.Start().ByteOffset()
end := location.End().ByteOffset()
// Finds the parameter including its prefix i.e. @itemID
parameter := string(queryBytes[start:end])
formattedQuery = strings.ReplaceAll(formattedQuery, parameter, strings.ToLower(parameter))
}
} else {
var err error
formattedQuery, err = newNode(node).FormatSQL(ctx)
if err != nil {
return nil, fmt.Errorf("failed to format query %s: %w", query, err)
}
}
if formattedQuery == "" {
return nil, fmt.Errorf("failed to format query %s", query)
}
params := getParamsFromNode(node)
queryArgs, err := getArgsFromParams(args, params)
if err != nil {
return nil, err
Expand Down
42 changes: 23 additions & 19 deletions internal/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ func (s *DMLStmt) NumInput() int {
}

func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) {
values := make([]interface{}, 0, len(args))
for _, arg := range args {
values = append(values, arg)
}
newArgs, err := EncodeGoValues(values, s.args)
return s.ExecContext(context.Background(), valuesToNamedValues(args))
}

func (s *DMLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
newArgs, err := getArgsFromParams(args, s.args)
if err != nil {
return nil, err
}
result, err := s.stmt.Exec(newArgs...)
result, err := s.stmt.ExecContext(ctx, newArgs...)
if err != nil {
return nil, fmt.Errorf(
"failed to execute query %s: args %v: %w",
Expand All @@ -172,10 +172,6 @@ func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) {
return result, nil
}

func (s *DMLStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return nil, fmt.Errorf("unimplemented ExecContext for DMLStmt")
}

func (s *DMLStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, fmt.Errorf("unsupported query for DMLStmt")
}
Expand Down Expand Up @@ -224,16 +220,28 @@ func (s *QueryStmt) ExecContext(ctx context.Context, query string, args []driver
return nil, fmt.Errorf("unsupported exec for QueryStmt")
}

func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) {
values := make([]interface{}, 0, len(args))
func valuesToNamedValues(args []driver.Value) []driver.NamedValue {
values := make([]driver.NamedValue, 0, len(args))
for _, arg := range args {
values = append(values, arg)
if namedValue, ok := arg.(driver.NamedValue); ok {
values = append(values, namedValue)
}
values = append(values, driver.NamedValue{Value: arg})
}
newArgs, err := EncodeGoValues(values, s.args)

return values
}

func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), valuesToNamedValues(args))
}

func (s *QueryStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
newArgs, err := getArgsFromParams(args, s.args)
if err != nil {
return nil, err
}
rows, err := s.stmt.Query(newArgs...)
rows, err := s.stmt.QueryContext(ctx, newArgs...)
if err != nil {
return nil, fmt.Errorf(
"failed to query %s: args: %v: %w",
Expand All @@ -244,7 +252,3 @@ func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) {
}
return &Rows{rows: rows, columns: s.outputColumns}, nil
}

func (s *QueryStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return nil, fmt.Errorf("unimplemented QueryContext for QueryStmt")
}

0 comments on commit e040cf6

Please sign in to comment.