From 22fb4e44b4f101266a26d4bf9e8da0798206bc01 Mon Sep 17 00:00:00 2001 From: Nick Farrell Date: Sat, 15 Jun 2024 09:36:34 +1000 Subject: [PATCH 1/3] Load types using a single SQL query When loading even a single type into pgx's type map, multiple SQL queries are performed in series. Over a slow link, this is not ideal. Worse, if multiple types are being registered, this is repeated multiple times. This commit add LoadTypes, which can retrieve type mapping information for multiple types in a single SQL call, including recursive fetching of dependent types. RegisterTypes performs the second stage of this operation. --- conn.go | 21 ++- derived_types.go | 256 +++++++++++++++++++++++++++++++++++ derived_types_test.go | 37 +++++ go.sum | 4 - pgtype/derived_types_test.go | 58 ++++++++ pgtype/pgtype.go | 16 +++ pgx_test.go | 22 +++ pgxpool/pool.go | 7 +- 8 files changed, 413 insertions(+), 8 deletions(-) create mode 100644 derived_types.go create mode 100644 derived_types_test.go create mode 100644 pgtype/derived_types_test.go create mode 100644 pgx_test.go diff --git a/conn.go b/conn.go index 311721459..a16b7b500 100644 --- a/conn.go +++ b/conn.go @@ -41,11 +41,16 @@ type ConnConfig struct { DefaultQueryExecMode QueryExecMode createdByParseConfig bool // Used to enforce created by ParseConfig rule. + + // automatically call LoadTypes with these values + AutoLoadTypes []string } // ParseConfigOptions contains options that control how a config is built such as getsslpassword. type ParseConfigOptions struct { pgconn.ParseConfigOptions + + AutoLoadTypes []string } // Copy returns a deep copy of the config that is safe to use and modify. @@ -107,8 +112,10 @@ var ( ErrTooManyRows = errors.New("too many rows in result set") ) -var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") -var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -194,6 +201,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con DescriptionCacheCapacity: descriptionCacheCapacity, DefaultQueryExecMode: defaultQueryExecMode, connString: connString, + AutoLoadTypes: options.AutoLoadTypes, } return connConfig, nil @@ -271,6 +279,14 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) } + if c.config.AutoLoadTypes != nil { + if types, err := LoadTypes(ctx, c, c.config.AutoLoadTypes); err == nil { + c.TypeMap().RegisterTypes(types) + } else { + return nil, err + } + } + return c, nil } @@ -843,7 +859,6 @@ func (c *Conn) getStatementDescription( mode QueryExecMode, sql string, ) (sd *pgconn.StatementDescription, err error) { - switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { diff --git a/derived_types.go b/derived_types.go new file mode 100644 index 000000000..5828c2e40 --- /dev/null +++ b/derived_types.go @@ -0,0 +1,256 @@ +package pgx + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgtype" +) + +/* +buildLoadDerivedTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + + if typeNames == nil { + // collect all types. Not currently recommended. + typeNamesClause = "IS NOT NULL" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + // Each of the type names provided might be found in pg_class or pg_type. + // Additionally, it may or may not include a schema portion. + parts = append(parts, ` +WITH RECURSIVE +-- find the OIDs in pg_class which match one of the provided type names +selected_classes(oid,reltype) AS ( + -- this query uses the namespace search path, so will match type names without a schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + -- this query will only match type names which include the schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + -- collect the OIDs from pg_types which correspond to the selected classes + SELECT reltype AS oid + FROM selected_classes +UNION ALL + -- as well as any other type names which match our criteria + SELECT oid + FROM pg_type + WHERE typname `, typeNamesClause, ` +), +-- this builds a parent/child mapping of objects, allowing us to know +-- all the child (ie: dependent) types that a parent (type) requires +-- As can be seen, there are 3 ways this can occur (the last of which +-- is due to being a composite class, where the composite fields are children) +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +-- Now construct a recursive query which includes a 'depth' element. +-- This is used to ensure that the "youngest" children are registered before +-- their parents. +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +-- Bring together this information, showing all the information which might possibly be required +-- to complete the registration, applying filters to only show the items which relate to the selected +-- types/classes. +SELECT typname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid = relationships.child) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +type derivedTypeInfo struct { + Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32 + TypeName, Typtype string + Attnames []string + Atttypids []uint32 +} + +// LoadTypes performs a single (complex) query, returning all the required +// information to register the named types, as well as any other types directly +// or indirectly required to complete the registration. +// The result of this call can be passed into RegisterTypes to complete the process. +func LoadTypes(ctx context.Context, c *Conn, typeNames []string) ([]*pgtype.Type, error) { + m := c.TypeMap().Copy() + if typeNames == nil || len(typeNames) == 0 { + return nil, fmt.Errorf("No type names were supplied.") + } + + serverVersion, err := serverVersion(c) + if err != nil { + return nil, fmt.Errorf("Unexpected server version error: %w", err) + } + sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) + var rows Rows + if typeNames == nil { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol) + } else { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + } + if err != nil { + return nil, fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + result := make([]*pgtype.Type, 0, 100) + for rows.Next() { + ti := derivedTypeInfo{} + err = rows.Scan(&ti.TypeName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids) + if err != nil { + return nil, fmt.Errorf("While scanning type information: %w", err) + } + var type_ *pgtype.Type + switch ti.Typtype { + case "b": // array + dt, ok := m.TypeForOID(ti.Typelem) + if !ok { + return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName) + } + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}} + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range ti.Attnames { + //if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil { + // return nil, fmt.Errorf("While extracting OID used in composite field: %w", err) + //} + dt, ok := m.TypeForOID(ti.Atttypids[i]) + if !ok { + return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i]) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}} + case "d": // domain + dt, ok := m.TypeForOID(ti.Typbasetype) + if !ok { + return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec} + case "e": // enum + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}} + case "r": // range + dt, ok := m.TypeForOID(ti.Rngsubtype) + if !ok { + return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}} + case "m": // multirange + dt, ok := m.TypeForOID(ti.Rngtypid) + if !ok { + return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}} + default: + return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) + } + if type_ != nil { + m.RegisterType(type_) + result = append(result, type_) + } + } + return result, nil +} + +// serverVersion returns the postgresql server version. +func serverVersion(c *Conn) (int64, error) { + serverVersionStr := c.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return serverVersion, nil +} diff --git a/derived_types_test.go b/derived_types_test.go new file mode 100644 index 000000000..fba722b30 --- /dev/null +++ b/derived_types_test.go @@ -0,0 +1,37 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dtype_test; +drop domain if exists anotheruint64; + +create domain anotheruint64 as numeric(20,0); +create type dtype_test as ( + a text, + b int4, + c anotheruint64, + d anotheruint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type dtype_test") + defer conn.Exec(ctx, "drop domain anotheruint64") + + types, err := pgx.LoadTypes(ctx, conn, []string{"dtype_test"}) + require.NoError(t, err) + require.Len(t, types, 3) + require.Equal(t, types[0].Name, "anotheruint64") + require.Equal(t, types[1].Name, "_anotheruint64") + require.Equal(t, types[2].Name, "dtype_test") + }) +} diff --git a/go.sum b/go.sum index 4b02a0365..29fe452b2 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= diff --git a/pgtype/derived_types_test.go b/pgtype/derived_types_test.go new file mode 100644 index 000000000..6c1d9048c --- /dev/null +++ b/pgtype/derived_types_test.go @@ -0,0 +1,58 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestDerivedTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dt_test; +drop domain if exists dt_uint64; + +create domain dt_uint64 as numeric(20,0); +create type dt_test as ( + a text, + b dt_uint64, + c dt_uint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop domain dt_uint64") + defer conn.Exec(ctx, "drop type dt_test") + + dtypes, err := pgx.LoadTypes(ctx, conn, []string{"dt_test"}) + require.Len(t, dtypes, 3) + require.Equal(t, dtypes[0].Name, "dt_uint64") + require.Equal(t, dtypes[1].Name, "_dt_uint64") + require.Equal(t, dtypes[2].Name, "dt_test") + require.NoError(t, err) + conn.TypeMap().RegisterTypes(dtypes) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b uint64 + var c *[]uint64 + + row := conn.QueryRow(ctx, "select $1::dt_test", pgx.QueryResultFormats{format.code}, pgtype.CompositeFields{"hi", uint64(42), []uint64{10, 20, 30}}) + err := row.Scan(pgtype.CompositeFields{&a, &b, &c}) + require.NoError(t, err) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 408295683..32d68f403 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -214,6 +214,15 @@ type Map struct { TryWrapScanPlanFuncs []TryWrapScanPlanFunc } +// Copy returns a new Map containing the same registered types. +func (m *Map) Copy() *Map { + newMap := NewMap() + for _, type_ := range m.oidToType { + newMap.RegisterType(type_) + } + return newMap +} + func NewMap() *Map { defaultMapInitOnce.Do(initDefaultMap) @@ -248,6 +257,13 @@ func NewMap() *Map { } } +// RegisterTypes registers multiple data types in the sequence they are provided. +func (m *Map) RegisterTypes(types []*Type) { + for _, t := range types { + m.RegisterType(t) + } +} + // RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t diff --git a/pgx_test.go b/pgx_test.go new file mode 100644 index 000000000..51b4bbc4e --- /dev/null +++ b/pgx_test.go @@ -0,0 +1,22 @@ +package pgx_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" +) + +func skipCockroachDB(t testing.TB, msg string) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} diff --git a/pgxpool/pool.go b/pgxpool/pool.go index fdcba7241..c614df001 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -296,7 +296,12 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // # Example URL // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 func ParseConfig(connString string) (*Config, error) { - connConfig, err := pgx.ParseConfig(connString) + return ParseConfigWithOptions(connString, pgx.ParseConfigOptions{}) +} + +// ParseConfigWithOptions is the same as ParseConfig, but allows additional options to be provided. +func ParseConfigWithOptions(connString string, options pgx.ParseConfigOptions) (*Config, error) { + connConfig, err := pgx.ParseConfigWithOptions(connString, options) if err != nil { return nil, err } From 59927e7fb1e6b7a80d2abe448ea2c4bd22cd1e3c Mon Sep 17 00:00:00 2001 From: Nick Farrell Date: Mon, 17 Jun 2024 23:21:23 +1000 Subject: [PATCH 2/3] Simplify custom type autoloading with pgxpool Provide a backwards-compatible configuration option for pgxpool which streamlines the use of the bulk loading and registration of types: - ReuseTypeMaps: if enabled, pgxpool will cache the typemap information, avoiding the need to perform any further queries as new connections are created. ReuseTypeMaps is disabled by default as in some situations, a connection string might resolve to a pool of servers which do not share the same type name -> OID mapping. --- derived_types.go | 2 +- pgtype/pgtype.go | 10 ++++---- pgxpool/pool.go | 59 +++++++++++++++++++++++++++++++++++++------- pgxpool/pool_test.go | 31 +++++++++++++++++++++-- 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/derived_types.go b/derived_types.go index 5828c2e40..3daff2f4d 100644 --- a/derived_types.go +++ b/derived_types.go @@ -156,7 +156,7 @@ type derivedTypeInfo struct { // or indirectly required to complete the registration. // The result of this call can be passed into RegisterTypes to complete the process. func LoadTypes(ctx context.Context, c *Conn, typeNames []string) ([]*pgtype.Type, error) { - m := c.TypeMap().Copy() + m := c.TypeMap() if typeNames == nil || len(typeNames) == 0 { return nil, fmt.Errorf("No type names were supplied.") } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 32d68f403..0de9cfe8a 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -214,13 +214,13 @@ type Map struct { TryWrapScanPlanFuncs []TryWrapScanPlanFunc } -// Copy returns a new Map containing the same registered types. -func (m *Map) Copy() *Map { - newMap := NewMap() +// Types() returns the non-default types which were registered +func (m *Map) Types() []*Type { + result := make([]*Type, 0, len(m.oidToType)) for _, type_ := range m.oidToType { - newMap.RegisterType(type_) + result = append(result, type_) } - return newMap + return result } func NewMap() *Map { diff --git a/pgxpool/pool.go b/pgxpool/pool.go index c614df001..5424a38a6 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -12,14 +12,17 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/puddle/v2" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -100,6 +103,10 @@ type Pool struct { closeOnce sync.Once closeChan chan struct{} + + reuseTypeMap bool + autoLoadTypes []*pgtype.Type + autoLoadMutex *sync.Mutex } // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be @@ -147,6 +154,13 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration + // ReuseTypeMaps, if enabled, will reuse the typemap information being used by AutoLoadTypes. + // This removes the need to query the database each time a new connection is created; + // only RegisterTypes will need to be called for each new connection. + // In some situations, where OID mapping can differ between pg servers in the pool, perhaps due + // to certain replication strategies, this should be left disabled. + ReuseTypeMaps bool + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -185,6 +199,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, + reuseTypeMap: config.ReuseTypeMaps, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, @@ -196,6 +211,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), + autoLoadMutex: new(sync.Mutex), } if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { @@ -223,8 +239,12 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { return nil, err } } - - conn, err := pgx.ConnectConfig(ctx, connConfig) + var conn *pgx.Conn + if p.reuseTypeMap { + conn, err = p.ConnectConfigReusingTypeMap(ctx, connConfig) + } else { + conn, err = pgx.ConnectConfig(ctx, connConfig) + } if err != nil { return nil, err } @@ -278,6 +298,29 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { return p, nil } +func (p *Pool) ConnectConfigReusingTypeMap(ctx context.Context, connConfig *pgx.ConnConfig) (*pgx.Conn, error) { + if connConfig.AutoLoadTypes == nil || len(connConfig.AutoLoadTypes) == 0 { + return pgx.ConnectConfig(ctx, connConfig) + } + if p.autoLoadTypes == nil { + p.autoLoadMutex.Lock() + defer p.autoLoadMutex.Unlock() + if p.autoLoadTypes == nil { + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err == nil { + p.autoLoadTypes = conn.TypeMap().Types() + } + return conn, err + } + } + connConfig.AutoLoadTypes = nil + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err == nil { + conn.TypeMap().RegisterTypes(p.autoLoadTypes) + } + return conn, err +} + // ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the // addition of the following variables: // @@ -487,7 +530,6 @@ func (p *Pool) checkMinConns() error { func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() - errs := make(chan error, targetResources) for i := 0; i < targetResources; i++ { @@ -500,7 +542,6 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs <- err }() } - var firstError error for i := 0; i < targetResources; i++ { err := <-errs diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 90428931b..6e7c02f5c 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -261,6 +261,35 @@ func TestPoolBeforeConnect(t *testing.T) { assert.EqualValues(t, "pgx", str) } +func TestAutoLoadTypes(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(ctx) + pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support autoloading of uint64") + db1, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db1.Close() + db1.Exec(ctx, "DROP DOMAIN IF EXISTS autoload_uint64; CREATE DOMAIN autoload_uint64 as numeric(20,0)") + defer db1.Exec(ctx, "DROP DOMAIN autoload_uint64") + + config.ConnConfig.AutoLoadTypes = []string{"autoload_uint64"} + db2, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + var n uint64 + err = db2.QueryRow(ctx, "select 12::autoload_uint64").Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, uint64(12), n) +} + func TestPoolAfterConnect(t *testing.T) { t.Parallel() @@ -676,7 +705,6 @@ func TestPoolQuery(t *testing.T) { stats = pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) - } func TestPoolQueryRow(t *testing.T) { @@ -1104,7 +1132,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { } t.Fatal("did not reach min pool size") - } func TestPoolSendBatchBatchCloseTwice(t *testing.T) { From 43edb0e9563f1996147d3969208a4e3e95e0bc43 Mon Sep 17 00:00:00 2001 From: Nick Farrell Date: Mon, 17 Jun 2024 23:21:23 +1000 Subject: [PATCH 3/3] Simplify custom type autoloading with pgxpool Provide a backwards-compatible configuration option for pgxpool which streamlines the use of the bulk loading and registration of types: - ReuseTypeMaps: if enabled, pgxpool will cache the typemap information, avoiding the need to perform any further queries as new connections are created. ReuseTypeMaps is disabled by default as in some situations, a connection string might resolve to a pool of servers which do not share the same type name -> OID mapping. --- conn.go | 6 ++++++ derived_types.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/conn.go b/conn.go index a16b7b500..b890fd875 100644 --- a/conn.go +++ b/conn.go @@ -44,6 +44,12 @@ type ConnConfig struct { // automatically call LoadTypes with these values AutoLoadTypes []string + + // TypeRegistrationMap is used to register types which require special operations. + // The type name is the key, the value is a function which will be called for each + // connection, providing the OID of that type name for that connection. + // The function will manipulate conn.TypeMap() in some way. + TypeRegistrationMap map[string]CustomRegistrationFunction } // ParseConfigOptions contains options that control how a config is built such as getsslpassword. diff --git a/derived_types.go b/derived_types.go index 3daff2f4d..a55c90d49 100644 --- a/derived_types.go +++ b/derived_types.go @@ -10,6 +10,10 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +// CustomRegistrationFunction is capable of registering whatever is necessary for +// a custom type. It is provided with the backend's OID for this type. +type CustomRegistrationFunction func(ctx context.Context, m *pgtype.Map, oid uint32) error + /* buildLoadDerivedTypesSQL generates the correct query for retrieving type information. @@ -254,3 +258,29 @@ func serverVersion(c *Conn) (int64, error) { } return serverVersion, nil } + +func fetchOidMapForCustomRegistration(ctx context.Context, conn *Conn) (map[string]uint32, error) { + sql := ` +SELECT oid, typname +FROM pg_type +WHERE typname = ANY($1)` + result := make(map[string]uint32) + typeNames := make([]string, 0, len(conn.config.TypeRegistrationMap)) + for typeName := range conn.config.TypeRegistrationMap { + typeNames = append(typeNames, typeName) + } + rows, err := conn.Query(ctx, sql, typeNames) + if err != nil { + return nil, fmt.Errorf("While collecting OIDs for custom registrations: %w", err) + } + defer rows.Close() + var typeName string + var oid uint32 + for rows.Next() { + if err := rows.Scan(&typeName, &oid); err != nil { + return nil, fmt.Errorf("While scanning a row for custom registrations: %w", err) + } + result[typeName] = oid + } + return result, nil +}