diff --git a/conn.go b/conn.go index a16b7b500..b890fd875 100644 --- a/conn.go +++ b/conn.go @@ -44,6 +44,12 @@ type ConnConfig struct { // automatically call LoadTypes with these values AutoLoadTypes []string + + // TypeRegistrationMap is used to register types which require special operations. + // The type name is the key, the value is a function which will be called for each + // connection, providing the OID of that type name for that connection. + // The function will manipulate conn.TypeMap() in some way. + TypeRegistrationMap map[string]CustomRegistrationFunction } // ParseConfigOptions contains options that control how a config is built such as getsslpassword. diff --git a/derived_types.go b/derived_types.go index 3daff2f4d..a55c90d49 100644 --- a/derived_types.go +++ b/derived_types.go @@ -10,6 +10,10 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +// CustomRegistrationFunction is capable of registering whatever is necessary for +// a custom type. It is provided with the backend's OID for this type. +type CustomRegistrationFunction func(ctx context.Context, m *pgtype.Map, oid uint32) error + /* buildLoadDerivedTypesSQL generates the correct query for retrieving type information. @@ -254,3 +258,29 @@ func serverVersion(c *Conn) (int64, error) { } return serverVersion, nil } + +func fetchOidMapForCustomRegistration(ctx context.Context, conn *Conn) (map[string]uint32, error) { + sql := ` +SELECT oid, typname +FROM pg_type +WHERE typname = ANY($1)` + result := make(map[string]uint32) + typeNames := make([]string, 0, len(conn.config.TypeRegistrationMap)) + for typeName := range conn.config.TypeRegistrationMap { + typeNames = append(typeNames, typeName) + } + rows, err := conn.Query(ctx, sql, typeNames) + if err != nil { + return nil, fmt.Errorf("While collecting OIDs for custom registrations: %w", err) + } + defer rows.Close() + var typeName string + var oid uint32 + for rows.Next() { + if err := rows.Scan(&typeName, &oid); err != nil { + return nil, fmt.Errorf("While scanning a row for custom registrations: %w", err) + } + result[typeName] = oid + } + return result, nil +}