Skip to content

Commit 6ec2ad4

Browse files
committed
Move some bits to internal/
Move some of the OS abstractions, parsing, and other bits to internal/ package. Also split out conn.go to rows.go and stmt.go.
1 parent b4c9a0a commit 6ec2ad4

32 files changed

+989
-1061
lines changed

conn.go

Lines changed: 58 additions & 511 deletions
Large diffs are not rendered by default.

conn_test.go

Lines changed: 32 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"testing"
1616
"time"
1717

18+
"github.com/lib/pq/internal/pgpass"
1819
"github.com/lib/pq/internal/pqtest"
1920
)
2021

@@ -82,57 +83,8 @@ func TestOpenURL(t *testing.T) {
8283
testURL("postgresql://")
8384
}
8485

85-
const pgpassFile = "/tmp/pqgotest_pgpass"
86-
8786
func TestPgpass(t *testing.T) {
88-
testAssert := func(conninfo string, expected string, reason string) {
89-
conn := pqtest.MustDB(t, conninfo)
90-
91-
txn, err := conn.Begin()
92-
if err != nil {
93-
if expected != "fail" {
94-
t.Fatalf(reason, err)
95-
}
96-
return
97-
}
98-
rows, err := txn.Query("SELECT USER")
99-
if err != nil {
100-
txn.Rollback()
101-
if expected != "fail" {
102-
t.Fatalf(reason, err)
103-
}
104-
} else {
105-
rows.Close()
106-
if expected != "ok" {
107-
t.Fatalf(reason, err)
108-
}
109-
}
110-
txn.Rollback()
111-
}
112-
testAssert("", "ok", "missing .pgpass, unexpected error %#v")
113-
os.Setenv("PGPASSFILE", pgpassFile)
114-
defer os.Unsetenv("PGPASSFILE")
115-
testAssert("host=/tmp", "fail", ", unexpected error %#v")
116-
os.Unsetenv("PGPASSFILE")
117-
118-
os.Remove(pgpassFile)
119-
pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644)
120-
if err != nil {
121-
t.Fatalf("Unexpected error writing pgpass file %#v", err)
122-
}
123-
_, err = pgpass.WriteString(`# comment
124-
server:5432:some_db:some_user:pass_A
125-
*:5432:some_db:some_user:pass_B
126-
localhost:*:*:*:pass_C
127-
*:*:*:*:pass_fallback
128-
`)
129-
if err != nil {
130-
t.Fatalf("Unexpected error writing pgpass file %#v", err)
131-
}
132-
defer os.Remove(pgpassFile)
133-
pgpass.Close()
134-
135-
assertPassword := func(extra values, expected string) {
87+
assertPassword := func(want string, extra values) {
13688
o := values{
13789
"host": "localhost",
13890
"sslmode": "disable",
@@ -146,27 +98,42 @@ localhost:*:*:*:pass_C
14698
for k, v := range extra {
14799
o[k] = v
148100
}
149-
(&conn{}).handlePgpass(o)
150-
if pw := o["password"]; pw != expected {
151-
t.Fatalf("For %v expected %s got %s", extra, expected, pw)
101+
have := pgpass.PasswordFromPgpass(o)
102+
if have != want {
103+
t.Fatalf("wrong password\nhave: %q\nwant: %q", have, want)
152104
}
153105
}
154-
// missing passfile means empty psasword
155-
assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "")
106+
107+
file := pqtest.TempFile(t, "pgpass", pqtest.NormalizeIndent(`
108+
# comment
109+
server:5432:some_db:some_user:pass_A
110+
*:5432:some_db:some_user:pass_B
111+
localhost:*:*:*:pass_C
112+
*:*:*:*:pass_fallback
113+
`))
114+
115+
// Missing passfile means empty password.
116+
assertPassword("", values{"host": "server", "dbname": "some_db", "user": "some_user"})
117+
156118
// wrong permissions for the pgpass file means it should be ignored
157-
assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "")
158-
// fix the permissions and check if it has taken effect
159-
os.Chmod(pgpassFile, 0600)
119+
assertPassword("", values{"host": "example.com", "passfile": file, "user": "foo"})
120+
121+
if err := os.Chmod(file, 0600); err != nil { // Fix the permissions
122+
t.Fatal(err)
123+
}
124+
125+
assertPassword("pass_A", values{"host": "server", "passfile": file, "dbname": "some_db", "user": "some_user"})
126+
assertPassword("pass_fallback", values{"host": "example.com", "passfile": file, "user": "foo"})
127+
assertPassword("pass_B", values{"host": "example.com", "passfile": file, "dbname": "some_db", "user": "some_user"})
160128

161-
assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A")
162-
assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "pass_fallback")
163-
assertPassword(values{"host": "example.com", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_B")
164129
// localhost also matches the default "" and UNIX sockets
165-
assertPassword(values{"host": "", "passfile": pgpassFile, "user": "some_user"}, "pass_C")
166-
assertPassword(values{"host": "/tmp", "passfile": pgpassFile, "user": "some_user"}, "pass_C")
167-
// passfile connection parameter takes precedence
130+
assertPassword("pass_C", values{"host": "", "passfile": file, "user": "some_user"})
131+
assertPassword("pass_C", values{"host": "/tmp", "passfile": file, "user": "some_user"})
132+
133+
// Connection parameter takes precedence
168134
os.Setenv("PGPASSFILE", "/tmp")
169-
assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A")
135+
defer os.Unsetenv("PGPASSFILE")
136+
assertPassword("pass_A", values{"host": "server", "passfile": file, "dbname": "some_db", "user": "some_user"})
170137
}
171138

172139
func TestExecNilSlice(t *testing.T) {
@@ -1473,58 +1440,6 @@ func TestRuntimeParameters(t *testing.T) {
14731440
}
14741441
}
14751442

1476-
func TestQuoteIdentifier(t *testing.T) {
1477-
var cases = []struct {
1478-
input string
1479-
want string
1480-
}{
1481-
{`foo`, `"foo"`},
1482-
{`foo bar baz`, `"foo bar baz"`},
1483-
{`foo"bar`, `"foo""bar"`},
1484-
{"foo\x00bar", `"foo"`},
1485-
{"\x00foo", `""`},
1486-
}
1487-
1488-
for _, test := range cases {
1489-
got := QuoteIdentifier(test.input)
1490-
if got != test.want {
1491-
t.Errorf("QuoteIdentifier(%q) = %v want %v", test.input, got, test.want)
1492-
}
1493-
}
1494-
}
1495-
1496-
func TestQuoteLiteral(t *testing.T) {
1497-
var cases = []struct {
1498-
input string
1499-
want string
1500-
}{
1501-
{`foo`, `'foo'`},
1502-
{`foo bar baz`, `'foo bar baz'`},
1503-
{`foo'bar`, `'foo''bar'`},
1504-
{`foo\bar`, ` E'foo\\bar'`},
1505-
{`foo\ba'r`, ` E'foo\\ba''r'`},
1506-
{`foo"bar`, `'foo"bar'`},
1507-
{`foo\x00bar`, ` E'foo\\x00bar'`},
1508-
{`\x00foo`, ` E'\\x00foo'`},
1509-
{`'`, `''''`},
1510-
{`''`, `''''''`},
1511-
{`\`, ` E'\\'`},
1512-
{`'abc'; DROP TABLE users;`, `'''abc''; DROP TABLE users;'`},
1513-
{`\'`, ` E'\\'''`},
1514-
{`E'\''`, ` E'E''\\'''''`},
1515-
{`e'\''`, ` E'e''\\'''''`},
1516-
{`E'\'abc\'; DROP TABLE users;'`, ` E'E''\\''abc\\''; DROP TABLE users;'''`},
1517-
{`e'\'abc\'; DROP TABLE users;'`, ` E'e''\\''abc\\''; DROP TABLE users;'''`},
1518-
}
1519-
1520-
for _, test := range cases {
1521-
got := QuoteLiteral(test.input)
1522-
if got != test.want {
1523-
t.Errorf("QuoteLiteral(%q) = %v want %v", test.input, got, test.want)
1524-
}
1525-
}
1526-
}
1527-
15281443
func TestRowsResultTag(t *testing.T) {
15291444
type ResultTag interface {
15301445
Result() driver.Result

connector.go

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ import (
66
"errors"
77
"fmt"
88
"net"
9+
neturl "net/url"
910
"os"
1011
"path/filepath"
12+
"sort"
1113
"strings"
1214
"unicode"
15+
16+
"github.com/lib/pq/internal/pqutil"
1317
)
1418

1519
// Connector represents a fixed configuration for the pq driver with a given
@@ -104,9 +108,9 @@ func NewConnector(dsn string) (*Connector, error) {
104108
// resort is to use the current operating system provided user
105109
// name.
106110
if _, ok := o["user"]; !ok {
107-
u, err := userCurrent()
111+
u, err := pqutil.User()
108112
if err != nil {
109-
return nil, err
113+
return nil, ErrCouldNotDetectUsername
110114
}
111115
o["user"] = u
112116
}
@@ -245,6 +249,52 @@ func parseOpts(name string, o values) error {
245249
return nil
246250
}
247251

252+
func convertURL(url string) (string, error) {
253+
u, err := neturl.Parse(url)
254+
if err != nil {
255+
return "", err
256+
}
257+
258+
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
259+
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
260+
}
261+
262+
var kvs []string
263+
escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`)
264+
accrue := func(k, v string) {
265+
if v != "" {
266+
kvs = append(kvs, k+"='"+escaper.Replace(v)+"'")
267+
}
268+
}
269+
270+
if u.User != nil {
271+
v := u.User.Username()
272+
accrue("user", v)
273+
274+
v, _ = u.User.Password()
275+
accrue("password", v)
276+
}
277+
278+
if host, port, err := net.SplitHostPort(u.Host); err != nil {
279+
accrue("host", u.Host)
280+
} else {
281+
accrue("host", host)
282+
accrue("port", port)
283+
}
284+
285+
if u.Path != "" {
286+
accrue("dbname", u.Path[1:])
287+
}
288+
289+
q := u.Query()
290+
for k := range q {
291+
accrue(k, q.Get(k))
292+
}
293+
294+
sort.Strings(kvs) // Makes testing easier (not a performance concern)
295+
return strings.Join(kvs, " "), nil
296+
}
297+
248298
// parseEnviron tries to mimic some of libpq's environment handling
249299
//
250300
// To ease testing, it does not directly reference os.Environ, but is
@@ -338,16 +388,14 @@ func parseEnviron(env []string) (out map[string]string) {
338388
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
339389
func isUTF8(name string) bool {
340390
// Recognize all sorts of silly things as "UTF-8", like Postgres does
341-
s := strings.Map(alnumLowerASCII, name)
391+
s := strings.Map(func(c rune) rune {
392+
if 'A' <= c && c <= 'Z' {
393+
return c + ('a' - 'A')
394+
}
395+
if 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
396+
return c
397+
}
398+
return -1 // discard
399+
}, name)
342400
return s == "utf8" || s == "unicode"
343401
}
344-
345-
func alnumLowerASCII(ch rune) rune {
346-
if 'A' <= ch && ch <= 'Z' {
347-
return ch + ('a' - 'A')
348-
}
349-
if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
350-
return ch
351-
}
352-
return -1 // discard
353-
}

connector_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"os"
1010
"reflect"
1111
"testing"
12+
13+
"github.com/lib/pq/internal/pqtest"
1214
)
1315

