Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PG 17 interval infinity values #2065

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
Expand Down
12 changes: 4 additions & 8 deletions pgtype/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ func (src Date) MarshalJSON() ([]byte, error) {
switch src.InfinityModifier {
case Finite:
s = src.Time.Format("2006-01-02")
case Infinity:
s = "infinity"
case NegativeInfinity:
s = "-infinity"
case Infinity, NegativeInfinity:
s = src.InfinityModifier.String()
}

return json.Marshal(s)
Expand Down Expand Up @@ -213,10 +211,8 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err
if bc {
buf = append(buf, " BC"...)
}
case Infinity:
buf = append(buf, "infinity"...)
case NegativeInfinity:
buf = append(buf, "-infinity"...)
case Infinity, NegativeInfinity:
buf = append(buf, date.InfinityModifier.String()...)
}

return buf, nil
Expand Down
215 changes: 128 additions & 87 deletions pgtype/interval.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"

Expand All @@ -27,10 +28,11 @@ type IntervalValuer interface {
}

type Interval struct {
Microseconds int64
Days int32
Months int32
Valid bool
Microseconds int64
Days int32
Months int32
InfinityModifier InfinityModifier
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth adding a field comment here along the lines of this?

// InfinityModifier for intervals is only supported in Postgres 17 or newer
InfinityModifier InfinityModifier

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't hurt.

Valid bool
}

func (interval *Interval) ScanInterval(v Interval) error {
Expand Down Expand Up @@ -63,6 +65,10 @@ func (interval Interval) Value() (driver.Value, error) {
return nil, nil
}

if interval.InfinityModifier != Finite {
return interval.InfinityModifier.String(), nil
}

buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -107,9 +113,21 @@ func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byt
return nil, nil
}

buf = pgio.AppendInt64(buf, interval.Microseconds)
buf = pgio.AppendInt32(buf, interval.Days)
buf = pgio.AppendInt32(buf, interval.Months)
switch interval.InfinityModifier {
case Finite:
buf = pgio.AppendInt64(buf, interval.Microseconds)
buf = pgio.AppendInt32(buf, interval.Days)
buf = pgio.AppendInt32(buf, interval.Months)
case Infinity:
buf = pgio.AppendInt64(buf, math.MaxInt64)
buf = pgio.AppendInt32(buf, math.MaxInt32)
buf = pgio.AppendInt32(buf, math.MaxInt32)
case NegativeInfinity:
buf = pgio.AppendInt64(buf, math.MinInt64)
buf = pgio.AppendInt32(buf, math.MinInt32)
buf = pgio.AppendInt32(buf, math.MinInt32)
}

return buf, nil
}

Expand All @@ -125,32 +143,37 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte,
return nil, nil
}

if interval.Months != 0 {
buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...)
buf = append(buf, " mon "...)
}
switch interval.InfinityModifier {
case Finite:
if interval.Months != 0 {
buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...)
buf = append(buf, " mon "...)
}

if interval.Days != 0 {
buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...)
buf = append(buf, " day "...)
}
if interval.Days != 0 {
buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...)
buf = append(buf, " day "...)
}

absMicroseconds := interval.Microseconds
if absMicroseconds < 0 {
absMicroseconds = -absMicroseconds
buf = append(buf, '-')
}
absMicroseconds := interval.Microseconds
if absMicroseconds < 0 {
absMicroseconds = -absMicroseconds
buf = append(buf, '-')
}

hours := absMicroseconds / microsecondsPerHour
minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
hours := absMicroseconds / microsecondsPerHour
minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond

timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds)
buf = append(buf, timeStr...)
timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds)
buf = append(buf, timeStr...)

microseconds := absMicroseconds % microsecondsPerSecond
if microseconds != 0 {
buf = append(buf, fmt.Sprintf(".%06d", microseconds)...)
microseconds := absMicroseconds % microsecondsPerSecond
if microseconds != 0 {
buf = append(buf, fmt.Sprintf(".%06d", microseconds)...)
}
case Infinity, NegativeInfinity:
buf = append(buf, interval.InfinityModifier.String()...)
}

return buf, nil
Expand Down Expand Up @@ -184,14 +207,22 @@ func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error {
}

if len(src) != 16 {
return fmt.Errorf("Received an invalid size for an interval: %d", len(src))
return fmt.Errorf("received an invalid size for an interval: %d", len(src))
}

microseconds := int64(binary.BigEndian.Uint64(src))
days := int32(binary.BigEndian.Uint32(src[8:]))
months := int32(binary.BigEndian.Uint32(src[12:]))

return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true})
interval := Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}

if microseconds == math.MaxInt64 && days == math.MaxInt32 && months == math.MaxInt32 {
interval.InfinityModifier = Infinity
} else if microseconds == math.MinInt64 && days == math.MinInt32 && months == math.MinInt32 {
interval.InfinityModifier = NegativeInfinity
}

return scanner.ScanInterval(interval)
}

type scanPlanTextAnyToIntervalScanner struct{}
Expand All @@ -203,80 +234,90 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error {
return scanner.ScanInterval(Interval{})
}

var microseconds int64
var days int32
var months int32

parts := strings.Split(string(src), " ")

