From 5594599465e6f02e79b0c22ae7193cd86e9b6f03 Mon Sep 17 00:00:00 2001 From: Kyoichiro Yamada Date: Sun, 12 Jan 2020 02:22:38 +0900 Subject: [PATCH 1/4] wip From 81b7b6d4a71c95a23d53e22998341f625e02deec Mon Sep 17 00:00:00 2001 From: Kyoichiro Yamada Date: Sun, 12 Jan 2020 02:50:51 +0900 Subject: [PATCH 2/4] support encoding.TextMarshaler and encoding.TextUnmarshaler --- decode.go | 12 +++++++++--- encode.go | 13 +++++++++++++ encode_test.go | 21 ++++++++++++++++++++- go.mod | 2 ++ go.sum | 2 ++ 5 files changed, 46 insertions(+), 4 deletions(-) diff --git a/decode.go b/decode.go index b9b21f9f..45ae81dd 100644 --- a/decode.go +++ b/decode.go @@ -2,6 +2,7 @@ package yaml import ( "bytes" + "encoding" "encoding/base64" "fmt" "io" @@ -355,6 +356,14 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { return errors.Wrapf(err, "failed to UnmarshalYAML") } return nil + } else if _, ok := dst.Addr().Interface().(*time.Time); ok { + return d.decodeTime(dst, src) + } else if unmarshaler, isText := dst.Addr().Interface().(encoding.TextUnmarshaler); isText { + b := fmt.Sprintf("%v", src) + if err := unmarshaler.UnmarshalText([]byte(b)); err != nil { + return errors.Wrapf(err, "failed to UnmarshalText") + } + return nil } switch valueType.Kind() { case reflect.Ptr: @@ -383,9 +392,6 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { case reflect.Slice: return d.decodeSlice(dst, src) case reflect.Struct: - if _, ok := dst.Addr().Interface().(*time.Time); ok { - return d.decodeTime(dst, src) - } return d.decodeStruct(dst, src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: v := d.nodeToValue(src) diff --git a/encode.go b/encode.go index 413942eb..30af0a28 100644 --- a/encode.go +++ b/encode.go @@ -1,6 +1,7 @@ package yaml import ( + "encoding" "fmt" "io" "math" @@ -128,6 +129,18 @@ func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) { return nil, errors.Wrapf(err, "failed to MarshalYAML") } return e.encodeValue(reflect.ValueOf(marshalV), column) + } else if t, ok := v.Interface().(time.Time); ok { + return e.encodeTime(t, column), nil + } else if marshaler, ok := v.Interface().(encoding.TextMarshaler); ok { + doc, err := marshaler.MarshalText() + if err != nil { + return nil, errors.Wrapf(err, "failed to MarshalText") + } + node, err := e.encodeDocument(doc) + if err != nil { + return nil, errors.Wrapf(err, "failed to encode document") + } + return node, nil } } switch v.Type().Kind() { diff --git a/encode_test.go b/encode_test.go index 06fcc26d..f1e55c89 100644 --- a/encode_test.go +++ b/encode_test.go @@ -745,7 +745,10 @@ type FastMarshaler struct { A string B int } - +type TextMarshaler int64 +type TextMarshalerContainer struct { + Field TextMarshaler `yaml:"field"` + } func (v SlowMarshaler) MarshalYAML() ([]byte, error) { var buf bytes.Buffer buf.WriteString("tags:\n") @@ -763,6 +766,10 @@ func (v FastMarshaler) MarshalYAML() (interface{}, error) { }, nil } +func (t TextMarshaler) MarshalText() ([]byte, error) { + return []byte(strconv.FormatInt(int64(t), 8)), nil +} + func Example_MarshalYAML() { var slow SlowMarshaler slow.A = "Hello slow poke" @@ -782,6 +789,16 @@ func Example_MarshalYAML() { panic(err.Error()) } + fmt.Println(string(buf)) + + text := TextMarshalerContainer{ + Field: 11, + } + buf, err = yaml.Marshal(text) + if err != nil { + panic(err.Error()) + } + fmt.Println(string(buf)) // OUTPUT: // tags: @@ -793,4 +810,6 @@ func Example_MarshalYAML() { // - fast-marshaler // a: Hello speed demon // b: 100 + // + // field: 13 } diff --git a/go.mod b/go.mod index 71a617d6..b6e88deb 100644 --- a/go.mod +++ b/go.mod @@ -12,4 +12,6 @@ require ( golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/go-playground/validator.v9 v9.30.0 + gopkg.in/yaml.v2 v2.2.2 + gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2 ) diff --git a/go.sum b/go.sum index d4ff87ef..f142df7e 100644 --- a/go.sum +++ b/go.sum @@ -36,3 +36,5 @@ gopkg.in/go-playground/validator.v9 v9.30.0 h1:Wk0Z37oBmKj9/n+tPyBHZmeL19LaCoK3Q gopkg.in/go-playground/validator.v9 v9.30.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2 h1:XZx7nhd5GMaZpmDaEHFVafUZC7ya0fuo7cSJ3UCKYmM= +gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 30b673aeaa44fe5a85d3cc10119eb2f2a18793f9 Mon Sep 17 00:00:00 2001 From: Kyoichiro Yamada Date: Sun, 12 Jan 2020 02:58:51 +0900 Subject: [PATCH 3/4] revert go.* --- go.mod | 2 -- go.sum | 2 -- 2 files changed, 4 deletions(-) diff --git a/go.mod b/go.mod index b6e88deb..71a617d6 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,4 @@ require ( golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/go-playground/validator.v9 v9.30.0 - gopkg.in/yaml.v2 v2.2.2 - gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2 ) diff --git a/go.sum b/go.sum index f142df7e..d4ff87ef 100644 --- a/go.sum +++ b/go.sum @@ -36,5 +36,3 @@ gopkg.in/go-playground/validator.v9 v9.30.0 h1:Wk0Z37oBmKj9/n+tPyBHZmeL19LaCoK3Q gopkg.in/go-playground/validator.v9 v9.30.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2 h1:XZx7nhd5GMaZpmDaEHFVafUZC7ya0fuo7cSJ3UCKYmM= -gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From a59c1d6d35a8edbec1aeb6e5de2f89c633b8c897 Mon Sep 17 00:00:00 2001 From: Kyoichiro Yamada Date: Sun, 12 Jan 2020 22:37:09 +0900 Subject: [PATCH 4/4] don't quote string before unmarshal text --- decode.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/decode.go b/decode.go index 45ae81dd..cba48569 100644 --- a/decode.go +++ b/decode.go @@ -359,7 +359,12 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { } else if _, ok := dst.Addr().Interface().(*time.Time); ok { return d.decodeTime(dst, src) } else if unmarshaler, isText := dst.Addr().Interface().(encoding.TextUnmarshaler); isText { - b := fmt.Sprintf("%v", src) + var b string + if scalar, isScalar := src.(ast.ScalarNode); isScalar { + b = scalar.GetValue().(string) + } else { + b = src.String() + } if err := unmarshaler.UnmarshalText([]byte(b)); err != nil { return errors.Wrapf(err, "failed to UnmarshalText") }