diff --git a/decode.go b/decode.go index e1c1f6d8..b9b21f9f 100644 --- a/decode.go +++ b/decode.go @@ -24,7 +24,8 @@ import ( type Decoder struct { reader io.Reader referenceReaders []io.Reader - anchorMap map[string]ast.Node + anchorNodeMap map[string]ast.Node + anchorValueMap map[string]reflect.Value opts []DecodeOption referenceFiles []string referenceDirs []string @@ -38,7 +39,8 @@ type Decoder struct { func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder { return &Decoder{ reader: r, - anchorMap: map[string]ast.Node{}, + anchorNodeMap: map[string]ast.Node{}, + anchorValueMap: map[string]reflect.Value{}, opts: opts, referenceReaders: []io.Reader{}, referenceFiles: []string{}, @@ -131,11 +133,12 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} { case *ast.AnchorNode: anchorName := n.Name.GetToken().Value anchorValue := d.nodeToValue(n.Value) - d.anchorMap[anchorName] = n.Value + d.anchorNodeMap[anchorName] = n.Value return anchorValue case *ast.AliasNode: aliasName := n.Value.GetToken().Value - return d.nodeToValue(d.anchorMap[aliasName]) + node := d.anchorNodeMap[aliasName] + return d.nodeToValue(node) case *ast.LiteralNode: return n.Value.GetValue() case *ast.MappingValueNode: @@ -177,12 +180,15 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { } if alias, ok := node.(*ast.AliasNode); ok { aliasName := alias.Value.GetToken().Value - anchorNode := d.anchorMap[aliasName] - mapNode, ok := anchorNode.(ast.MapNode) + node := d.anchorNodeMap[aliasName] + if node == nil { + return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName) + } + mapNode, ok := node.(ast.MapNode) if ok { return mapNode, nil } - return nil, xerrors.Errorf("%s node doesn't MapNode", anchorNode.Type()) + return nil, xerrors.Errorf("%s node doesn't MapNode", node.Type()) } mapNode, ok := node.(ast.MapNode) if !ok { @@ -204,12 +210,15 @@ func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) { } if alias, ok := node.(*ast.AliasNode); ok { aliasName := alias.Value.GetToken().Value - anchorNode := d.anchorMap[aliasName] - arrayNode, ok := anchorNode.(ast.ArrayNode) + node := d.anchorNodeMap[aliasName] + if node == nil { + return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName) + } + arrayNode, ok := node.(ast.ArrayNode) if ok { return arrayNode, nil } - return nil, xerrors.Errorf("%s node doesn't ArrayNode", anchorNode.Type()) + return nil, xerrors.Errorf("%s node doesn't ArrayNode", node.Type()) } arrayNode, ok := node.(ast.ArrayNode) if !ok { @@ -321,6 +330,10 @@ func (d *Decoder) deleteStructKeys(structValue reflect.Value, unknownFields map[ } func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { + if src.Type() == ast.AnchorType { + anchorName := src.(*ast.AnchorNode).Name.GetToken().Value + d.anchorValueMap[anchorName] = dst + } valueType := dst.Type() if unmarshaler, ok := dst.Addr().Interface().(BytesUnmarshaler); ok { b := fmt.Sprintf("%v", src) @@ -460,6 +473,21 @@ func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type return value } +func (d *Decoder) createDecodedNewValue(typ reflect.Type, node ast.Node) (reflect.Value, error) { + if node.Type() == ast.AliasType { + aliasName := node.(*ast.AliasNode).Value.GetToken().Value + newValue := d.anchorValueMap[aliasName] + if newValue.IsValid() { + return newValue, nil + } + } + newValue := d.createDecodableValue(typ) + if err := d.decodeValue(newValue, node); err != nil { + return newValue, errors.Wrapf(err, "failed to decode value") + } + return newValue, nil +} + func (d *Decoder) keyToNodeMap(node ast.Node, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) { mapNode, err := d.getMapNode(node) if err != nil { @@ -614,9 +642,7 @@ func (d *Decoder) decodeStruct(dst reflect.Value, src ast.Node) error { fieldValue.Set(reflect.Zero(fieldValue.Type())) continue } - newFieldValue := d.createDecodableValue(fieldValue.Type()) - err := d.decodeValue(newFieldValue, src) - + newFieldValue, err := d.createDecodedNewValue(fieldValue.Type(), src) if d.disallowUnknownField { var ufe *unknownFieldError if xerrors.As(err, &ufe) { @@ -663,8 +689,8 @@ func (d *Decoder) decodeStruct(dst reflect.Value, src ast.Node) error { fieldValue.Set(reflect.Zero(fieldValue.Type())) continue } - newFieldValue := d.createDecodableValue(fieldValue.Type()) - if err := d.decodeValue(newFieldValue, v); err != nil { + newFieldValue, err := d.createDecodedNewValue(fieldValue.Type(), v) + if err != nil { if foundErr != nil { continue } @@ -732,8 +758,8 @@ func (d *Decoder) decodeArray(dst reflect.Value, src ast.Node) error { // set nil value to pointer arrayValue.Index(idx).Set(reflect.Zero(elemType)) } else { - dstValue := d.createDecodableValue(elemType) - if err := d.decodeValue(dstValue, v); err != nil { + dstValue, err := d.createDecodedNewValue(elemType, v) + if err != nil { if foundErr == nil { foundErr = err } @@ -772,8 +798,8 @@ func (d *Decoder) decodeSlice(dst reflect.Value, src ast.Node) error { sliceValue = reflect.Append(sliceValue, reflect.Zero(elemType)) continue } - dstValue := d.createDecodableValue(elemType) - if err := d.decodeValue(dstValue, v); err != nil { + dstValue, err := d.createDecodedNewValue(elemType, v) + if err != nil { if foundErr == nil { foundErr = err } @@ -815,8 +841,8 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error { mapValue.SetMapIndex(k, reflect.Zero(valueType)) continue } - dstValue := d.createDecodableValue(valueType) - if err := d.decodeValue(dstValue, value); err != nil { + dstValue, err := d.createDecodedNewValue(valueType, value) + if err != nil { if foundErr == nil { foundErr = err }