Skip to content

Commit

Permalink
Merge pull request #74 from kyoh86/text-marshal
Browse files Browse the repository at this point in the history
Support encoding.TextMarshaler/TextUnmarshaler
  • Loading branch information
goccy authored Feb 14, 2020
2 parents 61d0bc0 + a59c1d6 commit 609746c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
17 changes: 14 additions & 3 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaml

import (
"bytes"
"encoding"
"encoding/base64"
"fmt"
"io"
Expand Down Expand Up @@ -360,6 +361,19 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
} else if _, ok := dst.Addr().Interface().(*time.Time); ok {
return d.decodeTime(dst, src)
} else if unmarshaler, isText := dst.Addr().Interface().(encoding.TextUnmarshaler); isText {
var b string
if scalar, isScalar := src.(ast.ScalarNode); isScalar {
b = scalar.GetValue().(string)
} else {
b = src.String()
}
if err := unmarshaler.UnmarshalText([]byte(b)); err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
}
return nil
}
switch valueType.Kind() {
case reflect.Ptr:
Expand Down Expand Up @@ -388,9 +402,6 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
case reflect.Slice:
return d.decodeSlice(dst, src)
case reflect.Struct:
if _, ok := dst.Addr().Interface().(*time.Time); ok {
return d.decodeTime(dst, src)
}
return d.decodeStruct(dst, src)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v := d.nodeToValue(src)
Expand Down
13 changes: 13 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package yaml

import (
"encoding"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -128,6 +129,18 @@ func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) {
return nil, errors.Wrapf(err, "failed to MarshalYAML")
}
return e.encodeValue(reflect.ValueOf(marshalV), column)
} else if t, ok := v.Interface().(time.Time); ok {
return e.encodeTime(t, column), nil
} else if marshaler, ok := v.Interface().(encoding.TextMarshaler); ok {
doc, err := marshaler.MarshalText()
if err != nil {
return nil, errors.Wrapf(err, "failed to MarshalText")
}
node, err := e.encodeDocument(doc)
if err != nil {
return nil, errors.Wrapf(err, "failed to encode document")
}
return node, nil
}
}
switch v.Type().Kind() {
Expand Down
21 changes: 20 additions & 1 deletion encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,10 @@ type FastMarshaler struct {
A string
B int
}

type TextMarshaler int64
type TextMarshalerContainer struct {
Field TextMarshaler `yaml:"field"`
}
func (v SlowMarshaler) MarshalYAML() ([]byte, error) {
var buf bytes.Buffer
buf.WriteString("tags:\n")
Expand All @@ -763,6 +766,10 @@ func (v FastMarshaler) MarshalYAML() (interface{}, error) {
}, nil
}

func (t TextMarshaler) MarshalText() ([]byte, error) {
return []byte(strconv.FormatInt(int64(t), 8)), nil
}

func Example_MarshalYAML() {
var slow SlowMarshaler
slow.A = "Hello slow poke"
Expand All @@ -782,6 +789,16 @@ func Example_MarshalYAML() {
panic(err.Error())
}

fmt.Println(string(buf))

text := TextMarshalerContainer{
Field: 11,
}
buf, err = yaml.Marshal(text)
if err != nil {
panic(err.Error())
}

fmt.Println(string(buf))
// OUTPUT:
// tags:
Expand All @@ -793,4 +810,6 @@ func Example_MarshalYAML() {
// - fast-marshaler
// a: Hello speed demon
// b: 100
//
// field: 13
}

0 comments on commit 609746c

Please sign in to comment.