diff --git a/decode.go b/decode.go index bbee6d07..61035b3d 100644 --- a/decode.go +++ b/decode.go @@ -214,6 +214,26 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} { return nil } +func (d *Decoder) resolveAlias(node ast.Node) ast.Node { + switch n := node.(type) { + case *ast.MappingNode: + for idx, value := range n.Values { + n.Values[idx] = d.resolveAlias(value).(*ast.MappingValueNode) + } + case *ast.MappingValueNode: + n.Key = d.resolveAlias(n.Key) + n.Value = d.resolveAlias(n.Value) + case *ast.SequenceNode: + for idx, value := range n.Values { + n.Values[idx] = d.resolveAlias(value) + } + case *ast.AliasNode: + aliasName := n.Value.GetToken().Value + return d.anchorNodeMap[aliasName] + } + return node +} + func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { if _, ok := node.(*ast.NullNode); ok { return nil, nil @@ -399,6 +419,7 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { } valueType := dst.Type() if unmarshaler, ok := dst.Addr().Interface().(BytesUnmarshaler); ok { + src = d.resolveAlias(src) var b string if scalar, isScalar := src.(ast.ScalarNode); isScalar { b = fmt.Sprint(scalar.GetValue()) diff --git a/decode_test.go b/decode_test.go index ef995bb9..90b954d1 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1875,3 +1875,52 @@ k: l t.Fatal("failed to stream decoding") } } + +type unmarshalYAMLWithAliasString string + +func (v *unmarshalYAMLWithAliasString) UnmarshalYAML(b []byte) error { + var s string + if err := yaml.Unmarshal(b, &s); err != nil { + return err + } + *v = unmarshalYAMLWithAliasString(s) + return nil +} + +type unmarshalYAMLWithAliasMap map[string]interface{} + +func (v *unmarshalYAMLWithAliasMap) UnmarshalYAML(b []byte) error { + var m map[string]interface{} + if err := yaml.Unmarshal(b, &m); err != nil { + return err + } + *v = unmarshalYAMLWithAliasMap(m) + return nil +} + +func TestDecoder_UnmarshalYAMLWithAlias(t *testing.T) { + yml := ` +anchors: + x: &x hello + map: &y + a: b + c: d +a: *x +b: + <<: *y + e: f +` + var v struct { + A unmarshalYAMLWithAliasString + B unmarshalYAMLWithAliasMap + } + if err := yaml.Unmarshal([]byte(yml), &v); err != nil { + t.Fatalf("%+v", err) + } + if v.A != "hello" { + t.Fatal("failed to unmarshal with alias") + } + if len(v.B) != 3 { + t.Fatal("failed to unmarshal with alias") + } +}