diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index f68e73e9d7..12967caace 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -1,10 +1,7 @@ package codec import ( - "fmt" "reflect" - - "github.com/smartcontractkit/chainlink-common/pkg/types" ) // NewByItemTypeModifier returns a Modifier that uses modByItemType to determine which Modifier to use for a given itemType. @@ -22,13 +19,14 @@ type byItemTypeModifier struct { modByitemType map[string]Modifier } +// RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this +// function applies no modifications if a modifier by the specified name is not found. func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { - mod, ok := b.modByitemType[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) + if mod, ok := b.modByitemType[itemType]; ok { + return mod.RetypeToOffChain(onChainType, itemType) } - return mod.RetypeToOffChain(onChainType, itemType) + return onChainType, nil } func (b *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { @@ -40,13 +38,15 @@ func (b *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType stri } func (b *byItemTypeModifier) transform( - val any, itemType string, transform func(Modifier, any, string) (any, error)) (any, error) { - mod, ok := b.modByitemType[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) + val any, + itemType string, + transform func(Modifier, any, string) (any, error), +) (any, error) { + if mod, ok := b.modByitemType[itemType]; ok { + return transform(mod, val, itemType) } - return transform(mod, val, itemType) + return val, nil } var _ Modifier = &byItemTypeModifier{} diff --git a/pkg/codec/encodings/struct.go b/pkg/codec/encodings/struct.go index 946936c457..a7a4747777 100644 --- a/pkg/codec/encodings/struct.go +++ b/pkg/codec/encodings/struct.go @@ -3,6 +3,7 @@ package encodings import ( "fmt" "reflect" + "strings" "github.com/smartcontractkit/chainlink-common/pkg/types" ) @@ -24,6 +25,8 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) { sfs := make([]reflect.StructField, len(fields)) codecFields := make([]TypeCodec, len(fields)) + lookup := make(map[string]int) + for i, field := range fields { ft := field.Codec.GetType() if ft.Kind() != reflect.Pointer { @@ -35,18 +38,22 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) { Name: field.Name, Type: ft, } + codecFields[i] = field.Codec + lookup[field.Name] = i } return &structCodec{ - fields: codecFields, - tpe: reflect.PointerTo(reflect.StructOf(sfs)), + fields: codecFields, + fieldLookup: lookup, + tpe: reflect.PointerTo(reflect.StructOf(sfs)), }, nil } type structCodec struct { - fields []TypeCodec - tpe reflect.Type + fields []TypeCodec + fieldLookup map[string]int + tpe reflect.Type } func (s *structCodec) Encode(value any, into []byte) ([]byte, error) { @@ -113,3 +120,46 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) { } return size, nil } + +func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) { + path := extendedItemType(itemType) + + // itemType could recurse into nested structs + fieldName, tail := path.next() + if fieldName == "" { + return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType) + } + + idx, ok := s.fieldLookup[fieldName] + if !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + codec := s.fields[idx] + + if tail == "" { + return codec, nil + } + + structType, ok := codec.(StructTypeCodec) + if !ok { + return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType) + } + + return structType.FieldCodec(tail) +} + +type extendedItemType string + +func (t extendedItemType) next() (string, string) { + if string(t) == "" { + return "", "" + } + + path := strings.Split(string(t), ".") + if len(path) == 1 { + return path[0], "" + } + + return path[0], strings.Join(path[1:], ".") +} diff --git a/pkg/codec/encodings/type_codec.go b/pkg/codec/encodings/type_codec.go index 1807df8c1a..5b0d35b281 100644 --- a/pkg/codec/encodings/type_codec.go +++ b/pkg/codec/encodings/type_codec.go @@ -33,6 +33,11 @@ type TopLevelCodec interface { SizeAtTopLevel(numItems int) (int, error) } +type StructTypeCodec interface { + TypeCodec + FieldCodec(string) (TypeCodec, error) +} + // CodecFromTypeCodec maps TypeCodec to types.RemoteCodec, using the key as the itemType // If the TypeCodec is a TopLevelCodec, GetMaxEncodingSize and GetMaxDecodingSize will call SizeAtTopLevel instead of Size. type CodecFromTypeCodec map[string]TypeCodec @@ -45,9 +50,9 @@ type LenientCodecFromTypeCodec map[string]TypeCodec var _ types.RemoteCodec = &LenientCodecFromTypeCodec{} func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) { - ntcwt, ok := c[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return nil, err } tpe := ntcwt.GetType() @@ -59,9 +64,9 @@ func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) { } func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) ([]byte, error) { - ntcwt, ok := c[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return nil, err } if item != nil { @@ -86,14 +91,15 @@ func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) } func (c CodecFromTypeCodec) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) { - ntcwt, ok := c[itemType] - if !ok { - return 0, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return 0, err } if lp, ok := ntcwt.(TopLevelCodec); ok { return lp.SizeAtTopLevel(n) } + return ntcwt.Size(n) } @@ -121,11 +127,16 @@ func (c LenientCodecFromTypeCodec) Decode(ctx context.Context, raw []byte, into return decode(c, raw, into, itemType, false) } +func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) { + return c.GetMaxEncodingSize(ctx, n, itemType) +} + func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exactSize bool) error { - ntcwt, ok := c[itemType] - if !ok { - return fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return err } + val, remaining, err := ntcwt.Decode(raw) if err != nil { return err @@ -138,6 +149,29 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil) } -func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) { - return c.GetMaxEncodingSize(ctx, n, itemType) +func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) { + // itemType could recurse into nested structs + path := extendedItemType(itemType) + + // itemType could recurse into nested structs + head, tail := path.next() + if head == "" { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + ntcwt, ok := c[head] + if !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + if tail == "" { + return ntcwt, nil + } + + structType, ok := ntcwt.(StructTypeCodec) + if !ok { + return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType) + } + + return structType.FieldCodec(tail) } diff --git a/pkg/codec/encodings/type_codec_test.go b/pkg/codec/encodings/type_codec_test.go index 874819ff23..35e942c525 100644 --- a/pkg/codec/encodings/type_codec_test.go +++ b/pkg/codec/encodings/type_codec_test.go @@ -4,6 +4,7 @@ import ( rawbin "encoding/binary" "math" "reflect" + "strings" "testing" "github.com/smartcontractkit/libocr/bigbigendian" @@ -122,6 +123,20 @@ func TestCodecFromTypeCodecs(t *testing.T) { assert.Equal(t, singleItemSize*2, actual) }) + + t.Run("CreateType works for nested struct values", func(t *testing.T) { + itemType := strings.Join([]string{TestItemType, "NestedDynamicStruct", "Inner", "S"}, ".") + ts := CreateTestStruct(0, biit) + c := biit.GetCodec(t) + + encoded, err := c.Encode(tests.Context(t), ts, itemType) + require.NoError(t, err) + + var actual string + require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType)) + + assert.Equal(t, ts.NestedDynamicStruct.Inner.S, actual) + }) } type interfaceTesterBase struct{}