Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional CBOR types #160

Merged
merged 6 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func Query[TResult any](db *DB, sql string, vars map[string]interface{}) (*[]Que
return res.Result, nil
}

func Create[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
func Create[TResult any, TWhat TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
var res connection.RPCResponse[TResult]
if err := db.con.Send(&res, "create", what, data); err != nil {
return nil, err
Expand All @@ -210,7 +210,7 @@ func Create[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data in
return res.Result, nil
}

func Select[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat) (*TResult, error) {
func Select[TResult any, TWhat TableOrRecord](db *DB, what TWhat) (*TResult, error) {
var res connection.RPCResponse[TResult]

if err := db.con.Send(&res, "select", what); err != nil {
Expand All @@ -226,16 +226,16 @@ func Patch(db *DB, what interface{}, patches []PatchData) (*[]PatchData, error)
return patchRes.Result, err
}

func Delete[TWhat models.TableOrRecord](db *DB, what TWhat) error {
func Delete[TWhat TableOrRecord](db *DB, what TWhat) error {
return db.con.Send(nil, "delete", what)
}

func Upsert[TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) error {
func Upsert[TWhat TableOrRecord](db *DB, what TWhat, data interface{}) error {
return db.con.Send(nil, "upsert", what, data)
}

// Update a table or record in the database like a PUT request.
func Update[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
func Update[TResult any, TWhat TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
var res connection.RPCResponse[TResult]
if err := db.con.Send(&res, "update", what, data); err != nil {
return nil, err
Expand Down
86 changes: 47 additions & 39 deletions pkg/models/cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,68 @@ package models
import (
"io"
"reflect"
"time"

"github.com/fxamacker/cbor/v2"
"github.com/surrealdb/surrealdb.go/internal/codec"
)

type CustomCBORTag uint64

var (
NoneTag CustomCBORTag = 6
TableNameTag CustomCBORTag = 7
RecordIDTag CustomCBORTag = 8
UUIDStringTag CustomCBORTag = 9
DecimalStringTag CustomCBORTag = 10
DateTimeCompactString CustomCBORTag = 12
DurationStringTag CustomCBORTag = 13
DurationCompactTag CustomCBORTag = 14
BinaryUUIDTag CustomCBORTag = 37
GeometryPointTag CustomCBORTag = 88
GeometryLineTag CustomCBORTag = 89
GeometryPolygonTag CustomCBORTag = 90
GeometryMultiPointTag CustomCBORTag = 91
GeometryMultiLineTag CustomCBORTag = 92
GeometryMultiPolygonTag CustomCBORTag = 93
GeometryCollectionTag CustomCBORTag = 94
TagNone uint64 = 6
TagTable uint64 = 7
TagRecordID uint64 = 8
TagCustomDatetime uint64 = 12
TagCustomDuration uint64 = 14
TagFuture uint64 = 15

TagStringUUID uint64 = 9
TagStringDecimal uint64 = 10
TagStringDuration uint64 = 13

TagSpecBinaryUUID uint64 = 37

TagRange uint64 = 49
TagBoundIncluded uint64 = 50
TagBoundExcluded uint64 = 51

TagGeometryPoint uint64 = 88
TagGeometryLine uint64 = 89
TagGeometryPolygon uint64 = 90
TagGeometryMultiPoint uint64 = 91
TagGeometryMultiLine uint64 = 92
TagGeometryMultiPolygon uint64 = 93
TagGeometryCollection uint64 = 94
)

func registerCborTags() cbor.TagSet {
customTags := map[CustomCBORTag]interface{}{
GeometryPointTag: GeometryPoint{},
GeometryLineTag: GeometryLine{},
GeometryPolygonTag: GeometryPolygon{},
GeometryMultiPointTag: GeometryMultiPoint{},
GeometryMultiLineTag: GeometryMultiLine{},
GeometryMultiPolygonTag: GeometryMultiPolygon{},
GeometryCollectionTag: GeometryCollection{},

TableNameTag: Table(""),
//UUIDStringTag: UUID(""),
DecimalStringTag: Decimal(""),
BinaryUUIDTag: UUID{},
NoneTag: CustomNil{},

DateTimeCompactString: CustomDateTime(time.Now()),
DurationStringTag: CustomDurationStr("2w"),
//DurationCompactTag: CustomDuration(0),
customTags := map[uint64]interface{}{
TagNone: CustomNil{},
TagTable: Table(""),
TagRecordID: RecordID{},

TagCustomDatetime: CustomDateTime{},
TagCustomDuration: CustomDuration{},
TagFuture: Future{},

TagStringUUID: UUIDString(""),
TagStringDecimal: DecimalString(""),
TagStringDuration: CustomDurationString(""),

TagSpecBinaryUUID: UUID{},

TagGeometryPoint: GeometryPoint{},
TagGeometryLine: GeometryLine{},
TagGeometryPolygon: GeometryPolygon{},
TagGeometryMultiPoint: GeometryMultiPoint{},
TagGeometryMultiLine: GeometryMultiLine{},
TagGeometryMultiPolygon: GeometryMultiPolygon{},
TagGeometryCollection: GeometryCollection{},
}

tags := cbor.NewTagSet()
for tag, customType := range customTags {
err := tags.Add(
cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired},
reflect.TypeOf(customType),
uint64(tag),
tag,
)
if err != nil {
panic(err)
Expand Down
131 changes: 123 additions & 8 deletions pkg/models/cbor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ func TestForRequestPayload(t *testing.T) {
params := []interface{}{
"SELECT marketing, count() FROM $tb GROUP BY marketing",
map[string]interface{}{
"tb": Table("person"),
"line": GeometryLine{NewGeometryPoint(11.11, 22.22), NewGeometryPoint(33.33, 44.44)},
"datetime": time.Now(),
"testNone": None,
"testNil": nil,
"duration": time.Duration(340),
// "custom_duration": CustomDuration(340),
"custom_datetime": CustomDateTime(time.Now()),
"tb": Table("person"),
"line": GeometryLine{NewGeometryPoint(11.11, 22.22), NewGeometryPoint(33.33, 44.44)},
"datetime": time.Now(),
"testNone": None,
"testNil": nil,
"duration": time.Duration(340),
"custom_duration": CustomDuration{340},
"custom_datetime": CustomDateTime{time.Now()},
},
}

Expand All @@ -94,3 +94,118 @@ func TestForRequestPayload(t *testing.T) {

fmt.Println(diagStr)
}

func TestRange_GetJoinString(t *testing.T) {
t.Run("begin excluded, end excluded", func(s *testing.T) {
r := &Range[int, BoundExcluded[int], BoundExcluded[int]]{
Begin: &BoundExcluded[int]{0},
End: &BoundExcluded[int]{10},
}
assert.Equal(t, ">..", r.GetJoinString())
})

t.Run("begin excluded, end included", func(t *testing.T) {
r := Range[int, BoundExcluded[int], BoundIncluded[int]]{
Begin: &BoundExcluded[int]{0},
End: &BoundIncluded[int]{10},
}
assert.Equal(t, ">..=", r.GetJoinString())
})

t.Run("begin included, end excluded", func(t *testing.T) {
r := Range[int, BoundIncluded[int], BoundExcluded[int]]{
Begin: &BoundIncluded[int]{0},
End: &BoundExcluded[int]{10},
}
assert.Equal(t, "..", r.GetJoinString())
})

t.Run("begin included, end included", func(t *testing.T) {
r := Range[int, BoundIncluded[int], BoundIncluded[int]]{
Begin: &BoundIncluded[int]{0},
End: &BoundIncluded[int]{10},
}
assert.Equal(t, "..=", r.GetJoinString())
})
}

func TestRange_Bounds(t *testing.T) {
em := getCborEncoder()
dm := getCborDecoder()

t.Run("bound included should be marshaled and unmarshaled properly", func(t *testing.T) {
bi := BoundIncluded[int]{10}
encoded, err := em.Marshal(bi)
assert.NoError(t, err)

var decoded BoundIncluded[int]
err = dm.Unmarshal(encoded, &decoded)
assert.NoError(t, err)
assert.Equal(t, bi, decoded)
})

t.Run("bound excluded should be marshaled and unmarshaled properly", func(t *testing.T) {
be := BoundExcluded[int]{10}
encoded, err := em.Marshal(be)
assert.NoError(t, err)

var decoded BoundExcluded[int]
err = dm.Unmarshal(encoded, &decoded)
assert.NoError(t, err)
assert.Equal(t, be, decoded)
})
}

func TestRange_CODEC(t *testing.T) {
em := getCborEncoder()
dm := getCborDecoder()

r := Range[int, BoundIncluded[int], BoundExcluded[int]]{
Begin: &BoundIncluded[int]{0},
End: &BoundExcluded[int]{10},
}

encoded, err := em.Marshal(r)
assert.NoError(t, err)

var decoded Range[int, BoundIncluded[int], BoundExcluded[int]]
err = dm.Unmarshal(encoded, &decoded)
assert.NoError(t, err)
assert.Equal(t, r, decoded)
}

func TestCustomDateTime_String(t *testing.T) {
time1, err := time.Parse("2006-01-02 15:04:05", "2024-10-30 12:05:00")
assert.NoError(t, err)

cd := CustomDateTime{time1}
assert.Equal(t, "2024-10-30T12:05:00Z", cd.String())
}

func TestTable_String(t *testing.T) {
table := Table("mytesttable")
assert.Equal(t, "mytesttable", table.String())
}

func TestCustomDuration_String(t *testing.T) {
cd := CustomDuration{time.Duration(33333333333000000)}
assert.Equal(t, "1y2w6d19h15m33s333ms", cd.String())
}

func TestRecordID_String(t *testing.T) {
rid := RecordID{Table: "mytesttable", ID: "121212121"}
assert.Equal(t, "mytesttable:121212121", rid.String())
}

func TestFormatDurationAndParseDuration(t *testing.T) {
durationStr := "1y2w6d19h15m33s333ms"

ns, _ := ParseDuration(durationStr)
d := FormatDuration(ns)
assert.Equal(t, durationStr, d)
}

func TestFormatDuration(t *testing.T) {
d := FormatDuration(33333333333000000)
assert.Equal(t, "1y2w6d19h15m33s333ms", d)
}
54 changes: 54 additions & 0 deletions pkg/models/datetime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package models

import (
"fmt"
"time"

"github.com/fxamacker/cbor/v2"
"github.com/surrealdb/surrealdb.go/pkg/constants"
)

// CustomDateTime embeds time.Time
type CustomDateTime struct {
time.Time
}

func (d *CustomDateTime) MarshalCBOR() ([]byte, error) {
enc := getCborEncoder()

totalNS := d.Nanosecond()

s := totalNS / constants.OneSecondToNanoSecond
ns := totalNS % constants.OneSecondToNanoSecond

return enc.Marshal(cbor.Tag{
Number: TagCustomDatetime,
Content: [2]int64{int64(s), int64(ns)},
})
}

func (d *CustomDateTime) UnmarshalCBOR(data []byte) error {
dec := getCborDecoder()

var temp [2]int64
err := dec.Unmarshal(data, &temp)
if err != nil {
return err
}

s := temp[0]
ns := temp[1]

*d = CustomDateTime{time.Unix(s, ns)}

return nil
}

func (d *CustomDateTime) String() string {
layout := "2006-01-02T15:04:05Z"
return d.Format(layout)
}

func (d *CustomDateTime) SurrealString() string {
return fmt.Sprintf("<datetime> '%s'", d.String())
}
Loading
Loading