Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to use ':' in named args #2178

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions named_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ func rawState(l *sqlLexer) stateFn {
return singleQuoteState
case '"':
return doubleQuoteState
case ':':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
prevRune := rune(0)
if l.pos > 1 {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this check to avoid panic when : is at the beginning of a line.

prevRune, _ = utf8.DecodeRuneInString(l.src[l.pos-2:])
}
if nextRune != ':' && prevRune != ':' && (isLetter(nextRune) || nextRune == '_') {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that we can omit the first check nextRune != ':' as there will be a more specific check next, but in addition I would like to say that this check is much easier than the next ones, and will cut off cast types a little faster.... But type casts are not done so often to leave this prevenient check.

What do you think, should I remove nextRune != ':' && ?

if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return namedArgState
Comment on lines +117 to +121
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy from '@' case

}
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) || nextRune == '_' {
Expand Down
152 changes: 152 additions & 0 deletions named_args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,155 @@ func TestStrictNamedArgsRewriteQuery(t *testing.T) {
}
}
}

func TestNamedArgsRewriteQuery2(t *testing.T) {
Copy link
Author

@KoNekoD KoNekoD Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added a tests that should confirm the clarity of my implementation

t.Parallel()

for i, tt := range []struct {
sql string
args []any
namedArgs pgx.NamedArgs
expectedSQL string
expectedArgs []any
}{
{
sql: "select * from users where id = :id",
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: "select * from users where id = $1",
expectedArgs: []any{int32(42)},
},
{
sql: "select * from t where foo < :abc and baz = :def and bar < :abc",
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
expectedArgs: []any{int32(42), int32(1)},
},
{
sql: "select :a::int, :b::text",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "select $1::int, $2::text",
expectedArgs: []any{int32(42), "foo"},
},
{
sql: "select :Abc::int, :b_4::text, :_c::int",
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
expectedSQL: "select $1::int, $2::text, $3::int",
expectedArgs: []any{int32(42), "foo", int32(1)},
},
{
sql: "at end :",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "at end :",
expectedArgs: []any{},
},
{
sql: "ignores without valid character after : foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "ignores without valid character after : foo bar",
expectedArgs: []any{},
},
{
sql: "name cannot start with number :1 foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "name cannot start with number :1 foo bar",
expectedArgs: []any{},
},
{
sql: `select *, ':foo' as ":bar" from users where id = :id`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select *, ':foo' as ":bar" from users where id = $1`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * -- :foo
from users -- :single line comments
where id = :id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * -- :foo
from users -- :single line comments
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * /* :multi line
:comment
*/
/* /* with :nesting */ */
from users
where id = :id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * /* :multi line
:comment
*/
/* /* with :nesting */ */
from users
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: ":missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},

// test comments and quotes
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
require.NoError(t, err)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}

func TestStrictNamedArgsRewriteQuery2(t *testing.T) {
t.Parallel()

for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: ":all :matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: ":missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}