From c395c330c56c0efb2b84344343da5a512386faa8 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Mon, 26 Oct 2020 19:00:49 +0900 Subject: [PATCH] Add UseJSONMarshaler option for encoder and UseJSONUnmarshaler option for decoder --- decode.go | 17 +++++++++++++++++ decode_test.go | 23 +++++++++++++++++++++++ encode.go | 30 ++++++++++++++++++++++++++---- encode_test.go | 22 ++++++++++++++++++++++ option.go | 19 +++++++++++++++++++ 5 files changed, 107 insertions(+), 4 deletions(-) diff --git a/decode.go b/decode.go index 3889b864..e310d8e6 100644 --- a/decode.go +++ b/decode.go @@ -36,6 +36,7 @@ type Decoder struct { disallowUnknownField bool disallowDuplicateKey bool useOrderedMap bool + useJSONUnmarshaler bool parsedFile *ast.File streamIndex int } @@ -488,6 +489,10 @@ func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool) { return nil, false } +type jsonUnmarshaler interface { + UnmarshalJSON([]byte) error +} + func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { if src.Type() == ast.AnchorType { anchorName := src.(*ast.AnchorNode).Name.GetToken().Value @@ -525,6 +530,18 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { } return nil } + } else if d.useJSONUnmarshaler { + if unmarshaler, ok := dst.Addr().Interface().(jsonUnmarshaler); ok { + jsonBytes, err := YAMLToJSON(d.unmarshalableDocument(src)) + if err != nil { + return errors.Wrapf(err, "failed to convert yaml to json") + } + jsonBytes = bytes.TrimRight(jsonBytes, "\n") + if err := unmarshaler.UnmarshalJSON(jsonBytes); err != nil { + return errors.Wrapf(err, "failed to UnmarshalJSON") + } + return nil + } } switch valueType.Kind() { case reflect.Ptr: diff --git a/decode_test.go b/decode_test.go index f4b8d6f1..01dc3350 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1618,6 +1618,29 @@ B: d // c } +type useJSONUnmarshalerTest struct { + s string +} + +func (t *useJSONUnmarshalerTest) UnmarshalJSON(b []byte) error { + s, err := strconv.Unquote(string(b)) + if err != nil { + return err + } + t.s = s + return nil +} + +func TestDecoder_UseJSONUnmarshaler(t *testing.T) { + var v useJSONUnmarshalerTest + if err := yaml.UnmarshalWithOptions([]byte(`"a"`), &v, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if v.s != "a" { + t.Fatalf("unexpected decoded value: %s", v.s) + } +} + func Example_JSONTags() { yml := `--- foo: 1 diff --git a/encode.go b/encode.go index e3fef008..9bcd7eca 100644 --- a/encode.go +++ b/encode.go @@ -31,6 +31,7 @@ type Encoder struct { indent int isFlowStyle bool isJSONStyle bool + useJSONMarshaler bool anchorCallback func(*ast.AnchorNode, interface{}) error anchorPtrToNameMap map[uintptr]string @@ -118,12 +119,17 @@ func (e *Encoder) isInvalidValue(v reflect.Value) bool { return false } +type jsonMarshaler interface { + MarshalJSON() ([]byte, error) +} + func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) { if e.isInvalidValue(v) { return e.encodeNil(), nil } if v.CanInterface() { - if marshaler, ok := v.Interface().(BytesMarshaler); ok { + iface := v.Interface() + if marshaler, ok := iface.(BytesMarshaler); ok { doc, err := marshaler.MarshalYAML() if err != nil { return nil, errors.Wrapf(err, "failed to MarshalYAML") @@ -133,15 +139,15 @@ func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) { return nil, errors.Wrapf(err, "failed to encode document") } return node, nil - } else if marshaler, ok := v.Interface().(InterfaceMarshaler); ok { + } else if marshaler, ok := iface.(InterfaceMarshaler); ok { marshalV, err := marshaler.MarshalYAML() if err != nil { return nil, errors.Wrapf(err, "failed to MarshalYAML") } return e.encodeValue(reflect.ValueOf(marshalV), column) - } else if t, ok := v.Interface().(time.Time); ok { + } else if t, ok := iface.(time.Time); ok { return e.encodeTime(t, column), nil - } else if marshaler, ok := v.Interface().(encoding.TextMarshaler); ok { + } else if marshaler, ok := iface.(encoding.TextMarshaler); ok { doc, err := marshaler.MarshalText() if err != nil { return nil, errors.Wrapf(err, "failed to MarshalText") @@ -151,6 +157,22 @@ func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) { return nil, errors.Wrapf(err, "failed to encode document") } return node, nil + } else if e.useJSONMarshaler { + if marshaler, ok := iface.(jsonMarshaler); ok { + jsonBytes, err := marshaler.MarshalJSON() + if err != nil { + return nil, errors.Wrapf(err, "failed to MarshalJSON") + } + doc, err := JSONToYAML(jsonBytes) + if err != nil { + return nil, errors.Wrapf(err, "failed to convert json to yaml") + } + node, err := e.encodeDocument(doc) + if err != nil { + return nil, errors.Wrapf(err, "failed to encode document") + } + return node, nil + } } } switch v.Type().Kind() { diff --git a/encode_test.go b/encode_test.go index 093ceab0..4d5cef47 100644 --- a/encode_test.go +++ b/encode_test.go @@ -820,6 +820,28 @@ queues: } } +type useJSONMarshalerTest struct{} + +func (t useJSONMarshalerTest) MarshalJSON() ([]byte, error) { + return []byte(`{"a":[1, 2, 3]}`), nil +} + +func TestEncoder_UseJSONMarshaler(t *testing.T) { + got, err := yaml.MarshalWithOptions(useJSONMarshalerTest{}, yaml.UseJSONMarshaler()) + if err != nil { + t.Fatal(err) + } + expected := ` +a: +- 1 +- 2 +- 3 +` + if expected != "\n"+string(got) { + t.Fatalf("failed to use json marshaler. expected [%q] but got [%q]", expected, string(got)) + } +} + func Example_Marshal_ExplicitAnchorAlias() { type T struct { A int diff --git a/option.go b/option.go index 2184b7cf..343fc55d 100644 --- a/option.go +++ b/option.go @@ -85,6 +85,15 @@ func UseOrderedMap() DecodeOption { } } +// UseJSONUnmarshaler if neither `BytesUnmarshaler` nor `InterfaceUnmarshaler` is implemented +// and `UnmashalJSON([]byte)error` is implemented, convert the argument from `YAML` to `JSON` and then call it. +func UseJSONUnmarshaler() DecodeOption { + return func(d *Decoder) error { + d.useJSONUnmarshaler = true + return nil + } +} + // EncodeOption functional option type for Encoder type EncodeOption func(e *Encoder) error @@ -120,3 +129,13 @@ func MarshalAnchor(callback func(*ast.AnchorNode, interface{}) error) EncodeOpti return nil } } + +// UseJSONMarshaler if neither `BytesMarshaler` nor `InterfaceMarshaler` +// nor `encoding.TextMarshaler` is implemented and `MarshalJSON()([]byte, error)` is implemented, +// call `MarshalJSON` to convert the returned `JSON` to `YAML` for processing. +func UseJSONMarshaler() EncodeOption { + return func(e *Encoder) error { + e.useJSONMarshaler = true + return nil + } +}