for i := 0; i < len(parts)-1; i += 2 {
scalar, err := strconv.ParseInt(parts[i], 10, 64)
if err != nil {
return fmt.Errorf("bad interval format")
}

switch parts[i+1] {
case "year", "years":
months += int32(scalar * 12)
case "mon", "mons":
months += int32(scalar)
case "day", "days":
days = int32(scalar)
}
}
var interval Interval
sbuf := string(src)
switch sbuf {
case "infinity":
interval = Interval{InfinityModifier: Infinity, Valid: true}
case "-infinity":
interval = Interval{InfinityModifier: NegativeInfinity, Valid: true}
default:
var microseconds int64
var days int32
var months int32

parts := strings.Split(sbuf, " ")

for i := 0; i < len(parts)-1; i += 2 {
scalar, err := strconv.ParseInt(parts[i], 10, 64)
if err != nil {
return fmt.Errorf("bad interval format")
}

if len(parts)%2 == 1 {
timeParts := strings.SplitN(parts[len(parts)-1], ":", 3)
if len(timeParts) != 3 {
return fmt.Errorf("bad interval format")
switch parts[i+1] {
case "year", "years":
months += int32(scalar * 12)
case "mon", "mons":
months += int32(scalar)
case "day", "days":
days = int32(scalar)
}
}

var negative bool
if timeParts[0][0] == '-' {
negative = true
timeParts[0] = timeParts[0][1:]
}
if len(parts)%2 == 1 {
timeParts := strings.SplitN(parts[len(parts)-1], ":", 3)
if len(timeParts) != 3 {
return fmt.Errorf("bad interval format")
}

hours, err := strconv.ParseInt(timeParts[0], 10, 64)
if err != nil {
return fmt.Errorf("bad interval hour format: %s", timeParts[0])
}
var negative bool
if timeParts[0][0] == '-' {
negative = true
timeParts[0] = timeParts[0][1:]
}

minutes, err := strconv.ParseInt(timeParts[1], 10, 64)
if err != nil {
return fmt.Errorf("bad interval minute format: %s", timeParts[1])
}
hours, err := strconv.ParseInt(timeParts[0], 10, 64)
if err != nil {
return fmt.Errorf("bad interval hour format: %s", timeParts[0])
}

sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".")
minutes, err := strconv.ParseInt(timeParts[1], 10, 64)
if err != nil {
return fmt.Errorf("bad interval minute format: %s", timeParts[1])
}

seconds, err := strconv.ParseInt(sec, 10, 64)
if err != nil {
return fmt.Errorf("bad interval second format: %s", sec)
}
sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".")

var uSeconds int64
if secFracFound {
uSeconds, err = strconv.ParseInt(secFrac, 10, 64)
seconds, err := strconv.ParseInt(sec, 10, 64)
if err != nil {
return fmt.Errorf("bad interval decimal format: %s", secFrac)
return fmt.Errorf("bad interval second format: %s", sec)
}

for i := 0; i < 6-len(secFrac); i++ {
uSeconds *= 10
var uSeconds int64
if secFracFound {
uSeconds, err = strconv.ParseInt(secFrac, 10, 64)
if err != nil {
return fmt.Errorf("bad interval decimal format: %s", secFrac)
}

for i := 0; i < 6-len(secFrac); i++ {
uSeconds *= 10
}
}
}

microseconds = hours * microsecondsPerHour
microseconds += minutes * microsecondsPerMinute
microseconds += seconds * microsecondsPerSecond
microseconds += uSeconds
microseconds = hours * microsecondsPerHour
microseconds += minutes * microsecondsPerMinute
microseconds += seconds * microsecondsPerSecond
microseconds += uSeconds

if negative {
microseconds = -microseconds
if negative {
microseconds = -microseconds
}
}
interval = Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}
}

return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true})
return scanner.ScanInterval(interval)
}

func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
Expand Down
12 changes: 12 additions & 0 deletions pgtype/interval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ func TestIntervalCodec(t *testing.T) {
new(pgtype.Interval),
isExpectedEq(pgtype.Interval{Months: -13, Valid: true}),
},
{
"infinity",
new(pgtype.Interval),
isExpectedEq(pgtype.Interval{InfinityModifier: pgtype.Infinity, Valid: true}),
},
{
"-infinity",
new(pgtype.Interval),
isExpectedEq(pgtype.Interval{InfinityModifier: pgtype.NegativeInfinity, Valid: true}),
},
{time.Hour, new(time.Duration), isExpectedEq(time.Hour)},
{
pgtype.Interval{Months: 1, Days: 1, Valid: true},
Expand All @@ -149,6 +159,8 @@ func TestIntervalTextEncode(t *testing.T) {
{source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 0, Valid: true}, result: "00:00:00"},
{source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 6 * 60 * 1000000, Valid: true}, result: "00:06:00"},
{source: pgtype.Interval{Months: 0, Days: 1, Microseconds: 6*60*1000000 + 30, Valid: true}, result: "1 day 00:06:00.000030"},
{source: pgtype.Interval{InfinityModifier: pgtype.Infinity, Valid: true}, result: "infinity"},
{source: pgtype.Interval{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "-infinity"},
}
for i, tt := range successfulTests {
buf, err := m.Encode(pgtype.DateOID, pgtype.TextFormatCode, tt.source, nil)
Expand Down
12 changes: 4 additions & 8 deletions pgtype/timestamp.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) {
switch ts.InfinityModifier {
case Finite:
s = ts.Time.Format(time.RFC3339Nano)
case Infinity:
s = "infinity"
case NegativeInfinity:
s = "-infinity"
case Infinity, NegativeInfinity:
s = ts.InfinityModifier.String()
}

return json.Marshal(s)
Expand Down Expand Up @@ -205,10 +203,8 @@ func (encodePlanTimestampCodecText) Encode(value any, buf []byte) (newBuf []byte
if bc {
s = s + " BC"
}
case Infinity:
s = "infinity"
case NegativeInfinity:
s = "-infinity"
case Infinity, NegativeInfinity:
s = ts.InfinityModifier.String()
}

buf = append(buf, s...)
Expand Down
Loading
Loading