diff --git a/decode.go b/decode.go index 20bf4cfb..d489cac9 100644 --- a/decode.go +++ b/decode.go @@ -89,8 +89,11 @@ func (d *Decoder) castToFloat(v interface{}) interface{} { func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) { switch n := node.(type) { case *ast.MappingValueNode: - if n.Key.Type() == ast.MergeKeyType { - d.setToMapValue(n.Value, m) + if n.Key.Type() == ast.MergeKeyType && n.Value.Type() == ast.AliasType { + aliasNode := n.Value.(*ast.AliasNode) + aliasName := aliasNode.Value.GetToken().Value + node := d.anchorNodeMap[aliasName] + d.setToMapValue(node, m) } else { key := n.Key.GetToken().Value m[key] = d.nodeToValue(n.Value) @@ -143,9 +146,12 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} { case *ast.LiteralNode: return n.Value.GetValue() case *ast.MappingValueNode: - if n.Key.Type() == ast.MergeKeyType { + if n.Key.Type() == ast.MergeKeyType && n.Value.Type() == ast.AliasType { + aliasNode := n.Value.(*ast.AliasNode) + aliasName := aliasNode.Value.GetToken().Value + node := d.anchorNodeMap[aliasName] m := map[string]interface{}{} - d.setToMapValue(n.Value, m) + d.setToMapValue(node, m) return m } key := n.Key.GetToken().Value @@ -846,11 +852,20 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error { keyType := mapValue.Type().Key() valueType := mapValue.Type().Elem() mapIter := mapNode.MapRange() - var foundErr error for mapIter.Next() { key := mapIter.Key() value := mapIter.Value() + if key.Type() == ast.MergeKeyType { + if err := d.decodeMap(dst, value); err != nil { + return errors.Wrapf(err, "failed to decode map with merge key") + } + iter := dst.MapRange() + for iter.Next() { + mapValue.SetMapIndex(iter.Key(), iter.Value()) + } + continue + } k := reflect.ValueOf(d.nodeToValue(key)) if k.IsValid() && k.Type().ConvertibleTo(keyType) { k = k.Convert(keyType) diff --git a/decode_test.go b/decode_test.go index 7547cf7c..1dfeb6be 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1082,6 +1082,87 @@ items: if v.Items[1].B != 1 || v.Items[1].C != "world" { t.Fatal("failed to decode with merge key") } + t.Run("decode with interface{}", func(t *testing.T) { + buf := bytes.NewBufferString(yml) + dec := yaml.NewDecoder(buf) + var v interface{} + if err := dec.Decode(&v); err != nil { + t.Fatalf("%+v", err) + } + items := v.(map[string]interface{})["items"].([]interface{}) + if len(items) != 2 { + t.Fatal("failed to decode with merge key") + } + b0 := items[0].(map[string]interface{})["b"] + if _, ok := b0.(uint64); !ok { + t.Fatal("failed to decode with merge key") + } + if b0.(uint64) != 1 { + t.Fatal("failed to decode with merge key") + } + c0 := items[0].(map[string]interface{})["c"] + if _, ok := c0.(string); !ok { + t.Fatal("failed to decode with merge key") + } + if c0.(string) != "hello" { + t.Fatal("failed to decode with merge key") + } + b1 := items[1].(map[string]interface{})["b"] + if _, ok := b1.(uint64); !ok { + t.Fatal("failed to decode with merge key") + } + if b1.(uint64) != 1 { + t.Fatal("failed to decode with merge key") + } + c1 := items[1].(map[string]interface{})["c"] + if _, ok := c1.(string); !ok { + t.Fatal("failed to decode with merge key") + } + if c1.(string) != "world" { + t.Fatal("failed to decode with merge key") + } + }) + t.Run("decode with map", func(t *testing.T) { + var v struct { + Items []map[string]interface{} + } + buf := bytes.NewBufferString(yml) + dec := yaml.NewDecoder(buf) + if err := dec.Decode(&v); err != nil { + t.Fatalf("%+v", err) + } + if len(v.Items) != 2 { + t.Fatal("failed to decode with merge key") + } + b0 := v.Items[0]["b"] + if _, ok := b0.(uint64); !ok { + t.Fatal("failed to decode with merge key") + } + if b0.(uint64) != 1 { + t.Fatal("failed to decode with merge key") + } + c0 := v.Items[0]["c"] + if _, ok := c0.(string); !ok { + t.Fatal("failed to decode with merge key") + } + if c0.(string) != "hello" { + t.Fatal("failed to decode with merge key") + } + b1 := v.Items[1]["b"] + if _, ok := b1.(uint64); !ok { + t.Fatal("failed to decode with merge key") + } + if b1.(uint64) != 1 { + t.Fatal("failed to decode with merge key") + } + c1 := v.Items[1]["c"] + if _, ok := c1.(string); !ok { + t.Fatal("failed to decode with merge key") + } + if c1.(string) != "world" { + t.Fatal("failed to decode with merge key") + } + }) } func TestDecoder_Inline(t *testing.T) {