1416
func TestNewConnector_WorksWithOpenDB(t *testing.T) {
@@ -130,3 +132,35 @@ func TestIsUTF8(t *testing.T) {
130132
}
131133
}
132134
}
135+
136+
func TestParseURL(t *testing.T) {
137+
tests := []struct {
138+
in, want, wantErr string
139+
}{
140+
{"postgres://", "", ""},
141+
{"postgres://hostname.remote", "host='hostname.remote'", ""},
142+
{"postgres://[::1]:1234", "host='::1' port='1234'", ""},
143+
{"postgres://username:top%[email protected]:1234/database",
144+
`dbname='database' host='hostname.remote' password='top secret' port='1234' user='username'`, ""},
145+
{"postgres://localhost/a%2Fb", "dbname='a/b' host='localhost'", ""},
146+
147+
{"", "", "invalid connection protocol:"},
148+
{"http://hostname.remote", "", "invalid connection protocol: http"},
149+
150+
//{"postgresql://%2Fvar%2Flib%2Fpostgresql/dbname", "", ``},
151+
//{"postgres:// host/db", "dbname='db' host='host'", ""},
152+
//{"postgres://host/db ", "dbname='db' host='host'", ""},
153+
}
154+
155+
for _, tt := range tests {
156+
t.Run("", func(t *testing.T) {
157+
have, err := ParseURL(tt.in)
158+
if !pqtest.ErrorContains(err, tt.wantErr) {
159+
t.Fatal(err)
160+
}
161+
if have != tt.want {
162+
t.Errorf("\nhave: %q\nwant: %q", have, tt.want)
163+
}
164+
})
165+
}
166+
}

deprecated.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,4 @@ func (e *Error) Get(k byte) (v string) {
6464
//
6565
// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...")
6666
// now works, and calling this manually is no longer required.
67-
func ParseURL(url string) (string, error) { return parseURL(url) }
67+
func ParseURL(url string) (string, error) { return convertURL(url) }

encode.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,23 @@ func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) any {
9292
panic("not reached")
9393
}
9494

95+
// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format.
96+
func decodeUUIDBinary(src []byte) ([]byte, error) {
97+
if len(src) != 16 {
98+
return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src))
99+
}
100+
101+
dst := make([]byte, 36)
102+
dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-'
103+
hex.Encode(dst[0:], src[0:4])
104+
hex.Encode(dst[9:], src[4:6])
105+
hex.Encode(dst[14:], src[6:8])
106+
hex.Encode(dst[19:], src[8:10])
107+
hex.Encode(dst[24:], src[10:16])
108+
109+
return dst, nil
110+
}
111+
95112
func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) any {
96113
switch typ {
97114
case oid.T_char, oid.T_bpchar, oid.T_varchar, oid.T_text:

0 commit comments

Comments
 (0)