Skip to content

Commit

Permalink
Nested Value Codec Access
Browse files Browse the repository at this point in the history
This commit introduces the ability to use a codec to encode/decode individual
fields of a struct, including traversal to nested structs.
  • Loading branch information
EasterTheBunny committed Jan 9, 2025
1 parent 2ebd63b commit 0e9ad4a
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 30 deletions.
24 changes: 12 additions & 12 deletions pkg/codec/by_item_type_modifier.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package codec

import (
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/types"
)

// NewByItemTypeModifier returns a Modifier that uses modByItemType to determine which Modifier to use for a given itemType.
Expand All @@ -22,13 +19,14 @@ type byItemTypeModifier struct {
modByitemType map[string]Modifier
}

// RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this
// function applies no modifications if a modifier by the specified name is not found.
func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) {
mod, ok := b.modByitemType[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType)
if mod, ok := b.modByitemType[itemType]; ok {
return mod.RetypeToOffChain(onChainType, itemType)
}

return mod.RetypeToOffChain(onChainType, itemType)
return onChainType, nil
}

func (b *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) {
Expand All @@ -40,13 +38,15 @@ func (b *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType stri
}

func (b *byItemTypeModifier) transform(
val any, itemType string, transform func(Modifier, any, string) (any, error)) (any, error) {
mod, ok := b.modByitemType[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType)
val any,
itemType string,
transform func(Modifier, any, string) (any, error),
) (any, error) {
if mod, ok := b.modByitemType[itemType]; ok {
return transform(mod, val, itemType)
}

return transform(mod, val, itemType)
return val, nil
}

var _ Modifier = &byItemTypeModifier{}
58 changes: 54 additions & 4 deletions pkg/codec/encodings/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package encodings
import (
"fmt"
"reflect"
"strings"

"github.com/smartcontractkit/chainlink-common/pkg/types"
)
Expand All @@ -24,6 +25,8 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) {

sfs := make([]reflect.StructField, len(fields))
codecFields := make([]TypeCodec, len(fields))
lookup := make(map[string]int)

for i, field := range fields {
ft := field.Codec.GetType()
if ft.Kind() != reflect.Pointer {
Expand All @@ -35,18 +38,22 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) {
Name: field.Name,
Type: ft,
}

codecFields[i] = field.Codec
lookup[field.Name] = i
}

return &structCodec{
fields: codecFields,
tpe: reflect.PointerTo(reflect.StructOf(sfs)),
fields: codecFields,
fieldLookup: lookup,
tpe: reflect.PointerTo(reflect.StructOf(sfs)),
}, nil
}

type structCodec struct {
fields []TypeCodec
tpe reflect.Type
fields []TypeCodec
fieldLookup map[string]int
tpe reflect.Type
}

func (s *structCodec) Encode(value any, into []byte) ([]byte, error) {
Expand Down Expand Up @@ -113,3 +120,46 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) {
}
return size, nil
}

func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) {
path := extendedItemType(itemType)

// itemType could recurse into nested structs
fieldName, tail := path.next()
if fieldName == "" {
return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType)
}

idx, ok := s.fieldLookup[fieldName]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

codec := s.fields[idx]

if tail == "" {
return codec, nil
}

structType, ok := codec.(StructTypeCodec)
if !ok {
return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType)
}

return structType.FieldCodec(tail)
}

type extendedItemType string

func (t extendedItemType) next() (string, string) {
if string(t) == "" {
return "", ""
}

path := strings.Split(string(t), ".")
if len(path) == 1 {
return path[0], ""
}

return path[0], strings.Join(path[1:], ".")
}
62 changes: 48 additions & 14 deletions pkg/codec/encodings/type_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type TopLevelCodec interface {
SizeAtTopLevel(numItems int) (int, error)
}

type StructTypeCodec interface {
TypeCodec
FieldCodec(string) (TypeCodec, error)
}

// CodecFromTypeCodec maps TypeCodec to types.RemoteCodec, using the key as the itemType
// If the TypeCodec is a TopLevelCodec, GetMaxEncodingSize and GetMaxDecodingSize will call SizeAtTopLevel instead of Size.
type CodecFromTypeCodec map[string]TypeCodec
Expand All @@ -45,9 +50,9 @@ type LenientCodecFromTypeCodec map[string]TypeCodec
var _ types.RemoteCodec = &LenientCodecFromTypeCodec{}

func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) {
ntcwt, ok := c[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return nil, err
}

tpe := ntcwt.GetType()
Expand All @@ -59,9 +64,9 @@ func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) {
}

func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) ([]byte, error) {
ntcwt, ok := c[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return nil, err
}

if item != nil {
Expand All @@ -86,14 +91,15 @@ func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string)
}

func (c CodecFromTypeCodec) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) {
ntcwt, ok := c[itemType]
if !ok {
return 0, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return 0, err
}

if lp, ok := ntcwt.(TopLevelCodec); ok {
return lp.SizeAtTopLevel(n)
}

return ntcwt.Size(n)
}

Expand Down Expand Up @@ -121,11 +127,16 @@ func (c LenientCodecFromTypeCodec) Decode(ctx context.Context, raw []byte, into
return decode(c, raw, into, itemType, false)
}

func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) {
return c.GetMaxEncodingSize(ctx, n, itemType)
}

func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exactSize bool) error {
ntcwt, ok := c[itemType]
if !ok {
return fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return err
}

val, remaining, err := ntcwt.Decode(raw)
if err != nil {
return err
Expand All @@ -138,6 +149,29 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact
return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil)
}

func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) {
return c.GetMaxEncodingSize(ctx, n, itemType)
func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) {
// itemType could recurse into nested structs
path := extendedItemType(itemType)

// itemType could recurse into nested structs
head, tail := path.next()
if head == "" {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

ntcwt, ok := c[head]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

if tail == "" {
return ntcwt, nil
}

structType, ok := ntcwt.(StructTypeCodec)
if !ok {
return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType)
}

return structType.FieldCodec(tail)
}
15 changes: 15 additions & 0 deletions pkg/codec/encodings/type_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
rawbin "encoding/binary"
"math"
"reflect"
"strings"
"testing"

"github.com/smartcontractkit/libocr/bigbigendian"
Expand Down Expand Up @@ -122,6 +123,20 @@ func TestCodecFromTypeCodecs(t *testing.T) {

assert.Equal(t, singleItemSize*2, actual)
})

t.Run("CreateType works for nested struct values", func(t *testing.T) {
itemType := strings.Join([]string{TestItemType, "NestedDynamicStruct", "Inner", "S"}, ".")
ts := CreateTestStruct(0, biit)
c := biit.GetCodec(t)

encoded, err := c.Encode(tests.Context(t), ts, itemType)
require.NoError(t, err)

var actual string
require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType))

assert.Equal(t, ts.NestedDynamicStruct.Inner.S, actual)
})
}

type interfaceTesterBase struct{}
Expand Down

0 comments on commit 0e9ad4a

Please sign in to comment.