Skip to content

Commit

Permalink
Reuse decoded value for alias
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy committed Jan 9, 2020
1 parent 17e1bea commit 7a0a7ff
Showing 1 changed file with 47 additions and 21 deletions.
68 changes: 47 additions & 21 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import (
type Decoder struct {
reader io.Reader
referenceReaders []io.Reader
anchorMap map[string]ast.Node
anchorNodeMap map[string]ast.Node
anchorValueMap map[string]reflect.Value
opts []DecodeOption
referenceFiles []string
referenceDirs []string
Expand All @@ -38,7 +39,8 @@ type Decoder struct {
func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
return &Decoder{
reader: r,
anchorMap: map[string]ast.Node{},
anchorNodeMap: map[string]ast.Node{},
anchorValueMap: map[string]reflect.Value{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
Expand Down Expand Up @@ -131,11 +133,12 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} {
case *ast.AnchorNode:
anchorName := n.Name.GetToken().Value
anchorValue := d.nodeToValue(n.Value)
d.anchorMap[anchorName] = n.Value
d.anchorNodeMap[anchorName] = n.Value
return anchorValue
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
return d.nodeToValue(d.anchorMap[aliasName])
node := d.anchorNodeMap[aliasName]
return d.nodeToValue(node)
case *ast.LiteralNode:
return n.Value.GetValue()
case *ast.MappingValueNode:
Expand Down Expand Up @@ -177,12 +180,15 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
anchorNode := d.anchorMap[aliasName]
mapNode, ok := anchorNode.(ast.MapNode)
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName)
}
mapNode, ok := node.(ast.MapNode)
if ok {
return mapNode, nil
}
return nil, xerrors.Errorf("%s node doesn't MapNode", anchorNode.Type())
return nil, xerrors.Errorf("%s node doesn't MapNode", node.Type())
}
mapNode, ok := node.(ast.MapNode)
if !ok {
Expand All @@ -204,12 +210,15 @@ func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
anchorNode := d.anchorMap[aliasName]
arrayNode, ok := anchorNode.(ast.ArrayNode)
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName)
}
arrayNode, ok := node.(ast.ArrayNode)
if ok {
return arrayNode, nil
}
return nil, xerrors.Errorf("%s node doesn't ArrayNode", anchorNode.Type())
return nil, xerrors.Errorf("%s node doesn't ArrayNode", node.Type())
}
arrayNode, ok := node.(ast.ArrayNode)
if !ok {
Expand Down Expand Up @@ -321,6 +330,10 @@ func (d *Decoder) deleteStructKeys(structValue reflect.Value, unknownFields map[
}

func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
if src.Type() == ast.AnchorType {
anchorName := src.(*ast.AnchorNode).Name.GetToken().Value
d.anchorValueMap[anchorName] = dst
}
valueType := dst.Type()
if unmarshaler, ok := dst.Addr().Interface().(BytesUnmarshaler); ok {
b := fmt.Sprintf("%v", src)
Expand Down Expand Up @@ -460,6 +473,21 @@ func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type
return value
}

func (d *Decoder) createDecodedNewValue(typ reflect.Type, node ast.Node) (reflect.Value, error) {
if node.Type() == ast.AliasType {
aliasName := node.(*ast.AliasNode).Value.GetToken().Value
newValue := d.anchorValueMap[aliasName]
if newValue.IsValid() {
return newValue, nil
}
}
newValue := d.createDecodableValue(typ)
if err := d.decodeValue(newValue, node); err != nil {
return newValue, errors.Wrapf(err, "failed to decode value")
}
return newValue, nil
}

func (d *Decoder) keyToNodeMap(node ast.Node, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
mapNode, err := d.getMapNode(node)
if err != nil {
Expand Down Expand Up @@ -614,9 +642,7 @@ func (d *Decoder) decodeStruct(dst reflect.Value, src ast.Node) error {
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
newFieldValue := d.createDecodableValue(fieldValue.Type())
err := d.decodeValue(newFieldValue, src)

newFieldValue, err := d.createDecodedNewValue(fieldValue.Type(), src)
if d.disallowUnknownField {
var ufe *unknownFieldError
if xerrors.As(err, &ufe) {
Expand Down Expand Up @@ -663,8 +689,8 @@ func (d *Decoder) decodeStruct(dst reflect.Value, src ast.Node) error {
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
newFieldValue := d.createDecodableValue(fieldValue.Type())
if err := d.decodeValue(newFieldValue, v); err != nil {
newFieldValue, err := d.createDecodedNewValue(fieldValue.Type(), v)
if err != nil {
if foundErr != nil {
continue
}
Expand Down Expand Up @@ -732,8 +758,8 @@ func (d *Decoder) decodeArray(dst reflect.Value, src ast.Node) error {
// set nil value to pointer
arrayValue.Index(idx).Set(reflect.Zero(elemType))
} else {
dstValue := d.createDecodableValue(elemType)
if err := d.decodeValue(dstValue, v); err != nil {
dstValue, err := d.createDecodedNewValue(elemType, v)
if err != nil {
if foundErr == nil {
foundErr = err
}
Expand Down Expand Up @@ -772,8 +798,8 @@ func (d *Decoder) decodeSlice(dst reflect.Value, src ast.Node) error {
sliceValue = reflect.Append(sliceValue, reflect.Zero(elemType))
continue
}
dstValue := d.createDecodableValue(elemType)
if err := d.decodeValue(dstValue, v); err != nil {
dstValue, err := d.createDecodedNewValue(elemType, v)
if err != nil {
if foundErr == nil {
foundErr = err
}
Expand Down Expand Up @@ -815,8 +841,8 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
mapValue.SetMapIndex(k, reflect.Zero(valueType))
continue
}
dstValue := d.createDecodableValue(valueType)
if err := d.decodeValue(dstValue, value); err != nil {
dstValue, err := d.createDecodedNewValue(valueType, value)
if err != nil {
if foundErr == nil {
foundErr = err
}
Expand Down

0 comments on commit 7a0a7ff

Please sign in to comment.