Skip to content

Commit 1e9d6ac

Browse files
[PECO-1962] Support positional query parameters (#247)
[PECO-1962] Replaces #232 [PECO-1962]: https://databricks.atlassian.net/browse/PECO-1962?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ --------- Signed-off-by: Levko Kravets <[email protected]>
1 parent 909d73f commit 1e9d6ac

File tree

6 files changed

+155
-49
lines changed

6 files changed

+155
-49
lines changed

connection.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ func invalidOperationState(ctx context.Context, opStatus *cli_service.TGetOperat
275275
func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) {
276276
ctx = driverctx.NewContextWithConnId(ctx, c.id)
277277

278+
parameters, err := convertNamedValuesToSparkParams(args)
279+
if err != nil {
280+
return nil, err
281+
}
282+
278283
req := cli_service.TExecuteStatementReq{
279284
SessionHandle: c.session.SessionHandle,
280285
Statement: query,
@@ -284,7 +289,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
284289
MaxRows: int64(c.cfg.MaxRows),
285290
},
286291
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
287-
Parameters: convertNamedValuesToSparkParams(args),
292+
Parameters: parameters,
288293
}
289294

290295
if c.cfg.UseArrowBatches {

doc.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,17 @@ Use the driverctx package under driverctx/ctx.go to add callbacks to the query c
188188
189189
Passing parameters to a query is supported when run against servers with version DBR 14.1.
190190
191+
// Named parameters:
191192
p := dbsql.Parameter{Name: "p_bool", Value: true},
192-
rows, err1 := db.QueryContext(ctx, `select * from sometable where condition=:p_bool`,dbsql.Parameter{Name: "p_bool", Value: true})
193+
rows, err := db.QueryContext(ctx, `select * from sometable where condition=:p_bool`,dbsql.Parameter{Name: "p_bool", Value: true})
194+
195+
// Positional parameters - both `dbsql.Parameter` and plain values can be used:
196+
rows, err := db.Query(`select *, ? from sometable where field=?`,dbsql.Parameter{Value: "123.456"}, "another parameter")
193197
194198
For complex types, you can specify the SQL type using the dbsql.Parameter type field. If this field is set, the value field MUST be set to a string.
195199
200+
Please note that named and positional parameters cannot be used together in the single query.
201+
196202
# Staging Ingestion
197203
198204
The Go driver now supports staging operations. In order to use a staging operation, you first must update the context with a list of folders that you are allowing the driver to access.

errors/errors.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ import (
1010
// Error messages
1111
const (
1212
// Driver errors
13-
ErrNotImplemented = "not implemented"
14-
ErrTransactionsNotSupported = "transactions are not supported"
15-
ErrReadQueryStatus = "could not read query status"
16-
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"
17-
ErrParametersNotSupported = "query parameters are not supported by this server"
13+
ErrNotImplemented = "not implemented"
14+
ErrTransactionsNotSupported = "transactions are not supported"
15+
ErrReadQueryStatus = "could not read query status"
16+
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"
17+
ErrParametersNotSupported = "query parameters are not supported by this server"
18+
ErrMixedNamedAndPositionalParameters = "named and positional parameters cannot be used simultaneously"
1819

1920
// Request error messages (connection, authentication, network error)
2021
ErrCloseConnection = "failed to close connection"

examples/parameters/main.go

+70-29
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package main
22

33
import (
4-
"context"
54
"database/sql"
65
"fmt"
76
"log"
@@ -12,6 +11,74 @@ import (
1211
"github.com/joho/godotenv"
1312
)
1413

14+
func queryWithNamedParameters(db *sql.DB) {
15+
var p_bool bool
16+
var p_int int
17+
var p_double float64
18+
var p_float float32
19+
var p_date string
20+
21+
err := db.QueryRow(`
22+
SELECT
23+
:p_bool AS col_bool,
24+
:p_int AS col_int,
25+
:p_double AS col_double,
26+
:p_float AS col_float,
27+
:p_date AS col_date
28+
`,
29+
dbsql.Parameter{Name: "p_bool", Value: true},
30+
dbsql.Parameter{Name: "p_int", Value: int(1234)},
31+
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
32+
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
33+
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"},
34+
).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)
35+
36+
if err != nil {
37+
if err == sql.ErrNoRows {
38+
fmt.Println("not found")
39+
return
40+
} else {
41+
fmt.Printf("err: %v\n", err)
42+
}
43+
} else {
44+
fmt.Println(p_bool, p_int, p_double, p_float, p_date)
45+
}
46+
}
47+
48+
func queryWithPositionalParameters(db *sql.DB) {
49+
var p_bool bool
50+
var p_int int
51+
var p_double float64
52+
var p_float float32
53+
var p_date string
54+
55+
err := db.QueryRow(`
56+
SELECT
57+
? AS col_bool,
58+
? AS col_int,
59+
? AS col_double,
60+
? AS col_float,
61+
? AS col_date
62+
`,
63+
true,
64+
int(1234),
65+
"3.14",
66+
dbsql.Parameter{Type: dbsql.SqlFloat, Value: "3.14"},
67+
dbsql.Parameter{Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"},
68+
).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)
69+
70+
if err != nil {
71+
if err == sql.ErrNoRows {
72+
fmt.Println("not found")
73+
return
74+
} else {
75+
fmt.Printf("err: %v\n", err)
76+
}
77+
} else {
78+
fmt.Println(p_bool, p_int, p_double, p_float, p_date)
79+
}
80+
}
81+
1582
func main() {
1683
// Opening a driver typically will not attempt to connect to the database.
1784
err := godotenv.Load()
@@ -36,33 +103,7 @@ func main() {
36103
}
37104
db := sql.OpenDB(connector)
38105
defer db.Close()
39-
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
40-
// defer cancel()
41-
ctx := context.Background()
42-
var p_bool bool
43-
var p_int int
44-
var p_double float64
45-
var p_float float32
46-
var p_date string
47-
err1 := db.QueryRowContext(ctx, `SELECT
48-
:p_bool AS col_bool,
49-
:p_int AS col_int,
50-
:p_double AS col_double,
51-
:p_float AS col_float,
52-
:p_date AS col_date`,
53-
dbsql.Parameter{Name: "p_bool", Value: true},
54-
dbsql.Parameter{Name: "p_int", Value: int(1234)},
55-
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
56-
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
57-
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)
58-
59-
if err1 != nil {
60-
if err1 == sql.ErrNoRows {
61-
fmt.Println("not found")
62-
return
63-
} else {
64-
fmt.Printf("err: %v\n", err1)
65-
}
66-
}
67106

107+
queryWithNamedParameters(db)
108+
queryWithPositionalParameters(db)
68109
}

parameter_test.go

+43-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package dbsql
22

33
import (
44
"database/sql/driver"
5+
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
6+
"github.com/stretchr/testify/require"
57
"strconv"
68
"testing"
79
"time"
@@ -21,7 +23,7 @@ func TestParameter_Inference(t *testing.T) {
2123
{Name: "", Value: nil},
2224
{Name: "", Value: Parameter{Value: float64Ptr(6.2), Type: SqlUnkown}},
2325
}
24-
parameters := convertNamedValuesToSparkParams(values[:])
26+
parameters, _ := convertNamedValuesToSparkParams(values[:])
2527
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
2628
assert.NotNil(t, parameters[1].Value.StringValue)
2729
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
@@ -34,14 +36,45 @@ func TestParameter_Inference(t *testing.T) {
3436
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[6].Value)
3537
})
3638
}
37-
func TestParameters_Names(t *testing.T) {
38-
t.Run("Should infer types correctly", func(t *testing.T) {
39-
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
40-
parameters := convertNamedValuesToSparkParams(values[:])
41-
assert.Equal(t, string("1"), *parameters[0].Name)
42-
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
43-
assert.Equal(t, string("2"), *parameters[1].Name)
44-
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
45-
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
39+
40+
func TestParameters_ConvertToSpark(t *testing.T) {
41+
t.Run("Should convert names parameters", func(t *testing.T) {
42+
values := [2]driver.NamedValue{
43+
{Name: "1", Value: int(26)},
44+
{Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}},
45+
}
46+
parameters, err := convertNamedValuesToSparkParams(values[:])
47+
require.NoError(t, err)
48+
require.Equal(t, string("1"), *parameters[0].Name)
49+
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
50+
require.Equal(t, string("2"), *parameters[1].Name)
51+
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
52+
require.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
53+
})
54+
55+
t.Run("Should convert positional parameters", func(t *testing.T) {
56+
values := [2]driver.NamedValue{
57+
{Value: int(26)},
58+
{Name: "", Value: Parameter{Type: SqlDecimal, Value: "6.2"}},
59+
}
60+
parameters, err := convertNamedValuesToSparkParams(values[:])
61+
require.NoError(t, err)
62+
require.Nil(t, parameters[0].Name)
63+
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
64+
require.Nil(t, parameters[1].Name)
65+
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
66+
require.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
67+
})
68+
69+
t.Run("Should error out when named and positional parameters are mixed", func(t *testing.T) {
70+
values := [4]driver.NamedValue{
71+
{Name: "a", Value: int(26)},
72+
{Name: "", Value: Parameter{Type: SqlDecimal, Value: "6.2"}},
73+
{Value: "test"},
74+
{Name: "b", Value: Parameter{Type: SqlDouble, Value: 123.456}},
75+
}
76+
_, err := convertNamedValuesToSparkParams(values[:])
77+
require.Error(t, err)
78+
require.Equal(t, err.Error(), dbsqlerr.ErrMixedNamedAndPositionalParameters)
4679
})
4780
}

parameters.go

+23-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"strings"
99
"time"
1010

11+
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1112
"github.com/databricks/databricks-sql-go/internal/cli_service"
13+
"github.com/pkg/errors"
1214
)
1315

1416
type Parameter struct {
@@ -162,10 +164,14 @@ func inferType(param *Parameter) {
162164
}
163165
}
164166

165-
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
167+
func convertNamedValuesToSparkParams(values []driver.NamedValue) ([]*cli_service.TSparkParameter, error) {
166168
var sparkParams []*cli_service.TSparkParameter
167169

168170
sqlParams := valuesToParameters(values)
171+
172+
hasNamedParams := false
173+
hasPositionalParams := false
174+
169175
inferTypes(sqlParams)
170176
for i := range sqlParams {
171177
sqlParam := sqlParams[i]
@@ -183,10 +189,24 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.
183189
} else {
184190
sparkParamType = sqlParam.Type.String()
185191
}
186-
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: sparkValue}
192+
193+
var sparkParamName *string
194+
if sqlParam.Name != "" {
195+
sparkParamName = &sqlParam.Name
196+
hasNamedParams = true
197+
} else {
198+
sparkParamName = nil
199+
hasPositionalParams = true
200+
}
201+
202+
if hasNamedParams && hasPositionalParams {
203+
return nil, errors.New(dbsqlerr.ErrMixedNamedAndPositionalParameters)
204+
}
205+
206+
sparkParam := cli_service.TSparkParameter{Name: sparkParamName, Type: &sparkParamType, Value: sparkValue}
187207
sparkParams = append(sparkParams, &sparkParam)
188208
}
189-
return sparkParams
209+
return sparkParams, nil
190210
}
191211

192212
func inferDecimalType(d string) (t string) {

0 commit comments

Comments
 (0)