From bdac53eb35617773bc5ef5697234ab0af7b4629d Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Wed, 15 Jan 2025 10:09:49 -0600 Subject: [PATCH 1/8] rename modifier functional for nested fields --- pkg/codec/by_item_type_modifier.go | 22 ++-- pkg/codec/encodings/struct.go | 58 +++++++++- pkg/codec/encodings/type_codec.go | 62 +++++++--- pkg/codec/encodings/type_codec_test.go | 31 ++++- pkg/codec/hard_coder.go | 4 +- pkg/codec/modifier_base.go | 149 +++++++++++++++++++++---- pkg/codec/renamer.go | 68 +++++++++-- pkg/codec/renamer_test.go | 39 +++++++ 8 files changed, 376 insertions(+), 57 deletions(-) diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index f68e73e9d..3009858ca 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -22,13 +22,17 @@ type byItemTypeModifier struct { modByitemType map[string]Modifier } +// RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this +// function applies no modifications if a modifier by the specified name is not found. func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { - mod, ok := b.modByitemType[itemType] + head, tail := extendedItemType(itemType).next() + + mod, ok := b.modByitemType[head] if !ok { return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) } - return mod.RetypeToOffChain(onChainType, itemType) + return mod.RetypeToOffChain(onChainType, tail) } func (b *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { @@ -40,13 +44,17 @@ func (b *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType stri } func (b *byItemTypeModifier) transform( - val any, itemType string, transform func(Modifier, any, string) (any, error)) (any, error) { - mod, ok := b.modByitemType[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) + val any, + itemType string, + transform func(Modifier, any, string) (any, error), +) (any, error) { + head, tail := extendedItemType(itemType).next() + + if mod, ok := b.modByitemType[head]; ok { + return transform(mod, val, tail) } - return transform(mod, val, itemType) + return val, nil } var _ Modifier = &byItemTypeModifier{} diff --git a/pkg/codec/encodings/struct.go b/pkg/codec/encodings/struct.go index 946936c45..a7a474777 100644 --- a/pkg/codec/encodings/struct.go +++ b/pkg/codec/encodings/struct.go @@ -3,6 +3,7 @@ package encodings import ( "fmt" "reflect" + "strings" "github.com/smartcontractkit/chainlink-common/pkg/types" ) @@ -24,6 +25,8 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) { sfs := make([]reflect.StructField, len(fields)) codecFields := make([]TypeCodec, len(fields)) + lookup := make(map[string]int) + for i, field := range fields { ft := field.Codec.GetType() if ft.Kind() != reflect.Pointer { @@ -35,18 +38,22 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) { Name: field.Name, Type: ft, } + codecFields[i] = field.Codec + lookup[field.Name] = i } return &structCodec{ - fields: codecFields, - tpe: reflect.PointerTo(reflect.StructOf(sfs)), + fields: codecFields, + fieldLookup: lookup, + tpe: reflect.PointerTo(reflect.StructOf(sfs)), }, nil } type structCodec struct { - fields []TypeCodec - tpe reflect.Type + fields []TypeCodec + fieldLookup map[string]int + tpe reflect.Type } func (s *structCodec) Encode(value any, into []byte) ([]byte, error) { @@ -113,3 +120,46 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) { } return size, nil } + +func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) { + path := extendedItemType(itemType) + + // itemType could recurse into nested structs + fieldName, tail := path.next() + if fieldName == "" { + return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType) + } + + idx, ok := s.fieldLookup[fieldName] + if !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + codec := s.fields[idx] + + if tail == "" { + return codec, nil + } + + structType, ok := codec.(StructTypeCodec) + if !ok { + return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType) + } + + return structType.FieldCodec(tail) +} + +type extendedItemType string + +func (t extendedItemType) next() (string, string) { + if string(t) == "" { + return "", "" + } + + path := strings.Split(string(t), ".") + if len(path) == 1 { + return path[0], "" + } + + return path[0], strings.Join(path[1:], ".") +} diff --git a/pkg/codec/encodings/type_codec.go b/pkg/codec/encodings/type_codec.go index 1807df8c1..5b0d35b28 100644 --- a/pkg/codec/encodings/type_codec.go +++ b/pkg/codec/encodings/type_codec.go @@ -33,6 +33,11 @@ type TopLevelCodec interface { SizeAtTopLevel(numItems int) (int, error) } +type StructTypeCodec interface { + TypeCodec + FieldCodec(string) (TypeCodec, error) +} + // CodecFromTypeCodec maps TypeCodec to types.RemoteCodec, using the key as the itemType // If the TypeCodec is a TopLevelCodec, GetMaxEncodingSize and GetMaxDecodingSize will call SizeAtTopLevel instead of Size. type CodecFromTypeCodec map[string]TypeCodec @@ -45,9 +50,9 @@ type LenientCodecFromTypeCodec map[string]TypeCodec var _ types.RemoteCodec = &LenientCodecFromTypeCodec{} func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) { - ntcwt, ok := c[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return nil, err } tpe := ntcwt.GetType() @@ -59,9 +64,9 @@ func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) { } func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) ([]byte, error) { - ntcwt, ok := c[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return nil, err } if item != nil { @@ -86,14 +91,15 @@ func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) } func (c CodecFromTypeCodec) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) { - ntcwt, ok := c[itemType] - if !ok { - return 0, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return 0, err } if lp, ok := ntcwt.(TopLevelCodec); ok { return lp.SizeAtTopLevel(n) } + return ntcwt.Size(n) } @@ -121,11 +127,16 @@ func (c LenientCodecFromTypeCodec) Decode(ctx context.Context, raw []byte, into return decode(c, raw, into, itemType, false) } +func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) { + return c.GetMaxEncodingSize(ctx, n, itemType) +} + func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exactSize bool) error { - ntcwt, ok := c[itemType] - if !ok { - return fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return err } + val, remaining, err := ntcwt.Decode(raw) if err != nil { return err @@ -138,6 +149,29 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil) } -func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) { - return c.GetMaxEncodingSize(ctx, n, itemType) +func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) { + // itemType could recurse into nested structs + path := extendedItemType(itemType) + + // itemType could recurse into nested structs + head, tail := path.next() + if head == "" { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + ntcwt, ok := c[head] + if !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + if tail == "" { + return ntcwt, nil + } + + structType, ok := ntcwt.(StructTypeCodec) + if !ok { + return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType) + } + + return structType.FieldCodec(tail) } diff --git a/pkg/codec/encodings/type_codec_test.go b/pkg/codec/encodings/type_codec_test.go index 874819ff2..f23ed90ec 100644 --- a/pkg/codec/encodings/type_codec_test.go +++ b/pkg/codec/encodings/type_codec_test.go @@ -4,6 +4,7 @@ import ( rawbin "encoding/binary" "math" "reflect" + "strings" "testing" "github.com/smartcontractkit/libocr/bigbigendian" @@ -122,6 +123,34 @@ func TestCodecFromTypeCodecs(t *testing.T) { assert.Equal(t, singleItemSize*2, actual) }) + + t.Run("CreateType works for nested struct values and modifiers", func(t *testing.T) { + itemType := strings.Join([]string{TestItemWithConfigExtra, "AccountStruct", "Account"}, ".") + ts := CreateTestStruct(0, biit) + c := biit.GetCodec(t) + + encoded, err := c.Encode(tests.Context(t), ts.AccountStruct.Account, itemType) + require.NoError(t, err) + + var actual []byte + require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType)) + + assert.Equal(t, ts.AccountStruct.Account, actual) + }) + + t.Run("CreateType works for nested struct values", func(t *testing.T) { + itemType := strings.Join([]string{TestItemType, "NestedDynamicStruct", "Inner", "S"}, ".") + ts := CreateTestStruct(0, biit) + c := biit.GetCodec(t) + + encoded, err := c.Encode(tests.Context(t), ts.NestedDynamicStruct.Inner.S, itemType) + require.NoError(t, err) + + var actual string + require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType)) + + assert.Equal(t, ts.NestedDynamicStruct.Inner.S, actual) + }) } type interfaceTesterBase struct{} @@ -319,7 +348,7 @@ func (b *bigEndianInterfaceTester) GetCodec(t *testing.T) types.Codec { modCodec, err := codec.NewModifierCodec(c, byTypeMod, codec.BigIntHook) require.NoError(t, err) - _, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), TestItemWithConfigExtra) + _, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), "") require.NoError(t, err) return modCodec diff --git a/pkg/codec/hard_coder.go b/pkg/codec/hard_coder.go index 9f946fa04..60e2633e7 100644 --- a/pkg/codec/hard_coder.go +++ b/pkg/codec/hard_coder.go @@ -2,6 +2,7 @@ package codec import ( "fmt" + "log" "reflect" "strings" @@ -81,7 +82,8 @@ func verifyHardCodeKeys(values map[string]any) error { return nil } -func (o *onChainHardCoder) TransformToOnChain(offChainValue any, _ string) (any, error) { +func (o *onChainHardCoder) TransformToOnChain(offChainValue any, itemType string) (any, error) { + log.Println(itemType) return transformWithMaps(offChainValue, o.offToOnChainType, o.onChain, hardCode, o.hooks...) } diff --git a/pkg/codec/modifier_base.go b/pkg/codec/modifier_base.go index 8a092fe9b..c50fe5245 100644 --- a/pkg/codec/modifier_base.go +++ b/pkg/codec/modifier_base.go @@ -17,9 +17,16 @@ type modifierBase[T any] struct { offToOnChainType map[reflect.Type]reflect.Type modifyFieldForInput func(pkgPath string, outputField *reflect.StructField, fullPath string, change T) error addFieldForInput func(pkgPath, name string, change T) reflect.StructField + onChainStructType reflect.Type + offChainStructType reflect.Type } +// RetypeToOffChain sets the on-chain and off-chain types for modifications. If itemType is empty, the type returned +// will be the full off-chain type and all type mappings will be reset. If itemType is not empty, retyping assumes a +// sub-field is expected and the off-chain type of the sub-field is returned with no modifications to internal type +// mappings. func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType string) (tpe reflect.Type, err error) { + // onChainType could be the entire struct or a sub-field type defer func() { // StructOf can panic if the fields are not valid if r := recover(); r != nil { @@ -27,48 +34,71 @@ func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType st err = fmt.Errorf("%w: %v", types.ErrInvalidType, r) } }() + + // if itemType is empty, store the type mappings + // if itemType is not empty, assume a sub-field property is expected to be extracted + onChainStructType := onChainType + if itemType != "" { + onChainStructType = m.onChainStructType + } + + // this will only work for the full on-chain struct type unless we cache the individual + // field types too. + if cached, ok := m.onToOffChainType[onChainStructType]; ok { + return typeForPath(cached, itemType) + } + if len(m.fields) == 0 { m.offToOnChainType[onChainType] = onChainType m.onToOffChainType[onChainType] = onChainType - return onChainType, nil - } + m.onChainStructType = onChainType + m.offChainStructType = onChainType - if cached, ok := m.onToOffChainType[onChainType]; ok { - return cached, nil + return typeForPath(onChainType, itemType) } var offChainType reflect.Type - switch onChainType.Kind() { + + // the onChainStructType here should always reference the full on-chain struct type + switch onChainStructType.Kind() { case reflect.Pointer: - elm, err := m.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = m.RetypeToOffChain(onChainStructType.Elem(), itemType); err != nil { return nil, err } offChainType = reflect.PointerTo(elm) case reflect.Slice: - elm, err := m.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = m.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } offChainType = reflect.SliceOf(elm) case reflect.Array: - elm, err := m.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = m.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } - offChainType = reflect.ArrayOf(onChainType.Len(), elm) + offChainType = reflect.ArrayOf(onChainStructType.Len(), elm) case reflect.Struct: - return m.getStructType(onChainType) + if offChainType, err = m.getStructType(onChainStructType); err != nil { + return nil, err + } default: - return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainType.Kind()) + return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainStructType.Kind()) } - m.onToOffChainType[onChainType] = offChainType - m.offToOnChainType[offChainType] = onChainType - return offChainType, nil + m.onToOffChainType[onChainStructType] = offChainType + m.offToOnChainType[offChainType] = onChainStructType + m.onChainStructType = onChainType + m.offChainStructType = offChainType + + return typeForPath(offChainType, itemType) } func (m *modifierBase[T]) getStructType(outputType reflect.Type) (reflect.Type, error) { @@ -78,10 +108,11 @@ func (m *modifierBase[T]) getStructType(outputType reflect.Type) (reflect.Type, } for _, key := range m.subkeysFirst() { + curLocations := filedLocations parts := strings.Split(key, ".") fieldName := parts[len(parts)-1] + parts = parts[:len(parts)-1] - curLocations := filedLocations for _, part := range parts { if curLocations, err = curLocations.populateSubFields(part); err != nil { return nil, err @@ -102,10 +133,7 @@ func (m *modifierBase[T]) getStructType(outputType reflect.Type) (reflect.Type, } } - newStruct := filedLocations.makeNewType() - m.onToOffChainType[outputType] = newStruct - m.offToOnChainType[newStruct] = outputType - return newStruct, nil + return filedLocations.makeNewType(), nil } // subkeysFirst returns a list of keys that will always have a sub-key before the key if both are present @@ -122,6 +150,34 @@ func (m *modifierBase[T]) subkeysFirst() []string { return orderedKeys } +func (m *modifierBase[T]) onToOffChainTyper(onChainType reflect.Type, itemType string) (reflect.Type, error) { + onChainRefType := onChainType + if itemType != "" { + onChainRefType = m.onChainStructType + } + + offChainType, ok := m.onToOffChainType[onChainRefType] + if !ok { + return nil, fmt.Errorf("%w: cannot rename unknown type %v", types.ErrInvalidType, onChainType) + } + + return typeForPath(offChainType, itemType) +} + +func (m *modifierBase[T]) offToOnChainTyper(offChainType reflect.Type, itemType string) (reflect.Type, error) { + offChainRefType := offChainType + if itemType != "" { + offChainRefType = m.offChainStructType + } + + onChainType, ok := m.offToOnChainType[offChainRefType] + if !ok { + return nil, fmt.Errorf("%w: cannot rename unknown type %v", types.ErrInvalidType, offChainType) + } + + return typeForPath(onChainType, itemType) +} + // subkeysLast returns a list of keys that will always have a sub-key after the key if both are present func subkeysLast[T any](fields map[string]T) []string { orderedKeys := make([]string, 0, len(fields)) @@ -130,6 +186,7 @@ func subkeysLast[T any](fields map[string]T) []string { } sort.Strings(orderedKeys) + return orderedKeys } @@ -264,6 +321,39 @@ func doForMapElements[T any](valueMapping map[string]any, fields map[string]T, f return nil } +func typeForPath(from reflect.Type, itemType string) (reflect.Type, error) { + if itemType == "" { + return from, nil + } + + switch from.Kind() { + case reflect.Pointer: + elem, err := typeForPath(from.Elem(), itemType) + if err != nil { + return nil, err + } + + return elem, nil + case reflect.Array, reflect.Slice: + return nil, fmt.Errorf("%w: cannot extract a field from an array or slice", types.ErrInvalidType) + case reflect.Struct: + head, tail := extendedItemType(itemType).next() + + field, ok := from.FieldByName(head) + if !ok { + return nil, fmt.Errorf("%w: field not found for path %s and itemType %s", types.ErrInvalidType, from, itemType) + } + + if tail == "" { + return field.Type, nil + } + + return typeForPath(field.Type, tail) + default: + return nil, fmt.Errorf("%w: cannot extract a field from kind %s", types.ErrInvalidType, from.Kind()) + } +} + type PathMappingError struct { Err error Path string @@ -276,3 +366,18 @@ func (e PathMappingError) Error() string { func (e PathMappingError) Cause() error { return e.Err } + +type extendedItemType string + +func (t extendedItemType) next() (string, string) { + if string(t) == "" { + return "", "" + } + + path := strings.Split(string(t), ".") + if len(path) == 1 { + return path[0], "" + } + + return path[0], strings.Join(path[1:], ".") +} diff --git a/pkg/codec/renamer.go b/pkg/codec/renamer.go index b2414964a..845abeef4 100644 --- a/pkg/codec/renamer.go +++ b/pkg/codec/renamer.go @@ -2,7 +2,9 @@ package codec import ( "fmt" + "log" "reflect" + "strings" "unicode" "github.com/smartcontractkit/chainlink-common/pkg/types" @@ -30,26 +32,72 @@ type renamer struct { modifierBase[string] } -func (r *renamer) TransformToOffChain(onChainValue any, _ string) (any, error) { - rOutput, err := renameTransform(r.onToOffChainType, reflect.ValueOf(onChainValue)) +func (r *renamer) TransformToOffChain(onChainValue any, itemType string) (any, error) { + // itemType references the on-chain type + // remap to the off-chain field name + if itemType != "" { + var ref string + + parts := strings.Split(itemType, ".") + if len(parts) > 0 { + ref = parts[len(parts)-1] + } + + for on, off := range r.fields { + if ref == on { + // B.A -> C == B.C + parts[len(parts)-1] = off + itemType = strings.Join(parts, ".") + + break + } + } + } + + rOutput, err := renameTransform(r.onToOffChainTyper, reflect.ValueOf(onChainValue), itemType) if err != nil { return nil, err } + return rOutput.Interface(), nil } -func (r *renamer) TransformToOnChain(offChainValue any, _ string) (any, error) { - rOutput, err := renameTransform(r.offToOnChainType, reflect.ValueOf(offChainValue)) +func (r *renamer) TransformToOnChain(offChainValue any, itemType string) (any, error) { + log.Println(itemType) + if itemType != "" { + log.Println(itemType) + var ref string + + parts := strings.Split(itemType, ".") + if len(parts) > 0 { + ref = parts[len(parts)-1] + } + + for on, off := range r.fields { + if ref == off { + itemType = on + + break + } + } + } + + rOutput, err := renameTransform(r.offToOnChainTyper, reflect.ValueOf(offChainValue), itemType) if err != nil { return nil, err } + return rOutput.Interface(), nil } -func renameTransform(typeMap map[reflect.Type]reflect.Type, rInput reflect.Value) (reflect.Value, error) { - toType, ok := typeMap[rInput.Type()] - if !ok { - return reflect.Value{}, fmt.Errorf("%w: cannot rename unknown type %v", types.ErrInvalidType, toType) +func renameTransform( + typeFunc func(reflect.Type, string) (reflect.Type, error), + rInput reflect.Value, + itemType string, +) (reflect.Value, error) { + toType, err := typeFunc(rInput.Type(), itemType) + if err != nil { + return reflect.Value{}, err } if toType == rInput.Type() { @@ -70,6 +118,10 @@ func transformNonPointer(toType reflect.Type, rInput reflect.Value) (reflect.Val // make sure the input is addressable ptr := reflect.New(rInput.Type()) reflect.Indirect(ptr).Set(rInput) + + // UnsafePointer is a bit of a Go hack but works because the data types/structure and data for the two types + // are the same. The only change is the names of the fields. changed := reflect.NewAt(toType, ptr.UnsafePointer()).Elem() + return changed, nil } diff --git a/pkg/codec/renamer_test.go b/pkg/codec/renamer_test.go index 55453ff16..62fe47bee 100644 --- a/pkg/codec/renamer_test.go +++ b/pkg/codec/renamer_test.go @@ -385,6 +385,45 @@ func TestRenamer(t *testing.T) { require.NoError(t, err) assert.Equal(t, iOffchain.Interface(), newInput) }) + + t.Run("TransformToOnChain and TransformToOffChain works on nested fields even if the field itself is renamed for path", func(t *testing.T) { + offChainType, err := nestedRenamer.RetypeToOffChain(reflect.TypeOf(nestedTestStruct{}), "") + require.NoError(t, err) + iOffchain := reflect.Indirect(reflect.New(offChainType)) + + iOffchain.FieldByName("X").SetString("foo") + rY := iOffchain.FieldByName("Y") + rY.FieldByName("X").SetString("foo") + rY.FieldByName("B").SetInt(10) + rY.FieldByName("Z").SetInt(20) + + rC := iOffchain.FieldByName("C") + rC.Set(reflect.MakeSlice(rC.Type(), 2, 2)) + iElm := rC.Index(0) + iElm.FieldByName("X").SetString("foo") + iElm.FieldByName("B").SetInt(10) + iElm.FieldByName("Z").SetInt(20) + iElm = rC.Index(1) + iElm.FieldByName("X").SetString("baz") + iElm.FieldByName("B").SetInt(15) + iElm.FieldByName("Z").SetInt(25) + + iOffchain.FieldByName("D").SetString("bar") + + output, err := nestedRenamer.TransformToOnChain(iOffchain.FieldByName("Y").Interface(), "Y") + + require.NoError(t, err) + + expected := testStruct{ + A: "foo", + B: 10, + C: 20, + } + assert.Equal(t, expected, output) + newInput, err := nestedRenamer.TransformToOffChain(expected, "B") + require.NoError(t, err) + assert.Equal(t, iOffchain.FieldByName("Y").Interface(), newInput) + }) } func assertBasicRenameTransform(t *testing.T, offChainType reflect.Type) { From ea4e07962d39343c5a8150cb8ab53555ebb574e7 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 23 Jan 2025 09:53:57 -0600 Subject: [PATCH 2/8] hard coder modifier updated --- pkg/codec/by_item_type_modifier.go | 33 +++++++++++++++++++-- pkg/codec/byte_string_modifier.go | 4 +-- pkg/codec/element_extractor.go | 4 +-- pkg/codec/encodings/struct.go | 21 ++------------ pkg/codec/encodings/type_codec.go | 5 +--- pkg/codec/epoch_to_time.go | 4 +-- pkg/codec/hard_coder.go | 8 ++---- pkg/codec/hard_coder_test.go | 46 ++++++++++++++++++++++++++++++ pkg/codec/modifier_base.go | 22 +++++++------- pkg/codec/precodec.go | 4 +-- pkg/codec/renamer.go | 3 -- pkg/codec/wrapper.go | 4 +-- 12 files changed, 105 insertions(+), 53 deletions(-) diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index 3009858ca..5c10df51e 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -2,6 +2,7 @@ package codec import ( "fmt" + "log" "reflect" "github.com/smartcontractkit/chainlink-common/pkg/types" @@ -15,20 +16,43 @@ func NewByItemTypeModifier(modByItemType map[string]Modifier) (Modifier, error) return &byItemTypeModifier{ modByitemType: modByItemType, + enableNesting: false, + }, nil +} + +// NewNestableByItemTypeModifier returns a Modifier that uses modByItemType to determine which Modifier to use for a +// given itemType. If itemType is structured as a dot-separated string like 'A.B.C', the first part 'A' will be used to +// match in the mod map and the remaining list will be provided to the found Modifier 'B.C'. +func NewNestableByItemTypeModifier(modByItemType map[string]Modifier) (Modifier, error) { + if modByItemType == nil { + modByItemType = map[string]Modifier{} + } + + return &byItemTypeModifier{ + modByitemType: modByItemType, + enableNesting: true, }, nil } type byItemTypeModifier struct { modByitemType map[string]Modifier + enableNesting bool } // RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this // function applies no modifications if a modifier by the specified name is not found. func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { - head, tail := extendedItemType(itemType).next() + head := itemType + tail := itemType + + if b.enableNesting { + head, tail = ItemTyper(itemType).Next() + } + log.Println("byItemTypeModifier", "RetypeToOffChain", onChainType, head, ":", tail) mod, ok := b.modByitemType[head] if !ok { + log.Println(mod) return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) } @@ -48,7 +72,12 @@ func (b *byItemTypeModifier) transform( itemType string, transform func(Modifier, any, string) (any, error), ) (any, error) { - head, tail := extendedItemType(itemType).next() + head := itemType + tail := itemType + + if b.enableNesting { + head, tail = ItemTyper(itemType).Next() + } if mod, ok := b.modByitemType[head]; ok { return transform(mod, val, tail) diff --git a/pkg/codec/byte_string_modifier.go b/pkg/codec/byte_string_modifier.go index 153cc6e20..5d7b66ac5 100644 --- a/pkg/codec/byte_string_modifier.go +++ b/pkg/codec/byte_string_modifier.go @@ -87,12 +87,12 @@ func (t *bytesToStringModifier) RetypeToOffChain(onChainType reflect.Type, _ str // TransformToOnChain uses the AddressModifier for string-to-address conversion. func (t *bytesToStringModifier) TransformToOnChain(offChainValue any, _ string) (any, error) { - return transformWithMaps(offChainValue, t.offToOnChainType, t.fields, noop, stringToAddressHookForOnChain(t.modifier)) + return transformWithMaps(offChainValue, t.offToOnChainTyper, "", t.fields, noop, stringToAddressHookForOnChain(t.modifier)) } // TransformToOffChain uses the AddressModifier for address-to-string conversion. func (t *bytesToStringModifier) TransformToOffChain(onChainValue any, _ string) (any, error) { - return transformWithMaps(onChainValue, t.onToOffChainType, t.fields, + return transformWithMaps(onChainValue, t.onToOffChainTyper, "", t.fields, addressTransformationAction(t.modifier.Length()), addressToStringHookForOffChain(t.modifier), ) diff --git a/pkg/codec/element_extractor.go b/pkg/codec/element_extractor.go index c6dbeebaf..73e4e3502 100644 --- a/pkg/codec/element_extractor.go +++ b/pkg/codec/element_extractor.go @@ -78,11 +78,11 @@ type elementExtractor struct { } func (e *elementExtractor) TransformToOnChain(offChainValue any, _ string) (any, error) { - return transformWithMaps(offChainValue, e.offToOnChainType, e.fields, extractMap) + return transformWithMaps(offChainValue, e.offToOnChainTyper, "", e.fields, extractMap) } func (e *elementExtractor) TransformToOffChain(onChainValue any, _ string) (any, error) { - return transformWithMaps(onChainValue, e.onToOffChainType, e.fields, expandMap) + return transformWithMaps(onChainValue, e.onToOffChainTyper, "", e.fields, expandMap) } func extractMap(extractMap map[string]any, key string, elementLocation *ElementExtractorLocation) error { diff --git a/pkg/codec/encodings/struct.go b/pkg/codec/encodings/struct.go index a7a474777..0e9bfda45 100644 --- a/pkg/codec/encodings/struct.go +++ b/pkg/codec/encodings/struct.go @@ -3,8 +3,8 @@ package encodings import ( "fmt" "reflect" - "strings" + "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types" ) @@ -122,10 +122,8 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) { } func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) { - path := extendedItemType(itemType) - // itemType could recurse into nested structs - fieldName, tail := path.next() + fieldName, tail := codec.ItemTyper(itemType).Next() if fieldName == "" { return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType) } @@ -148,18 +146,3 @@ func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) { return structType.FieldCodec(tail) } - -type extendedItemType string - -func (t extendedItemType) next() (string, string) { - if string(t) == "" { - return "", "" - } - - path := strings.Split(string(t), ".") - if len(path) == 1 { - return path[0], "" - } - - return path[0], strings.Join(path[1:], ".") -} diff --git a/pkg/codec/encodings/type_codec.go b/pkg/codec/encodings/type_codec.go index 5b0d35b28..edb719c7b 100644 --- a/pkg/codec/encodings/type_codec.go +++ b/pkg/codec/encodings/type_codec.go @@ -151,10 +151,7 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) { // itemType could recurse into nested structs - path := extendedItemType(itemType) - - // itemType could recurse into nested structs - head, tail := path.next() + head, tail := codec.ItemTyper(itemType).Next() if head == "" { return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) } diff --git a/pkg/codec/epoch_to_time.go b/pkg/codec/epoch_to_time.go index 287de807d..5c2695bcc 100644 --- a/pkg/codec/epoch_to_time.go +++ b/pkg/codec/epoch_to_time.go @@ -42,12 +42,12 @@ type timeToUnixModifier struct { func (t *timeToUnixModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { // since the hook will convert time.Time to epoch, we don't need to worry about converting them in the maps - return transformWithMaps(offChainValue, t.offToOnChainType, t.fields, noop, EpochToTimeHook, BigIntHook) + return transformWithMaps(offChainValue, t.offToOnChainTyper, "", t.fields, noop, EpochToTimeHook, BigIntHook) } func (t *timeToUnixModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { // since the hook will convert epoch to time.Time, we don't need to worry about converting them in the maps - return transformWithMaps(onChainValue, t.onToOffChainType, t.fields, noop, EpochToTimeHook, BigIntHook) + return transformWithMaps(onChainValue, t.onToOffChainTyper, "", t.fields, noop, EpochToTimeHook, BigIntHook) } func noop(_ map[string]any, _ string, _ bool) error { diff --git a/pkg/codec/hard_coder.go b/pkg/codec/hard_coder.go index 60e2633e7..291f110fb 100644 --- a/pkg/codec/hard_coder.go +++ b/pkg/codec/hard_coder.go @@ -2,7 +2,6 @@ package codec import ( "fmt" - "log" "reflect" "strings" @@ -83,15 +82,14 @@ func verifyHardCodeKeys(values map[string]any) error { } func (o *onChainHardCoder) TransformToOnChain(offChainValue any, itemType string) (any, error) { - log.Println(itemType) - return transformWithMaps(offChainValue, o.offToOnChainType, o.onChain, hardCode, o.hooks...) + return transformWithMaps(offChainValue, o.offToOnChainTyper, itemType, o.onChain, hardCode, o.hooks...) } -func (o *onChainHardCoder) TransformToOffChain(onChainValue any, _ string) (any, error) { +func (o *onChainHardCoder) TransformToOffChain(onChainValue any, itemType string) (any, error) { allHooks := make([]mapstructure.DecodeHookFunc, len(o.hooks)+1) copy(allHooks, o.hooks) allHooks[len(o.hooks)] = hardCodeManyHook - return transformWithMaps(onChainValue, o.onToOffChainType, o.fields, hardCode, allHooks...) + return transformWithMaps(onChainValue, o.onToOffChainTyper, itemType, o.fields, hardCode, allHooks...) } func hardCode(extractMap map[string]any, key string, item any) error { diff --git a/pkg/codec/hard_coder_test.go b/pkg/codec/hard_coder_test.go index 6dc1ba0ab..31ed1aace 100644 --- a/pkg/codec/hard_coder_test.go +++ b/pkg/codec/hard_coder_test.go @@ -469,6 +469,52 @@ func TestHardCoder(t *testing.T) { require.NoError(t, err) assert.Equal(t, int32(123), reflect.ValueOf(offChain).FieldByName("B").Interface()) }) + + t.Run("TransformToOnChain and TransformToOffChain works for itemType path", func(t *testing.T) { + nestedHardCoder, err := codec.NewHardCoder(map[string]any{ + "A": "Top", + "B.A": "Foo", + "B.C": []int32{2, 3}, + "C.A": "Foo", + "C.C": []int32{2, 3}, + }, map[string]any{ + "B.Z": "Bar", + "B.Q": []struct { + A int + B string + }{{1, "a"}, {2, "b"}}, + "C.Z": "Bar", + "C.Q": []struct { + A int + B string + }{{1, "a"}, {2, "b"}}, + }) + require.NoError(t, err) + + offChainType, err := nestedHardCoder.RetypeToOffChain(reflect.TypeOf(nestedTestStruct{}), "") + require.NoError(t, err) + + _, err = nestedHardCoder.RetypeToOffChain(reflect.TypeOf(""), "B.A") + require.NoError(t, err) + + iInput := reflect.Indirect(reflect.New(offChainType)) + iB := iInput.FieldByName("B") + iB.FieldByName("B").SetInt(1) + iC := iInput.FieldByName("C") + iC.Set(reflect.MakeSlice(iC.Type(), 2, 2)) + iC.Index(0).FieldByName("B").SetInt(2) + iC.Index(1).FieldByName("B").SetInt(3) + iInput.FieldByName("D").SetInt(1) + + actual, err := nestedHardCoder.TransformToOnChain(iInput.FieldByName("B").FieldByName("A").Interface(), "B.A") + require.NoError(t, err) + + expected := "Foo" + assert.Equal(t, expected, actual) + + _, err = nestedHardCoder.TransformToOffChain(expected, "B.A") + require.NoError(t, err) + }) } // Since we're using the on-chain values that have their hard-coded values set to diff --git a/pkg/codec/modifier_base.go b/pkg/codec/modifier_base.go index c50fe5245..a3c48a8c3 100644 --- a/pkg/codec/modifier_base.go +++ b/pkg/codec/modifier_base.go @@ -194,18 +194,19 @@ type mapAction[T any] func(extractMap map[string]any, key string, element T) err func transformWithMaps[T any]( item any, - typeMap map[reflect.Type]reflect.Type, + typeFn func(offChainType reflect.Type, itemType string) (reflect.Type, error), + itemType string, fields map[string]T, fn mapAction[T], hooks ...mapstructure.DecodeHookFunc) (any, error) { rItem := reflect.ValueOf(item) - toType, ok := typeMap[rItem.Type()] - if !ok { - return reflect.Value{}, fmt.Errorf("%w: cannot retype %v", types.ErrInvalidType, rItem.Type()) + toType, err := typeFn(rItem.Type(), itemType) + if err != nil { + return reflect.Value{}, err } - rOutput, err := transformWithMapsHelper(rItem, toType, fields, fn, hooks) + rOutput, err := transformWithMapsHelper(rItem, toType, itemType, fields, fn, hooks) if err != nil { return reflect.Value{}, err } @@ -216,6 +217,7 @@ func transformWithMaps[T any]( func transformWithMapsHelper[T any]( rItem reflect.Value, toType reflect.Type, + itemType string, fields map[string]T, fn mapAction[T], hooks []mapstructure.DecodeHookFunc) (reflect.Value, error) { @@ -229,7 +231,7 @@ func transformWithMapsHelper[T any]( return into, err } - tmp, err := transformWithMapsHelper(elm, toType.Elem(), fields, fn, hooks) + tmp, err := transformWithMapsHelper(elm, toType.Elem(), itemType, fields, fn, hooks) result := reflect.New(toType.Elem()) reflect.Indirect(result).Set(tmp) @@ -262,7 +264,7 @@ func doMany[T any](rInput, rOutput reflect.Value, fields map[string]T, fn mapAct inTmp := rInput.Index(i) outTmp := rOutput.Index(i) - output, err := transformWithMapsHelper(inTmp, outTmp.Type(), fields, fn, hooks) + output, err := transformWithMapsHelper(inTmp, outTmp.Type(), "", fields, fn, hooks) if err != nil { return err } @@ -337,7 +339,7 @@ func typeForPath(from reflect.Type, itemType string) (reflect.Type, error) { case reflect.Array, reflect.Slice: return nil, fmt.Errorf("%w: cannot extract a field from an array or slice", types.ErrInvalidType) case reflect.Struct: - head, tail := extendedItemType(itemType).next() + head, tail := ItemTyper(itemType).Next() field, ok := from.FieldByName(head) if !ok { @@ -367,9 +369,9 @@ func (e PathMappingError) Cause() error { return e.Err } -type extendedItemType string +type ItemTyper string -func (t extendedItemType) next() (string, string) { +func (t ItemTyper) Next() (string, string) { if string(t) == "" { return "", "" } diff --git a/pkg/codec/precodec.go b/pkg/codec/precodec.go index de5dec055..1be5b9945 100644 --- a/pkg/codec/precodec.go +++ b/pkg/codec/precodec.go @@ -60,7 +60,7 @@ func (pc *preCodec) TransformToOffChain(onChainValue any, _ string) (any, error) allHooks := make([]mapstructure.DecodeHookFunc, 1) allHooks[0] = hardCodeManyHook - return transformWithMaps(onChainValue, pc.onToOffChainType, pc.fields, pc.decodeFieldMapAction, allHooks...) + return transformWithMaps(onChainValue, pc.onToOffChainTyper, "", pc.fields, pc.decodeFieldMapAction, allHooks...) } func (pc *preCodec) decodeFieldMapAction(extractMap map[string]any, key string, typeDef string) error { @@ -90,7 +90,7 @@ func (pc *preCodec) TransformToOnChain(offChainValue any, _ string) (any, error) allHooks := make([]mapstructure.DecodeHookFunc, 1) allHooks[0] = hardCodeManyHook - return transformWithMaps(offChainValue, pc.offToOnChainType, pc.fields, pc.encodeFieldMapAction, allHooks...) + return transformWithMaps(offChainValue, pc.offToOnChainTyper, "", pc.fields, pc.encodeFieldMapAction, allHooks...) } func (pc *preCodec) encodeFieldMapAction(extractMap map[string]any, key string, typeDef string) error { diff --git a/pkg/codec/renamer.go b/pkg/codec/renamer.go index 845abeef4..51b4e8f63 100644 --- a/pkg/codec/renamer.go +++ b/pkg/codec/renamer.go @@ -2,7 +2,6 @@ package codec import ( "fmt" - "log" "reflect" "strings" "unicode" @@ -63,9 +62,7 @@ func (r *renamer) TransformToOffChain(onChainValue any, itemType string) (any, e } func (r *renamer) TransformToOnChain(offChainValue any, itemType string) (any, error) { - log.Println(itemType) if itemType != "" { - log.Println(itemType) var ref string parts := strings.Split(itemType, ".") diff --git a/pkg/codec/wrapper.go b/pkg/codec/wrapper.go index dd1061244..4461cff3c 100644 --- a/pkg/codec/wrapper.go +++ b/pkg/codec/wrapper.go @@ -30,11 +30,11 @@ type wrapperModifier struct { } func (t *wrapperModifier) TransformToOnChain(offChainValue any, _ string) (any, error) { - return transformWithMaps(offChainValue, t.offToOnChainType, t.fields, unwrapFieldMapAction) + return transformWithMaps(offChainValue, t.offToOnChainTyper, "", t.fields, unwrapFieldMapAction) } func (t *wrapperModifier) TransformToOffChain(onChainValue any, _ string) (any, error) { - return transformWithMaps(onChainValue, t.onToOffChainType, t.fields, wrapFieldMapAction) + return transformWithMaps(onChainValue, t.onToOffChainTyper, "", t.fields, wrapFieldMapAction) } func wrapFieldMapAction(typesMap map[string]any, fieldName string, wrappedFieldName string) error { From 670a4b3b35fdc94c2a786a7cbe2e8055c5a64651 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 23 Jan 2025 15:07:19 -0600 Subject: [PATCH 3/8] bypass nestable keys --- pkg/codec/encodings/type_codec.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pkg/codec/encodings/type_codec.go b/pkg/codec/encodings/type_codec.go index edb719c7b..79d08db30 100644 --- a/pkg/codec/encodings/type_codec.go +++ b/pkg/codec/encodings/type_codec.go @@ -158,7 +158,12 @@ func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) { ntcwt, ok := c[head] if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + if ntcwt, ok = c[itemType]; !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + // in this case, the codec is structured to not have nestable keys + return ntcwt, nil } if tail == "" { From b368873637cd17e9e0ec66aa1a1ffd89a0e0a804 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 23 Jan 2025 15:58:14 -0600 Subject: [PATCH 4/8] maintain backward compatibility --- pkg/codec/by_item_type_modifier.go | 1 - pkg/codec/hard_coder.go | 10 ++++++++++ pkg/codec/modifier_base.go | 7 +++++++ pkg/codec/renamer.go | 10 ++++++++++ 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index 5c10df51e..e2c2522f3 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -48,7 +48,6 @@ func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType if b.enableNesting { head, tail = ItemTyper(itemType).Next() } - log.Println("byItemTypeModifier", "RetypeToOffChain", onChainType, head, ":", tail) mod, ok := b.modByitemType[head] if !ok { diff --git a/pkg/codec/hard_coder.go b/pkg/codec/hard_coder.go index 291f110fb..f59096961 100644 --- a/pkg/codec/hard_coder.go +++ b/pkg/codec/hard_coder.go @@ -82,10 +82,20 @@ func verifyHardCodeKeys(values map[string]any) error { } func (o *onChainHardCoder) TransformToOnChain(offChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !o.modifierBase.enablePathTraverse { + itemType = "" + } + return transformWithMaps(offChainValue, o.offToOnChainTyper, itemType, o.onChain, hardCode, o.hooks...) } func (o *onChainHardCoder) TransformToOffChain(onChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !o.modifierBase.enablePathTraverse { + itemType = "" + } + allHooks := make([]mapstructure.DecodeHookFunc, len(o.hooks)+1) copy(allHooks, o.hooks) allHooks[len(o.hooks)] = hardCodeManyHook diff --git a/pkg/codec/modifier_base.go b/pkg/codec/modifier_base.go index a3c48a8c3..a269a6e00 100644 --- a/pkg/codec/modifier_base.go +++ b/pkg/codec/modifier_base.go @@ -12,6 +12,7 @@ import ( ) type modifierBase[T any] struct { + enablePathTraverse bool fields map[string]T onToOffChainType map[reflect.Type]reflect.Type offToOnChainType map[reflect.Type]reflect.Type @@ -35,6 +36,12 @@ func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType st } }() + // path traverse allows an item type of Struct.FieldA.NestedField to isolate modifiers + // associated with the nested field `NestedField`. + if !m.enablePathTraverse { + itemType = "" + } + // if itemType is empty, store the type mappings // if itemType is not empty, assume a sub-field property is expected to be extracted onChainStructType := onChainType diff --git a/pkg/codec/renamer.go b/pkg/codec/renamer.go index 51b4e8f63..0a47d9c81 100644 --- a/pkg/codec/renamer.go +++ b/pkg/codec/renamer.go @@ -32,6 +32,11 @@ type renamer struct { } func (r *renamer) TransformToOffChain(onChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !r.modifierBase.enablePathTraverse { + itemType = "" + } + // itemType references the on-chain type // remap to the off-chain field name if itemType != "" { @@ -62,6 +67,11 @@ func (r *renamer) TransformToOffChain(onChainValue any, itemType string) (any, e } func (r *renamer) TransformToOnChain(offChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !r.modifierBase.enablePathTraverse { + itemType = "" + } + if itemType != "" { var ref string From daabd1de7ddaa5ab296e0fe52a46e4c911be5331 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 23 Jan 2025 16:20:00 -0600 Subject: [PATCH 5/8] optional constructors for path traversal --- pkg/codec/hard_coder.go | 19 ++++++++++++++++--- pkg/codec/renamer.go | 11 ++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pkg/codec/hard_coder.go b/pkg/codec/hard_coder.go index f59096961..34c4981ce 100644 --- a/pkg/codec/hard_coder.go +++ b/pkg/codec/hard_coder.go @@ -11,9 +11,22 @@ import ( ) // NewHardCoder creates a modifier that will hard-code values for on-chain and off-chain types -// The modifier will override any values of the same name, if you need an overwritten value to be used in a different field, -// NewRenamer must be used before NewHardCoder. -func NewHardCoder(onChain map[string]any, offChain map[string]any, hooks ...mapstructure.DecodeHookFunc) (Modifier, error) { +// The modifier will override any values of the same name, if you need an overwritten value to be used in a different +// field. NewRenamer must be used before NewHardCoder. +func NewHardCoder( + onChain map[string]any, + offChain map[string]any, + hooks ...mapstructure.DecodeHookFunc, +) (Modifier, error) { + return NewPathTraverseHardCoder(onChain, offChain, false, hooks...) +} + +func NewPathTraverseHardCoder( + onChain map[string]any, + offChain map[string]any, + enablePathTraverse bool, + hooks ...mapstructure.DecodeHookFunc, +) (Modifier, error) { if err := verifyHardCodeKeys(onChain); err != nil { return nil, err } else if err = verifyHardCodeKeys(offChain); err != nil { diff --git a/pkg/codec/renamer.go b/pkg/codec/renamer.go index 0a47d9c81..328e3f670 100644 --- a/pkg/codec/renamer.go +++ b/pkg/codec/renamer.go @@ -10,11 +10,16 @@ import ( ) func NewRenamer(fields map[string]string) Modifier { + return NewPathTraverseRenamer(fields, false) +} + +func NewPathTraverseRenamer(fields map[string]string, enablePathTraverse bool) Modifier { m := &renamer{ modifierBase: modifierBase[string]{ - fields: fields, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: fields, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, } m.modifyFieldForInput = func(pkgPath string, field *reflect.StructField, _, newName string) error { From 1257f5a7b0d84632a501456785ab7ac599c64a3f Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 23 Jan 2025 19:43:11 -0600 Subject: [PATCH 6/8] make tests pass --- pkg/codec/by_item_type_modifier.go | 9 ++++----- pkg/codec/renamer_test.go | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index e2c2522f3..a768c2971 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -2,7 +2,6 @@ package codec import ( "fmt" - "log" "reflect" "github.com/smartcontractkit/chainlink-common/pkg/types" @@ -51,7 +50,6 @@ func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType mod, ok := b.modByitemType[head] if !ok { - log.Println(mod) return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) } @@ -78,11 +76,12 @@ func (b *byItemTypeModifier) transform( head, tail = ItemTyper(itemType).Next() } - if mod, ok := b.modByitemType[head]; ok { - return transform(mod, val, tail) + mod, ok := b.modByitemType[head] + if !ok { + return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) } - return val, nil + return transform(mod, val, tail) } var _ Modifier = &byItemTypeModifier{} diff --git a/pkg/codec/renamer_test.go b/pkg/codec/renamer_test.go index 62fe47bee..9fbc170c8 100644 --- a/pkg/codec/renamer_test.go +++ b/pkg/codec/renamer_test.go @@ -28,9 +28,9 @@ func TestRenamer(t *testing.T) { D string } - renamer := codec.NewRenamer(map[string]string{"A": "X", "C": "Z"}) - invalidRenamer := codec.NewRenamer(map[string]string{"W": "X", "C": "Z"}) - nestedRenamer := codec.NewRenamer(map[string]string{"A": "X", "B.A": "X", "B.C": "Z", "C.A": "X", "C.C": "Z", "B": "Y"}) + renamer := codec.NewPathTraverseRenamer(map[string]string{"A": "X", "C": "Z"}, true) + invalidRenamer := codec.NewPathTraverseRenamer(map[string]string{"W": "X", "C": "Z"}, true) + nestedRenamer := codec.NewPathTraverseRenamer(map[string]string{"A": "X", "B.A": "X", "B.C": "Z", "C.A": "X", "C.C": "Z", "B": "Y"}, true) t.Run("RetypeToOffChain renames fields keeping structure", func(t *testing.T) { offChainType, err := renamer.RetypeToOffChain(reflect.TypeOf(testStruct{}), "") require.NoError(t, err) From b2a29c63e28067bdc23182b1ecc35e5037d2ac80 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 23 Jan 2025 20:07:02 -0600 Subject: [PATCH 7/8] address comments --- pkg/codec/by_item_type_modifier.go | 3 ++- pkg/codec/encodings/struct.go | 5 +++++ pkg/codec/encodings/struct_test.go | 15 +++++++++++++++ pkg/codec/modifier.go | 12 ++++++++++++ pkg/codec/renamer.go | 2 +- 5 files changed, 35 insertions(+), 2 deletions(-) diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index a768c2971..f97bd27c4 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -39,7 +39,8 @@ type byItemTypeModifier struct { } // RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this -// function applies no modifications if a modifier by the specified name is not found. +// function returns an error if a modifier by the specified name is not found. If nesting is enabled, the itemType can +// be of the form `Path.To.Type` and this modifier will attempt to only match on `Path` to find a valid modifier. func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { head := itemType tail := itemType diff --git a/pkg/codec/encodings/struct.go b/pkg/codec/encodings/struct.go index 0e9bfda45..0baee9fb4 100644 --- a/pkg/codec/encodings/struct.go +++ b/pkg/codec/encodings/struct.go @@ -133,8 +133,13 @@ func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) { return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) } + if idx >= len(s.fields) { + return nil, fmt.Errorf("%w: invalid field index for type %s", types.ErrInvalidType, itemType) + } + codec := s.fields[idx] + // if itemType wasn't referencing a nested field if tail == "" { return codec, nil } diff --git a/pkg/codec/encodings/struct_test.go b/pkg/codec/encodings/struct_test.go index 0a9ace59c..f678d0aa8 100644 --- a/pkg/codec/encodings/struct_test.go +++ b/pkg/codec/encodings/struct_test.go @@ -14,6 +14,10 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" ) +type fieldCodec interface { + FieldCodec(string) (encodings.TypeCodec, error) +} + func TestStructCodec(t *testing.T) { t.Parallel() t.Run("NewStructCodec returns an error if names are repeated", func(t *testing.T) { @@ -176,6 +180,17 @@ func TestStructCodec(t *testing.T) { _, err := structCodecWithErr.SizeAtTopLevel(100) assert.Equal(t, errCodec.Err, err) }) + + t.Run("FieldCodec returns a nested field codec", func(t *testing.T) { + fc, ok := structCodec.(fieldCodec) + + require.True(t, ok) + + tc, err := fc.FieldCodec("Bar") + + require.NoError(t, err) + assert.Equal(t, reflect.PointerTo(reflect.TypeOf(uint64(0))), tc.GetType()) + }) } func toPointer[T any](t T) *T { diff --git a/pkg/codec/modifier.go b/pkg/codec/modifier.go index a25a44599..da3e7eda3 100644 --- a/pkg/codec/modifier.go +++ b/pkg/codec/modifier.go @@ -7,13 +7,25 @@ import ( // Modifier allows you to modify the off-chain type to be used on-chain, and vice-versa. // A modifier is set up by retyping the on-chain type to a type used off-chain. type Modifier interface { + // RetypeToOffChain will retype the onChainType to its correlated offChainType. The itemType should be empty for an + // expected whole struct. A dot-separated string can be provided when path traversal is supported on the modifier + // to retype a nested field. + // + // For most modifiers, RetypeToOffChain must be called first with the entire struct to be retyped/modified before + // any other transformations or path traversal can function. RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) // TransformToOnChain transforms a type returned from AdjustForInput into the outputType. // You may also pass a pointer to the type returned by AdjustForInput to get a pointer to outputType. + // + // Modifiers should also optionally provide support for path traversal using itemType. In the case of using path + // traversal, the offChainValue should be the field value being modified as identified by itemType. TransformToOnChain(offChainValue any, itemType string) (any, error) // TransformToOffChain is the reverse of TransformForOnChain input. // It is used to send back the object after it has been decoded + // + // Modifiers should also optionally provide support for path traversal using itemType. In the case of using path + // traversal, the onChainValue should be the field value being modified as identified by itemType. TransformToOffChain(onChainValue any, itemType string) (any, error) } diff --git a/pkg/codec/renamer.go b/pkg/codec/renamer.go index 328e3f670..fed6b3f53 100644 --- a/pkg/codec/renamer.go +++ b/pkg/codec/renamer.go @@ -43,7 +43,7 @@ func (r *renamer) TransformToOffChain(onChainValue any, itemType string) (any, e } // itemType references the on-chain type - // remap to the off-chain field name + // rename field/subfield path in itemType to match the modifier renaming if itemType != "" { var ref string From 3eb5f3b9abd8a092248180b91188a9a644ddd1b6 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Fri, 24 Jan 2025 08:49:06 -0600 Subject: [PATCH 8/8] all tests pass --- pkg/codec/encodings/type_codec_test.go | 58 +++++++++++++++++++++++++- pkg/codec/modifier_base.go | 10 +++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/pkg/codec/encodings/type_codec_test.go b/pkg/codec/encodings/type_codec_test.go index f23ed90ec..ee9110153 100644 --- a/pkg/codec/encodings/type_codec_test.go +++ b/pkg/codec/encodings/type_codec_test.go @@ -127,7 +127,7 @@ func TestCodecFromTypeCodecs(t *testing.T) { t.Run("CreateType works for nested struct values and modifiers", func(t *testing.T) { itemType := strings.Join([]string{TestItemWithConfigExtra, "AccountStruct", "Account"}, ".") ts := CreateTestStruct(0, biit) - c := biit.GetCodec(t) + c := biit.GetNestableCodec(t) encoded, err := c.Encode(tests.Context(t), ts.AccountStruct.Account, itemType) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestCodecFromTypeCodecs(t *testing.T) { t.Run("CreateType works for nested struct values", func(t *testing.T) { itemType := strings.Join([]string{TestItemType, "NestedDynamicStruct", "Inner", "S"}, ".") ts := CreateTestStruct(0, biit) - c := biit.GetCodec(t) + c := biit.GetNestableCodec(t) encoded, err := c.Encode(tests.Context(t), ts.NestedDynamicStruct.Inner.S, itemType) require.NoError(t, err) @@ -354,6 +354,60 @@ func (b *bigEndianInterfaceTester) GetCodec(t *testing.T) types.Codec { return modCodec } +func (b *bigEndianInterfaceTester) GetNestableCodec(t *testing.T) types.Codec { + testStruct := newTestStructCodec(t, binary.BigEndian()) + size, err := binary.BigEndian().Int(1) + require.NoError(t, err) + slice, err := encodings.NewSlice(testStruct, size) + require.NoError(t, err) + arr1, err := encodings.NewArray(1, testStruct) + require.NoError(t, err) + arr2, err := encodings.NewArray(2, testStruct) + require.NoError(t, err) + + ts := CreateTestStruct(0, b) + + tc := &encodings.CodecFromTypeCodec{ + TestItemType: testStruct, + TestItemSliceType: slice, + TestItemArray1Type: arr1, + TestItemArray2Type: arr2, + TestItemWithConfigExtra: testStruct, + NilType: encodings.Empty{}, + } + + require.NoError(t, err) + + var c types.RemoteCodec = tc + if b.lenient { + c = (*encodings.LenientCodecFromTypeCodec)(tc) + } + + mod, err := codec.NewPathTraverseHardCoder(map[string]any{ + "BigField": ts.BigField.String(), + "AccountStruct.Account": ts.AccountStruct.Account, + }, map[string]any{"ExtraField": AnyExtraValue}, true, codec.BigIntHook) + require.NoError(t, err) + + byTypeMod, err := codec.NewNestableByItemTypeModifier(map[string]codec.Modifier{ + TestItemType: codec.MultiModifier{}, + TestItemSliceType: codec.MultiModifier{}, + TestItemArray1Type: codec.MultiModifier{}, + TestItemArray2Type: codec.MultiModifier{}, + TestItemWithConfigExtra: mod, + NilType: codec.MultiModifier{}, + }) + require.NoError(t, err) + + modCodec, err := codec.NewModifierCodec(c, byTypeMod, codec.BigIntHook) + require.NoError(t, err) + + _, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), "") + require.NoError(t, err) + + return modCodec +} + func (b *bigEndianInterfaceTester) IncludeArrayEncodingSizeEnforcement() bool { return true } diff --git a/pkg/codec/modifier_base.go b/pkg/codec/modifier_base.go index a269a6e00..ba86a4e0c 100644 --- a/pkg/codec/modifier_base.go +++ b/pkg/codec/modifier_base.go @@ -97,6 +97,11 @@ func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType st return nil, err } default: + // if the types don't match, it means we are attempting to traverse the main struct + if onChainType != m.onChainStructType { + return onChainType, nil + } + return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainStructType.Kind()) } @@ -213,6 +218,10 @@ func transformWithMaps[T any]( return reflect.Value{}, err } + if rItem.Type() == toType { + return rItem.Interface(), nil + } + rOutput, err := transformWithMapsHelper(rItem, toType, itemType, fields, fn, hooks) if err != nil { return reflect.Value{}, err @@ -260,6 +269,7 @@ func transformWithMapsHelper[T any]( return into, err default: + panic("error") return reflect.Value{}, fmt.Errorf("%w: cannot retype %v", types.ErrInvalidType, rItem.Type()) } }