diff --git a/ast/ast.go b/ast/ast.go index b4d5ec41..6a0d510f 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -1414,6 +1414,11 @@ func (n *MappingValueNode) toString() string { return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) } else if _, ok := n.Value.(*AliasNode); ok { return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) + } else if tn, ok := n.Value.(*TagNode); ok { + if _, xok := tn.Value.(MapNode); xok { + return fmt.Sprintf("%s%s:%s", space, n.Key.String(), n.Value.String()) + } + return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) } if keyComment != nil { return fmt.Sprintf( @@ -1795,6 +1800,16 @@ func (n *TagNode) AddColumn(col int) { // String tag to text func (n *TagNode) String() string { + space := strings.Repeat(" ", n.GetToken().Position.Column-1) + + value := n.Value.String() + if s, ok := n.Value.(*SequenceNode); ok && !s.IsFlowStyle { + return fmt.Sprintf("%s\n%s", n.Start.Value, value) + } else if m, ok := n.Value.(*MappingNode); ok && !m.IsFlowStyle { + return fmt.Sprintf("\n%s%s\n%s", space, n.Start.Value, value) + } else if _, ok := n.Value.(*MappingValueNode); ok { + return fmt.Sprintf("\n%s%s\n%s", space, n.Start.Value, value) + } return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String()) } diff --git a/encode.go b/encode.go index 3b9b2981..a65df9e5 100644 --- a/encode.go +++ b/encode.go @@ -609,6 +609,9 @@ func (e *Encoder) encodeMapItem(ctx context.Context, item MapItem, column int) ( if e.isMapNode(value) { value.AddColumn(e.indent) } + if e.isTagAndMapNode(value) { + value.AddColumn(e.indent) + } return ast.MappingValue( token.New("", "", e.pos(column)), e.encodeString(k.Interface().(string), column), @@ -633,6 +636,14 @@ func (e *Encoder) isMapNode(node ast.Node) bool { return ok } +func (e *Encoder) isTagAndMapNode(node ast.Node) bool { + tn, ok := node.(*ast.TagNode) + if ok { + _, ok = tn.Value.(ast.MapNode) + } + return ok +} + func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int) ast.Node { node := ast.Mapping(token.New("", "", e.pos(column)), e.isFlowStyle) keys := make([]interface{}, len(value.MapKeys())) @@ -652,6 +663,9 @@ func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int if e.isMapNode(value) { value.AddColumn(e.indent) } + if e.isTagAndMapNode(value) { + value.AddColumn(e.indent) + } node.Values = append(node.Values, ast.MappingValue( nil, e.encodeString(fmt.Sprint(key), column), diff --git a/encode_test.go b/encode_test.go index 3ff6f1c1..50e479f1 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1630,3 +1630,109 @@ b: t.Fatalf("failed to encode. expected %s but got %s", expected, got) } } + +type tagMarshaler struct{} + +func (b *tagMarshaler) MarshalYAML() ([]byte, error) { + v, err := yaml.Marshal("test") + if err != nil { + return nil, err + } + return []byte(fmt.Sprintf("%s %s", "!!timestamp", string(v))), nil +} + +func TestBytesMarshalerWithTag(t *testing.T) { + b, err := yaml.Marshal(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": &tagMarshaler{}, + "d": []*tagMarshaler{&tagMarshaler{}, &tagMarshaler{}}, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + expected := ` +a: + b: + c: !!timestamp test + d: + - !!timestamp test + - !!timestamp test +` + got := "\n" + string(b) + if expected != got { + t.Fatalf("failed to encode. expected %s but got %s", expected, got) + } +} + +type tagMarshalerMapValue struct { + Tag string + Value any +} + +func (t *tagMarshalerMapValue) MarshalYAML() ([]byte, error) { + var out bytes.Buffer + _, _ = fmt.Fprintf(&out, "\n%s\n", t.Tag) + v, err := yaml.ValueToNode(t.Value, yaml.Flow(false)) + if err != nil { + return nil, err + } + _, _ = fmt.Fprintf(&out, "%s", v) + return out.Bytes(), nil +} + +func TestTagMarshalerMapValue(t *testing.T) { + b, err := yaml.Marshal(map[string]interface{}{ + "a": map[string]interface{}{ + "b": &tagMarshalerMapValue{ + Tag: "!mytag", + Value: map[string]interface{}{ + "c": 15, + "d": 99, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + expected := ` +a: + b: + !mytag + c: 15 + d: 99 +` + got := "\n" + string(b) + if expected != got { + t.Fatalf("failed to encode. expected %s but got %s", expected, got) + } +} + +func TestTagMarshalerMapValue2(t *testing.T) { + b, err := yaml.Marshal(map[string]interface{}{ + "a": map[string]interface{}{ + "b": &tagMarshalerMapValue{ + Tag: "!mytag", + Value: map[string]interface{}{ + "c": 15, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + expected := ` +a: + b: + !mytag + c: 15 +` + got := "\n" + string(b) + if expected != got { + t.Fatalf("failed to encode. expected %s but got %s", expected, got) + } +}