diff --git a/decode.go b/decode.go index c5d6f485..042e0257 100644 --- a/decode.go +++ b/decode.go @@ -2,6 +2,7 @@ package yaml import ( "bytes" + "encoding" "encoding/base64" "fmt" "io" @@ -360,6 +361,19 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { return errors.Wrapf(err, "failed to UnmarshalYAML") } return nil + } else if _, ok := dst.Addr().Interface().(*time.Time); ok { + return d.decodeTime(dst, src) + } else if unmarshaler, isText := dst.Addr().Interface().(encoding.TextUnmarshaler); isText { + var b string + if scalar, isScalar := src.(ast.ScalarNode); isScalar { + b = scalar.GetValue().(string) + } else { + b = src.String() + } + if err := unmarshaler.UnmarshalText([]byte(b)); err != nil { + return errors.Wrapf(err, "failed to UnmarshalText") + } + return nil } switch valueType.Kind() { case reflect.Ptr: @@ -388,9 +402,6 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error { case reflect.Slice: return d.decodeSlice(dst, src) case reflect.Struct: - if _, ok := dst.Addr().Interface().(*time.Time); ok { - return d.decodeTime(dst, src) - } return d.decodeStruct(dst, src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: v := d.nodeToValue(src) diff --git a/encode.go b/encode.go index 413942eb..30af0a28 100644 --- a/encode.go +++ b/encode.go @@ -1,6 +1,7 @@ package yaml import ( + "encoding" "fmt" "io" "math" @@ -128,6 +129,18 @@ func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) { return nil, errors.Wrapf(err, "failed to MarshalYAML") } return e.encodeValue(reflect.ValueOf(marshalV), column) + } else if t, ok := v.Interface().(time.Time); ok { + return e.encodeTime(t, column), nil + } else if marshaler, ok := v.Interface().(encoding.TextMarshaler); ok { + doc, err := marshaler.MarshalText() + if err != nil { + return nil, errors.Wrapf(err, "failed to MarshalText") + } + node, err := e.encodeDocument(doc) + if err != nil { + return nil, errors.Wrapf(err, "failed to encode document") + } + return node, nil } } switch v.Type().Kind() { diff --git a/encode_test.go b/encode_test.go index 06fcc26d..f1e55c89 100644 --- a/encode_test.go +++ b/encode_test.go @@ -745,7 +745,10 @@ type FastMarshaler struct { A string B int } - +type TextMarshaler int64 +type TextMarshalerContainer struct { + Field TextMarshaler `yaml:"field"` + } func (v SlowMarshaler) MarshalYAML() ([]byte, error) { var buf bytes.Buffer buf.WriteString("tags:\n") @@ -763,6 +766,10 @@ func (v FastMarshaler) MarshalYAML() (interface{}, error) { }, nil } +func (t TextMarshaler) MarshalText() ([]byte, error) { + return []byte(strconv.FormatInt(int64(t), 8)), nil +} + func Example_MarshalYAML() { var slow SlowMarshaler slow.A = "Hello slow poke" @@ -782,6 +789,16 @@ func Example_MarshalYAML() { panic(err.Error()) } + fmt.Println(string(buf)) + + text := TextMarshalerContainer{ + Field: 11, + } + buf, err = yaml.Marshal(text) + if err != nil { + panic(err.Error()) + } + fmt.Println(string(buf)) // OUTPUT: // tags: @@ -793,4 +810,6 @@ func Example_MarshalYAML() { // - fast-marshaler // a: Hello speed demon // b: 100 + // + // field: 13 }