diff --git a/decode.go b/decode.go index 1fc13009..4a10e89d 100644 --- a/decode.go +++ b/decode.go @@ -408,8 +408,14 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { case reflect.Array: return d.decodeArray(dst, src) case reflect.Slice: + if mapSlice, ok := dst.Addr().Interface().(*MapSlice); ok { + return d.decodeMapSlice(mapSlice, src) + } return d.decodeSlice(dst, src) case reflect.Struct: + if mapItem, ok := dst.Addr().Interface().(*MapItem); ok { + return d.decodeMapItem(mapItem, src) + } return d.decodeStruct(dst, src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: v := d.nodeToValue(src) @@ -882,6 +888,65 @@ func (d *Decoder) decodeSlice(dst reflect.Value, src ast.Node) error { return nil } +func (d *Decoder) decodeMapItem(dst *MapItem, src ast.Node) error { + mapNode, err := d.getMapNode(src) + if err != nil { + return errors.Wrapf(err, "failed to get map node") + } + if mapNode == nil { + return nil + } + mapIter := mapNode.MapRange() + if !mapIter.Next() { + return nil + } + key := mapIter.Key() + value := mapIter.Value() + if key.Type() == ast.MergeKeyType { + if err := d.decodeMapItem(dst, value); err != nil { + return errors.Wrapf(err, "failed to decode map with merge key") + } + return nil + } + *dst = MapItem{ + Key: d.nodeToValue(key), + Value: d.nodeToValue(value), + } + return nil +} + +func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error { + mapNode, err := d.getMapNode(src) + if err != nil { + return errors.Wrapf(err, "failed to get map node") + } + if mapNode == nil { + return nil + } + mapSlice := MapSlice{} + mapIter := mapNode.MapRange() + for mapIter.Next() { + key := mapIter.Key() + value := mapIter.Value() + if key.Type() == ast.MergeKeyType { + var m MapSlice + if err := d.decodeMapSlice(&m, value); err != nil { + return errors.Wrapf(err, "failed to decode map with merge key") + } + for _, v := range m { + mapSlice = append(mapSlice, v) + } + continue + } + mapSlice = append(mapSlice, MapItem{ + Key: d.nodeToValue(key), + Value: d.nodeToValue(value), + }) + } + *dst = mapSlice + return nil +} + func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error { mapNode, err := d.getMapNode(src) if err != nil { diff --git a/decode_test.go b/decode_test.go index 323d3f8f..27dacc5d 100644 --- a/decode_test.go +++ b/decode_test.go @@ -566,6 +566,18 @@ func TestDecoder(t *testing.T) { }{1}, }, + { + "a: 1\n", + yaml.MapItem{Key: "a", Value: 1}, + }, + { + "a: 1\nb: 2\nc: 3\n", + yaml.MapSlice{ + {Key: "a", Value: 1}, + {Key: "b", Value: 2}, + {Key: "c", Value: 3}, + }, + }, { "v:\n- A\n- 1\n- B:\n - 2\n - 3\n", map[string]interface{}{