Skip to content

Commit

Permalink
Simplify custom type autoloading with pgxpool
Browse files Browse the repository at this point in the history
Provide a backwards-compatible configuration option for pgxpool
which streamlines the use of the bulk loading and registration of types:
- ReuseTypeMaps: if enabled, pgxpool will cache the typemap information,
  avoiding the need to perform any further queries as new connections
  are created.

ReuseTypeMaps is disabled by default as in some situations, a
connection string might resolve to a pool of servers which do not share
the same type name -> OID mapping.
  • Loading branch information
nicois committed Jun 27, 2024
1 parent 80ddeed commit 44eb7b7
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 17 deletions.
6 changes: 6 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ type ConnConfig struct {

// automatically call LoadTypes with these values
AutoLoadTypes []string

// TypeRegistrationMap is used to register types which require special operations.
// The type name is the key, the value is a function which will be called for each
// connection, providing the OID of that type name for that connection.
// The function will manipulate conn.TypeMap() in some way.
TypeRegistrationMap map[string]CustomRegistrationFunction
}

// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
Expand Down
32 changes: 31 additions & 1 deletion derived_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

// CustomRegistrationFunction is capable of registering whatever is necessary for
// a custom type. It is provided with the backend's OID for this type.
type CustomRegistrationFunction func(ctx context.Context, m *pgtype.Map, oid uint32) error

