diff --git a/decode.go b/decode.go index c01d404c..2cb7a44f 100644 --- a/decode.go +++ b/decode.go @@ -35,6 +35,8 @@ type Decoder struct { validator StructValidator disallowUnknownField bool useOrderedMap bool + parsedFile *ast.File + streamIndex int } // NewDecoder returns a new decoder that reads from r. @@ -1131,7 +1133,7 @@ func (d *Decoder) resolveReference() error { } // assign new anchor definition to anchorMap - if _, err := d.decode(bytes); err != nil { + if _, err := d.parse(bytes); err != nil { return errors.Wrapf(err, "failed to decode") } } @@ -1139,42 +1141,85 @@ func (d *Decoder) resolveReference() error { return nil } -func (d *Decoder) decode(bytes []byte) (ast.Node, error) { +func (d *Decoder) parse(bytes []byte) (*ast.File, error) { f, err := parser.ParseBytes(bytes, 0) if err != nil { return nil, errors.Wrapf(err, "failed to parse yaml") } - return d.fileToNode(f), nil + normalizedFile := &ast.File{} + for _, doc := range f.Docs { + // try to decode ast.Node to value and map anchor value to anchorMap + if v := d.nodeToValue(doc.Body); v != nil { + normalizedFile.Docs = append(normalizedFile.Docs, doc) + } + } + return normalizedFile, nil } -// Decode reads the next YAML-encoded value from its input -// and stores it in the value pointed to by v. -// -// See the documentation for Unmarshal for details about the -// conversion of YAML into a Go value. -func (d *Decoder) Decode(v interface{}) error { +func (d *Decoder) isInitialized() bool { + return d.parsedFile != nil +} + +func (d *Decoder) decodeInit() error { if !d.isResolvedReference { if err := d.resolveReference(); err != nil { return errors.Wrapf(err, "failed to resolve reference") } } - rv := reflect.ValueOf(v) - if rv.Type().Kind() != reflect.Ptr { - return errors.ErrDecodeRequiredPointerType - } var buf bytes.Buffer if _, err := io.Copy(&buf, d.reader); err != nil { return errors.Wrapf(err, "failed to copy from reader") } - node, err := d.decode(buf.Bytes()) + file, err := d.parse(buf.Bytes()) if err != nil { return errors.Wrapf(err, "failed to decode") } - if node == nil { + d.parsedFile = file + return nil +} + +func (d *Decoder) decode(v reflect.Value) error { + if len(d.parsedFile.Docs) <= d.streamIndex { + return io.EOF + } + body := d.parsedFile.Docs[d.streamIndex].Body + if body == nil { return nil } - if err := d.decodeValue(rv.Elem(), node); err != nil { + if err := d.decodeValue(v.Elem(), body); err != nil { return errors.Wrapf(err, "failed to decode value") } + d.streamIndex++ + return nil +} + +// Decode reads the next YAML-encoded value from its input +// and stores it in the value pointed to by v. +// +// See the documentation for Unmarshal for details about the +// conversion of YAML into a Go value. +func (d *Decoder) Decode(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Type().Kind() != reflect.Ptr { + return errors.ErrDecodeRequiredPointerType + } + if d.isInitialized() { + if err := d.decode(rv); err != nil { + if err == io.EOF { + return err + } + return errors.Wrapf(err, "failed to decode") + } + return nil + } + if err := d.decodeInit(); err != nil { + return errors.Wrapf(err, "failed to decodInit") + } + if err := d.decode(rv); err != nil { + if err == io.EOF { + return err + } + return errors.Wrapf(err, "failed to decode") + } return nil } diff --git a/decode_test.go b/decode_test.go index 7cad20b3..86792195 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3,6 +3,7 @@ package yaml_test import ( "bytes" "fmt" + "io" "log" "math" "reflect" @@ -921,6 +922,9 @@ func TestDecoder(t *testing.T) { typ := reflect.ValueOf(test.value).Type() value := reflect.New(typ) if err := dec.Decode(value.Interface()); err != nil { + if err == io.EOF { + continue + } t.Fatalf("%s: %+v", test.source, err) } actual := fmt.Sprintf("%+v", value.Elem().Interface()) @@ -1728,3 +1732,41 @@ j: k t.Fatalf("expected:[%s] actual:[%s]", string(yml), "\n"+string(bytes)) } } + +func TestDecoder_Stream(t *testing.T) { + yml := ` +--- +a: b +c: d +--- +e: f +g: h +--- +i: j +k: l +` + dec := yaml.NewDecoder(strings.NewReader(yml)) + values := []map[string]string{} + for { + var v map[string]string + if err := dec.Decode(&v); err != nil { + if err == io.EOF { + break + } + t.Fatalf("%+v", err) + } + values = append(values, v) + } + if len(values) != 3 { + t.Fatal("failed to stream decoding") + } + if values[0]["a"] != "b" { + t.Fatal("failed to stream decoding") + } + if values[1]["e"] != "f" { + t.Fatal("failed to stream decoding") + } + if values[2]["i"] != "j" { + t.Fatal("failed to stream decoding") + } +} diff --git a/yaml.go b/yaml.go index b45d9207..6177b882 100644 --- a/yaml.go +++ b/yaml.go @@ -2,6 +2,7 @@ package yaml import ( "bytes" + "io" "github.com/goccy/go-yaml/internal/errors" "golang.org/x/xerrors" @@ -127,6 +128,9 @@ func Marshal(v interface{}) ([]byte, error) { func Unmarshal(data []byte, v interface{}) error { dec := NewDecoder(bytes.NewBuffer(data)) if err := dec.Decode(v); err != nil { + if err == io.EOF { + return nil + } return errors.Wrapf(err, "failed to unmarshal") } return nil