diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 079966724..f02f90cf9 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -104,10 +104,13 @@ type Pool struct { closeOnce sync.Once closeChan chan struct{} - autoLoadTypeNames []string - reuseTypeMap bool - autoLoadMutex *sync.Mutex - autoLoadTypes []*pgtype.Type + autoLoadTypeNames []string + reuseTypeMap bool + autoLoadMutex *sync.Mutex + autoLoadTypes []*pgtype.Type + customRegistrationMap map[string]CustomRegistrationFunction + customRegistrationMutex *sync.Mutex + customRegistrationOidMap map[string]uint32 } // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be @@ -198,6 +201,10 @@ func New(ctx context.Context, connString string) (*Pool, error) { return NewWithConfig(ctx, config) } +// CustomRegistrationFunction is capable of registering whatever is necessary for +// a custom type. It is provided with the backend's OID for this type. +type CustomRegistrationFunction func(ctx context.Context, m *pgtype.Map, oid uint32) error + // NewWithConfig creates a new Pool. config must have been created by [ParseConfig]. func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from @@ -207,23 +214,25 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } p := &Pool{ - config: config, - beforeConnect: config.BeforeConnect, - afterConnect: config.AfterConnect, - autoLoadTypeNames: config.AutoLoadTypes, - reuseTypeMap: config.ReuseTypeMaps, - beforeAcquire: config.BeforeAcquire, - afterRelease: config.AfterRelease, - beforeClose: config.BeforeClose, - minConns: config.MinConns, - maxConns: config.MaxConns, - maxConnLifetime: config.MaxConnLifetime, - maxConnLifetimeJitter: config.MaxConnLifetimeJitter, - maxConnIdleTime: config.MaxConnIdleTime, - healthCheckPeriod: config.HealthCheckPeriod, - healthCheckChan: make(chan struct{}, 1), - closeChan: make(chan struct{}), - autoLoadMutex: new(sync.Mutex), + config: config, + beforeConnect: config.BeforeConnect, + afterConnect: config.AfterConnect, + autoLoadTypeNames: config.AutoLoadTypes, + reuseTypeMap: config.ReuseTypeMaps, + beforeAcquire: config.BeforeAcquire, + afterRelease: config.AfterRelease, + beforeClose: config.BeforeClose, + minConns: config.MinConns, + maxConns: config.MaxConns, + maxConnLifetime: config.MaxConnLifetime, + maxConnLifetimeJitter: config.MaxConnLifetimeJitter, + maxConnIdleTime: config.MaxConnIdleTime, + healthCheckPeriod: config.HealthCheckPeriod, + healthCheckChan: make(chan struct{}, 1), + closeChan: make(chan struct{}), + autoLoadMutex: new(sync.Mutex), + customRegistrationMap: make(map[string]CustomRegistrationFunction), + customRegistrationMutex: new(sync.Mutex), } if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { @@ -265,6 +274,24 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } } + if len(p.customRegistrationMap) > 0 { + oidMap, err := p.getOidMapForCustomRegistration(ctx, conn) + if err != nil { + conn.Close(ctx) + return nil, fmt.Errorf("While retrieving OIDs for custom type registration: %w", err) + } + for typeName, f := range p.customRegistrationMap { + if oid, exists := oidMap[typeName]; exists { + if err := f(ctx, conn.TypeMap(), oid); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("Type %q does not have an associated OID.", typeName) + } + } + + } + if p.autoLoadTypeNames != nil && len(p.autoLoadTypeNames) > 0 { types, err := p.loadTypes(ctx, conn, p.autoLoadTypeNames) if err != nil { @@ -315,6 +342,51 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { return p, nil } +func (p *Pool) getOidMapForCustomRegistration(ctx context.Context, conn *pgx.Conn) (map[string]uint32, error) { + if p.reuseTypeMap { + p.customRegistrationMutex.Lock() + defer p.customRegistrationMutex.Unlock() + if p.customRegistrationOidMap != nil { + return p.customRegistrationOidMap, nil + } + oidMap, err := p.fetchOidMapForCustomRegistration(ctx, conn) + if err != nil { + return nil, err + } + p.customRegistrationOidMap = oidMap + return oidMap, nil + } + // Avoid needing to acquire the mutex and allow connections to initialise in parallel + // if we have chosen to not reuse the type mapping + return p.fetchOidMapForCustomRegistration(ctx, conn) +} + +func (p *Pool) fetchOidMapForCustomRegistration(ctx context.Context, conn *pgx.Conn) (map[string]uint32, error) { + sql := ` +SELECT oid, typname +FROM pg_type +WHERE typname = ANY($1)` + result := make(map[string]uint32) + typeNames := make([]string, 0, len(p.customRegistrationMap)) + for typeName := range p.customRegistrationMap { + typeNames = append(typeNames, typeName) + } + rows, err := conn.Query(ctx, sql, typeNames) + if err != nil { + return nil, fmt.Errorf("While collecting OIDs for custom registrations: %w", err) + } + defer rows.Close() + var typeName string + var oid uint32 + for rows.Next() { + if err := rows.Scan(&typeName, &oid); err != nil { + return nil, fmt.Errorf("While scanning a row for custom registrations: %w", err) + } + result[typeName] = oid + } + return result, nil +} + // ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the // addition of the following variables: // @@ -425,6 +497,12 @@ func (p *Pool) Close() { }) } +// RegisterCustomType is used to provide a function capable of performing +// type registration for situations where the autoloader is unable to do so on its own +func (p *Pool) RegisterCustomType(typeName string, f CustomRegistrationFunction) { + p.customRegistrationMap[typeName] = f +} + // loadTypes is used internally to autoload the custom types for a connection, // potentially reusing previously-loaded typemap information. func (p *Pool) loadTypes(ctx context.Context, conn *pgx.Conn, typeNames []string) ([]*pgtype.Type, error) {