diff --git a/conn.go b/conn.go index 311721459..b890fd875 100644 --- a/conn.go +++ b/conn.go @@ -41,11 +41,22 @@ type ConnConfig struct { DefaultQueryExecMode QueryExecMode createdByParseConfig bool // Used to enforce created by ParseConfig rule. + + // automatically call LoadTypes with these values + AutoLoadTypes []string + + // TypeRegistrationMap is used to register types which require special operations. + // The type name is the key, the value is a function which will be called for each + // connection, providing the OID of that type name for that connection. + // The function will manipulate conn.TypeMap() in some way. + TypeRegistrationMap map[string]CustomRegistrationFunction } // ParseConfigOptions contains options that control how a config is built such as getsslpassword. type ParseConfigOptions struct { pgconn.ParseConfigOptions + + AutoLoadTypes []string } // Copy returns a deep copy of the config that is safe to use and modify. @@ -107,8 +118,10 @@ var ( ErrTooManyRows = errors.New("too many rows in result set") ) -var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") -var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -194,6 +207,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con DescriptionCacheCapacity: descriptionCacheCapacity, DefaultQueryExecMode: defaultQueryExecMode, connString: connString, + AutoLoadTypes: options.AutoLoadTypes, } return connConfig, nil @@ -271,6 +285,14 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) } + if c.config.AutoLoadTypes != nil { + if types, err := LoadTypes(ctx, c, c.config.AutoLoadTypes); err == nil { + c.TypeMap().RegisterTypes(types) + } else { + return nil, err + } + } + return c, nil } @@ -843,7 +865,6 @@ func (c *Conn) getStatementDescription( mode QueryExecMode, sql string, ) (sd *pgconn.StatementDescription, err error) { - switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { diff --git a/derived_types.go b/derived_types.go new file mode 100644 index 000000000..a55c90d49 --- /dev/null +++ b/derived_types.go @@ -0,0 +1,286 @@ +package pgx + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgtype" +) + +// CustomRegistrationFunction is capable of registering whatever is necessary for +// a custom type. It is provided with the backend's OID for this type. +type CustomRegistrationFunction func(ctx context.Context, m *pgtype.Map, oid uint32) error + +/* +buildLoadDerivedTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + + if typeNames == nil { + // collect all types. Not currently recommended. + typeNamesClause = "IS NOT NULL" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + // Each of the type names provided might be found in pg_class or pg_type. + // Additionally, it may or may not include a schema portion. + parts = append(parts, ` +WITH RECURSIVE +-- find the OIDs in pg_class which match one of the provided type names +selected_classes(oid,reltype) AS ( + -- this query uses the namespace search path, so will match type names without a schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + -- this query will only match type names which include the schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + -- collect the OIDs from pg_types which correspond to the selected classes + SELECT reltype AS oid + FROM selected_classes +UNION ALL + -- as well as any other type names which match our criteria + SELECT oid + FROM pg_type + WHERE typname `, typeNamesClause, ` +), +-- this builds a parent/child mapping of objects, allowing us to know +-- all the child (ie: dependent) types that a parent (type) requires +-- As can be seen, there are 3 ways this can occur (the last of which +-- is due to being a composite class, where the composite fields are children) +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +-- Now construct a recursive query which includes a 'depth' element. +-- This is used to ensure that the "youngest" children are registered before +-- their parents. +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +-- Bring together this information, showing all the information which might possibly be required +-- to complete the registration, applying filters to only show the items which relate to the selected +-- types/classes. +SELECT typname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid = relationships.child) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +type derivedTypeInfo struct { + Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32 + TypeName, Typtype string + Attnames []string + Atttypids []uint32 +} + +// LoadTypes performs a single (complex) query, returning all the required +// information to register the named types, as well as any other types directly +// or indirectly required to complete the registration. +// The result of this call can be passed into RegisterTypes to complete the process. +func LoadTypes(ctx context.Context, c *Conn, typeNames []string) ([]*pgtype.Type, error) { + m := c.TypeMap() + if typeNames == nil || len(typeNames) == 0 { + return nil, fmt.Errorf("No type names were supplied.") + } + + serverVersion, err := serverVersion(c) + if err != nil { + return nil, fmt.Errorf("Unexpected server version error: %w", err) + } + sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) + var rows Rows + if typeNames == nil { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol) + } else { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + } + if err != nil { + return nil, fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + result := make([]*pgtype.Type, 0, 100) + for rows.Next() { + ti := derivedTypeInfo{} + err = rows.Scan(&ti.TypeName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids) + if err != nil { + return nil, fmt.Errorf("While scanning type information: %w", err) + } + var type_ *pgtype.Type + switch ti.Typtype { + case "b": // array + dt, ok := m.TypeForOID(ti.Typelem) + if !ok { + return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName) + } + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}} + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range ti.Attnames { + //if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil { + // return nil, fmt.Errorf("While extracting OID used in composite field: %w", err) + //} + dt, ok := m.TypeForOID(ti.Atttypids[i]) + if !ok { + return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i]) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}} + case "d": // domain + dt, ok := m.TypeForOID(ti.Typbasetype) + if !ok { + return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec} + case "e": // enum + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}} + case "r": // range + dt, ok := m.TypeForOID(ti.Rngsubtype) + if !ok { + return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}} + case "m": // multirange + dt, ok := m.TypeForOID(ti.Rngtypid) + if !ok { + return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}} + default: + return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) + } + if type_ != nil { + m.RegisterType(type_) + result = append(result, type_) + } + } + return result, nil +} + +// serverVersion returns the postgresql server version. +func serverVersion(c *Conn) (int64, error) { + serverVersionStr := c.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return serverVersion, nil +} + +func fetchOidMapForCustomRegistration(ctx context.Context, conn *Conn) (map[string]uint32, error) { + sql := ` +SELECT oid, typname +FROM pg_type +WHERE typname = ANY($1)` + result := make(map[string]uint32) + typeNames := make([]string, 0, len(conn.config.TypeRegistrationMap)) + for typeName := range conn.config.TypeRegistrationMap { + typeNames = append(typeNames, typeName) + } + rows, err := conn.Query(ctx, sql, typeNames) + if err != nil { + return nil, fmt.Errorf("While collecting OIDs for custom registrations: %w", err) + } + defer rows.Close() + var typeName string + var oid uint32 + for rows.Next() { + if err := rows.Scan(&typeName, &oid); err != nil { + return nil, fmt.Errorf("While scanning a row for custom registrations: %w", err) + } + result[typeName] = oid + } + return result, nil +} diff --git a/derived_types_test.go b/derived_types_test.go new file mode 100644 index 000000000..fba722b30 --- /dev/null +++ b/derived_types_test.go @@ -0,0 +1,37 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dtype_test; +drop domain if exists anotheruint64; + +create domain anotheruint64 as numeric(20,0); +create type dtype_test as ( + a text, + b int4, + c anotheruint64, + d anotheruint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type dtype_test") + defer conn.Exec(ctx, "drop domain anotheruint64") + + types, err := pgx.LoadTypes(ctx, conn, []string{"dtype_test"}) + require.NoError(t, err) + require.Len(t, types, 3) + require.Equal(t, types[0].Name, "anotheruint64") + require.Equal(t, types[1].Name, "_anotheruint64") + require.Equal(t, types[2].Name, "dtype_test") + }) +} diff --git a/go.sum b/go.sum index 4b02a0365..29fe452b2 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= diff --git a/pgtype/derived_types_test.go b/pgtype/derived_types_test.go new file mode 100644 index 000000000..6c1d9048c --- /dev/null +++ b/pgtype/derived_types_test.go @@ -0,0 +1,58 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestDerivedTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dt_test; +drop domain if exists dt_uint64; + +create domain dt_uint64 as numeric(20,0); +create type dt_test as ( + a text, + b dt_uint64, + c dt_uint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop domain dt_uint64") + defer conn.Exec(ctx, "drop type dt_test") + + dtypes, err := pgx.LoadTypes(ctx, conn, []string{"dt_test"}) + require.Len(t, dtypes, 3) + require.Equal(t, dtypes[0].Name, "dt_uint64") + require.Equal(t, dtypes[1].Name, "_dt_uint64") + require.Equal(t, dtypes[2].Name, "dt_test") + require.NoError(t, err) + conn.TypeMap().RegisterTypes(dtypes) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b uint64 + var c *[]uint64 + + row := conn.QueryRow(ctx, "select $1::dt_test", pgx.QueryResultFormats{format.code}, pgtype.CompositeFields{"hi", uint64(42), []uint64{10, 20, 30}}) + err := row.Scan(pgtype.CompositeFields{&a, &b, &c}) + require.NoError(t, err) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 408295683..0de9cfe8a 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -214,6 +214,15 @@ type Map struct { TryWrapScanPlanFuncs []TryWrapScanPlanFunc } +// Types() returns the non-default types which were registered +func (m *Map) Types() []*Type { + result := make([]*Type, 0, len(m.oidToType)) + for _, type_ := range m.oidToType { + result = append(result, type_) + } + return result +} + func NewMap() *Map { defaultMapInitOnce.Do(initDefaultMap) @@ -248,6 +257,13 @@ func NewMap() *Map { } } +// RegisterTypes registers multiple data types in the sequence they are provided. +func (m *Map) RegisterTypes(types []*Type) { + for _, t := range types { + m.RegisterType(t) + } +} + // RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t diff --git a/pgx_test.go b/pgx_test.go new file mode 100644 index 000000000..51b4bbc4e --- /dev/null +++ b/pgx_test.go @@ -0,0 +1,22 @@ +package pgx_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" +) + +func skipCockroachDB(t testing.TB, msg string) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} diff --git a/pgxpool/pool.go b/pgxpool/pool.go index fdcba7241..5424a38a6 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -12,14 +12,17 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/puddle/v2" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -100,6 +103,10 @@ type Pool struct { closeOnce sync.Once closeChan chan struct{} + + reuseTypeMap bool + autoLoadTypes []*pgtype.Type + autoLoadMutex *sync.Mutex } // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be @@ -147,6 +154,13 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration + // ReuseTypeMaps, if enabled, will reuse the typemap information being used by AutoLoadTypes. + // This removes the need to query the database each time a new connection is created; + // only RegisterTypes will need to be called for each new connection. + // In some situations, where OID mapping can differ between pg servers in the pool, perhaps due + // to certain replication strategies, this should be left disabled. + ReuseTypeMaps bool + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -185,6 +199,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, + reuseTypeMap: config.ReuseTypeMaps, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, @@ -196,6 +211,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), + autoLoadMutex: new(sync.Mutex), } if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { @@ -223,8 +239,12 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { return nil, err } } - - conn, err := pgx.ConnectConfig(ctx, connConfig) + var conn *pgx.Conn + if p.reuseTypeMap { + conn, err = p.ConnectConfigReusingTypeMap(ctx, connConfig) + } else { + conn, err = pgx.ConnectConfig(ctx, connConfig) + } if err != nil { return nil, err } @@ -278,6 +298,29 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { return p, nil } +func (p *Pool) ConnectConfigReusingTypeMap(ctx context.Context, connConfig *pgx.ConnConfig) (*pgx.Conn, error) { + if connConfig.AutoLoadTypes == nil || len(connConfig.AutoLoadTypes) == 0 { + return pgx.ConnectConfig(ctx, connConfig) + } + if p.autoLoadTypes == nil { + p.autoLoadMutex.Lock() + defer p.autoLoadMutex.Unlock() + if p.autoLoadTypes == nil { + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err == nil { + p.autoLoadTypes = conn.TypeMap().Types() + } + return conn, err + } + } + connConfig.AutoLoadTypes = nil + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err == nil { + conn.TypeMap().RegisterTypes(p.autoLoadTypes) + } + return conn, err +} + // ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the // addition of the following variables: // @@ -296,7 +339,12 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // # Example URL // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 func ParseConfig(connString string) (*Config, error) { - connConfig, err := pgx.ParseConfig(connString) + return ParseConfigWithOptions(connString, pgx.ParseConfigOptions{}) +} + +// ParseConfigWithOptions is the same as ParseConfig, but allows additional options to be provided. +func ParseConfigWithOptions(connString string, options pgx.ParseConfigOptions) (*Config, error) { + connConfig, err := pgx.ParseConfigWithOptions(connString, options) if err != nil { return nil, err } @@ -482,7 +530,6 @@ func (p *Pool) checkMinConns() error { func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() - errs := make(chan error, targetResources) for i := 0; i < targetResources; i++ { @@ -495,7 +542,6 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs <- err }() } - var firstError error for i := 0; i < targetResources; i++ { err := <-errs diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 90428931b..6e7c02f5c 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -261,6 +261,35 @@ func TestPoolBeforeConnect(t *testing.T) { assert.EqualValues(t, "pgx", str) } +func TestAutoLoadTypes(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(ctx) + pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support autoloading of uint64") + db1, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db1.Close() + db1.Exec(ctx, "DROP DOMAIN IF EXISTS autoload_uint64; CREATE DOMAIN autoload_uint64 as numeric(20,0)") + defer db1.Exec(ctx, "DROP DOMAIN autoload_uint64") + + config.ConnConfig.AutoLoadTypes = []string{"autoload_uint64"} + db2, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + var n uint64 + err = db2.QueryRow(ctx, "select 12::autoload_uint64").Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, uint64(12), n) +} + func TestPoolAfterConnect(t *testing.T) { t.Parallel() @@ -676,7 +705,6 @@ func TestPoolQuery(t *testing.T) { stats = pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) - } func TestPoolQueryRow(t *testing.T) { @@ -1104,7 +1132,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { } t.Fatal("did not reach min pool size") - } func TestPoolSendBatchBatchCloseTwice(t *testing.T) {