Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested Value Codec Access #990

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions pkg/codec/by_item_type_modifier.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package codec

import (
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/types"
)

// NewByItemTypeModifier returns a Modifier that uses modByItemType to determine which Modifier to use for a given itemType.
Expand All @@ -22,13 +19,14 @@ type byItemTypeModifier struct {
modByitemType map[string]Modifier
}

// RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this
// function applies no modifications if a modifier by the specified name is not found.
func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) {
mod, ok := b.modByitemType[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType)
if mod, ok := b.modByitemType[itemType]; ok {
return mod.RetypeToOffChain(onChainType, itemType)
}

return mod.RetypeToOffChain(onChainType, itemType)
return onChainType, nil
}

func (b *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) {
Expand All @@ -40,13 +38,15 @@ func (b *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType stri
}

func (b *byItemTypeModifier) transform(
val any, itemType string, transform func(Modifier, any, string) (any, error)) (any, error) {
mod, ok := b.modByitemType[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType)
val any,
itemType string,
transform func(Modifier, any, string) (any, error),
) (any, error) {
if mod, ok := b.modByitemType[itemType]; ok {
return transform(mod, val, itemType)
}

return transform(mod, val, itemType)
return val, nil
}

var _ Modifier = &byItemTypeModifier{}
58 changes: 54 additions & 4 deletions pkg/codec/encodings/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package encodings
import (
"fmt"
"reflect"
"strings"

"github.com/smartcontractkit/chainlink-common/pkg/types"
)
Expand All @@ -24,6 +25,8 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) {

sfs := make([]reflect.StructField, len(fields))
codecFields := make([]TypeCodec, len(fields))
lookup := make(map[string]int)

for i, field := range fields {
ft := field.Codec.GetType()
if ft.Kind() != reflect.Pointer {
Expand All @@ -35,18 +38,22 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) {
Name: field.Name,
Type: ft,
}

codecFields[i] = field.Codec
lookup[field.Name] = i
}

return &structCodec{
fields: codecFields,
tpe: reflect.PointerTo(reflect.StructOf(sfs)),
fields: codecFields,
fieldLookup: lookup,
tpe: reflect.PointerTo(reflect.StructOf(sfs)),
}, nil
}

type structCodec struct {
fields []TypeCodec
tpe reflect.Type
fields []TypeCodec
fieldLookup map[string]int
tpe reflect.Type
}

func (s *structCodec) Encode(value any, into []byte) ([]byte, error) {
Expand Down Expand Up @@ -113,3 +120,46 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) {
}
return size, nil
}

func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) {
path := extendedItemType(itemType)

// itemType could recurse into nested structs
fieldName, tail := path.next()
if fieldName == "" {
return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType)
}

idx, ok := s.fieldLookup[fieldName]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

codec := s.fields[idx]

if tail == "" {
return codec, nil
}

structType, ok := codec.(StructTypeCodec)
if !ok {
return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType)
}

return structType.FieldCodec(tail)
}

type extendedItemType string

func (t extendedItemType) next() (string, string) {
if string(t) == "" {
return "", ""
}

path := strings.Split(string(t), ".")
if len(path) == 1 {
return path[0], ""
}

return path[0], strings.Join(path[1:], ".")
}
62 changes: 48 additions & 14 deletions pkg/codec/encodings/type_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type TopLevelCodec interface {
SizeAtTopLevel(numItems int) (int, error)
}

type StructTypeCodec interface {
TypeCodec
FieldCodec(string) (TypeCodec, error)
}

// CodecFromTypeCodec maps TypeCodec to types.RemoteCodec, using the key as the itemType
// If the TypeCodec is a TopLevelCodec, GetMaxEncodingSize and GetMaxDecodingSize will call SizeAtTopLevel instead of Size.
type CodecFromTypeCodec map[string]TypeCodec
Expand All @@ -45,9 +50,9 @@ type LenientCodecFromTypeCodec map[string]TypeCodec
var _ types.RemoteCodec = &LenientCodecFromTypeCodec{}

func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) {
ntcwt, ok := c[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return nil, err
}

tpe := ntcwt.GetType()
Expand All @@ -59,9 +64,9 @@ func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) {
}

func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) ([]byte, error) {
ntcwt, ok := c[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return nil, err
}

if item != nil {
Expand All @@ -86,14 +91,15 @@ func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string)
}

func (c CodecFromTypeCodec) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) {
ntcwt, ok := c[itemType]
if !ok {
return 0, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return 0, err
}

if lp, ok := ntcwt.(TopLevelCodec); ok {
return lp.SizeAtTopLevel(n)
}

return ntcwt.Size(n)
}

Expand Down Expand Up @@ -121,11 +127,16 @@ func (c LenientCodecFromTypeCodec) Decode(ctx context.Context, raw []byte, into
return decode(c, raw, into, itemType, false)
}

func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) {
return c.GetMaxEncodingSize(ctx, n, itemType)
}

func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exactSize bool) error {
ntcwt, ok := c[itemType]
if !ok {
return fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return err
}

val, remaining, err := ntcwt.Decode(raw)
if err != nil {
return err
Expand All @@ -138,6 +149,29 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact
return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil)
}

func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) {
return c.GetMaxEncodingSize(ctx, n, itemType)
func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) {
// itemType could recurse into nested structs
path := extendedItemType(itemType)

// itemType could recurse into nested structs
head, tail := path.next()
if head == "" {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

ntcwt, ok := c[head]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

if tail == "" {
return ntcwt, nil
}

structType, ok := ntcwt.(StructTypeCodec)
if !ok {
return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType)
}

return structType.FieldCodec(tail)
}
29 changes: 29 additions & 0 deletions pkg/codec/encodings/type_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
rawbin "encoding/binary"
"math"
"reflect"
"strings"
"testing"

"github.com/smartcontractkit/libocr/bigbigendian"
Expand Down Expand Up @@ -122,6 +123,34 @@ func TestCodecFromTypeCodecs(t *testing.T) {

assert.Equal(t, singleItemSize*2, actual)
})

t.Run("CreateType works for nested struct values", func(t *testing.T) {
itemType := strings.Join([]string{TestItemType, "AccountStruct", "Account"}, ".")
ts := CreateTestStruct(0, biit)
c := biit.GetCodec(t)

encoded, err := c.Encode(tests.Context(t), ts.AccountStruct.Account, itemType)
require.NoError(t, err)

var actual []byte
require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType))

assert.Equal(t, ts.AccountStruct.Account, actual)
})

t.Run("CreateType works for nested struct values and modifiers", func(t *testing.T) {
itemType := strings.Join([]string{TestItemWithConfigExtra, "NestedDynamicStruct", "Inner", "S"}, ".")
ts := CreateTestStruct(0, biit)
c := biit.GetCodec(t)

encoded, err := c.Encode(tests.Context(t), ts.NestedDynamicStruct.Inner.S, itemType)
require.NoError(t, err)

var actual string
require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType))

assert.Equal(t, ts.NestedDynamicStruct.Inner.S, actual)
})
}

type interfaceTesterBase struct{}
Expand Down
2 changes: 1 addition & 1 deletion pkg/codec/modifier_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type modifierBase[T any] struct {
addFieldForInput func(pkgPath, name string, change T) reflect.StructField
}

func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType string) (tpe reflect.Type, err error) {
func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, _ string) (tpe reflect.Type, err error) {
defer func() {
// StructOf can panic if the fields are not valid
if r := recover(); r != nil {
Expand Down
Loading