/*
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
Expand Down Expand Up @@ -156,7 +160,7 @@ type derivedTypeInfo struct {
// or indirectly required to complete the registration.
// The result of this call can be passed into RegisterTypes to complete the process.
func LoadTypes(ctx context.Context, c *Conn, typeNames []string) ([]*pgtype.Type, error) {
m := c.TypeMap().Copy()
m := c.TypeMap()
if typeNames == nil || len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}
Expand Down Expand Up @@ -254,3 +258,29 @@ func serverVersion(c *Conn) (int64, error) {
}
return serverVersion, nil
}

func fetchOidMapForCustomRegistration(ctx context.Context, conn *Conn) (map[string]uint32, error) {
sql := `
SELECT oid, typname
FROM pg_type
WHERE typname = ANY($1)`
result := make(map[string]uint32)
typeNames := make([]string, 0, len(conn.config.TypeRegistrationMap))
for typeName := range conn.config.TypeRegistrationMap {
typeNames = append(typeNames, typeName)
}
rows, err := conn.Query(ctx, sql, typeNames)
if err != nil {
return nil, fmt.Errorf("While collecting OIDs for custom registrations: %w", err)
}
defer rows.Close()
var typeName string
var oid uint32
for rows.Next() {
if err := rows.Scan(&typeName, &oid); err != nil {
return nil, fmt.Errorf("While scanning a row for custom registrations: %w", err)
}
result[typeName] = oid
}
return result, nil
}
10 changes: 5 additions & 5 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ type Map struct {
TryWrapScanPlanFuncs []TryWrapScanPlanFunc
}

// Copy returns a new Map containing the same registered types.
func (m *Map) Copy() *Map {
newMap := NewMap()
// Types() returns the non-default types which were registered
func (m *Map) Types() []*Type {
result := make([]*Type, 0, len(m.oidToType))
for _, type_ := range m.oidToType {
newMap.RegisterType(type_)
result = append(result, type_)
}
return newMap
return result
}

func NewMap() *Map {
Expand Down
59 changes: 50 additions & 9 deletions pgxpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ import (

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/puddle/v2"
)

var defaultMaxConns = int32(4)
var defaultMinConns = int32(0)
var defaultMaxConnLifetime = time.Hour
var defaultMaxConnIdleTime = time.Minute * 30
var defaultHealthCheckPeriod = time.Minute
var (
defaultMaxConns = int32(4)
defaultMinConns = int32(0)
defaultMaxConnLifetime = time.Hour
defaultMaxConnIdleTime = time.Minute * 30
defaultHealthCheckPeriod = time.Minute
)

type connResource struct {
conn *pgx.Conn
Expand Down Expand Up @@ -100,6 +103,10 @@ type Pool struct {

closeOnce sync.Once
closeChan chan struct{}

reuseTypeMap bool
autoLoadTypes []*pgtype.Type
autoLoadMutex *sync.Mutex
}

// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
Expand Down Expand Up @@ -147,6 +154,13 @@ type Config struct {
// HealthCheckPeriod is the duration between checks of the health of idle connections.
HealthCheckPeriod time.Duration

// ReuseTypeMaps, if enabled, will reuse the typemap information being used by AutoLoadTypes.
// This removes the need to query the database each time a new connection is created;
// only RegisterTypes will need to be called for each new connection.
// In some situations, where OID mapping can differ between pg servers in the pool, perhaps due
// to certain replication strategies, this should be left disabled.
ReuseTypeMaps bool

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down Expand Up @@ -185,6 +199,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
config: config,
beforeConnect: config.BeforeConnect,
afterConnect: config.AfterConnect,
reuseTypeMap: config.ReuseTypeMaps,
beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
beforeClose: config.BeforeClose,
Expand All @@ -196,6 +211,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
healthCheckPeriod: config.HealthCheckPeriod,
healthCheckChan: make(chan struct{}, 1),
closeChan: make(chan struct{}),
autoLoadMutex: new(sync.Mutex),
}

if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok {
Expand Down Expand Up @@ -223,8 +239,12 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
return nil, err
}
}

conn, err := pgx.ConnectConfig(ctx, connConfig)
var conn *pgx.Conn
if p.reuseTypeMap {
conn, err = p.ConnectConfigReusingTypeMap(ctx, connConfig)
} else {
conn, err = pgx.ConnectConfig(ctx, connConfig)
}
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -278,6 +298,29 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
return p, nil
}

func (p *Pool) ConnectConfigReusingTypeMap(ctx context.Context, connConfig *pgx.ConnConfig) (*pgx.Conn, error) {
if connConfig.AutoLoadTypes == nil || len(connConfig.AutoLoadTypes) == 0 {
return pgx.ConnectConfig(ctx, connConfig)
}
if p.autoLoadTypes == nil {
p.autoLoadMutex.Lock()
defer p.autoLoadMutex.Unlock()
if p.autoLoadTypes == nil {
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err == nil {
p.autoLoadTypes = conn.TypeMap().Types()
}
return conn, err
}
}
connConfig.AutoLoadTypes = nil
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err == nil {
conn.TypeMap().RegisterTypes(p.autoLoadTypes)
}
return conn, err
}

// ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the
// addition of the following variables:
//
Expand Down Expand Up @@ -482,7 +525,6 @@ func (p *Pool) checkMinConns() error {
func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error {
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()

errs := make(chan error, targetResources)

for i := 0; i < targetResources; i++ {
Expand All @@ -495,7 +537,6 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
errs <- err
}()
}

var firstError error
for i := 0; i < targetResources; i++ {
err := <-errs
Expand Down
31 changes: 29 additions & 2 deletions pgxpool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,35 @@ func TestPoolBeforeConnect(t *testing.T) {
assert.EqualValues(t, "pgx", str)
}

func TestAutoLoadTypes(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)

controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer controllerConn.Close(ctx)
pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support autoloading of uint64")
db1, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)
defer db1.Close()
db1.Exec(ctx, "DROP DOMAIN IF EXISTS autoload_uint64; CREATE DOMAIN autoload_uint64 as numeric(20,0)")
defer db1.Exec(ctx, "DROP DOMAIN autoload_uint64")

config.ConnConfig.AutoLoadTypes = []string{"autoload_uint64"}
db2, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)

var n uint64
err = db2.QueryRow(ctx, "select 12::autoload_uint64").Scan(&n)
require.NoError(t, err)
assert.EqualValues(t, uint64(12), n)
}

func TestPoolAfterConnect(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -676,7 +705,6 @@ func TestPoolQuery(t *testing.T) {
stats = pool.Stat()
assert.EqualValues(t, 0, stats.AcquiredConns())
assert.EqualValues(t, 1, stats.TotalConns())

}

func TestPoolQueryRow(t *testing.T) {
Expand Down Expand Up @@ -1104,7 +1132,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) {
}

t.Fatal("did not reach min pool size")

}

func TestPoolSendBatchBatchCloseTwice(t *testing.T) {
Expand Down

0 comments on commit 44eb7b7

Please sign in to comment.