diff --git a/decode.go b/decode.go index 06690bc..a600981 100644 --- a/decode.go +++ b/decode.go @@ -38,7 +38,7 @@ type Decoder struct { isResolvedReference bool validator StructValidator disallowUnknownField bool - disallowDuplicateKey bool + allowDuplicateMapKey bool useOrderedMap bool useJSONUnmarshaler bool parsedFile *ast.File @@ -60,7 +60,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder { isRecursiveDir: false, isResolvedReference: false, disallowUnknownField: false, - disallowDuplicateKey: false, + allowDuplicateMapKey: false, useOrderedMap: false, } } @@ -1570,7 +1570,7 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface if !ok { return nil } - if d.disallowDuplicateKey { + if !d.allowDuplicateMapKey { if _, exists := keyMap[k]; exists { return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken()) } @@ -1806,7 +1806,11 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) { if d.toCommentMap != nil { parseMode = parser.ParseComments } - f, err := parser.ParseBytes(bytes, parseMode) + var opts []parser.Option + if d.allowDuplicateMapKey { + opts = append(opts, parser.AllowDuplicateMapKey()) + } + f, err := parser.ParseBytes(bytes, parseMode, opts...) if err != nil { return nil, err } diff --git a/decode_test.go b/decode_test.go index a442a40..902ba04 100644 --- a/decode_test.go +++ b/decode_test.go @@ -956,10 +956,6 @@ func TestDecoder(t *testing.T) { "a: &a [1, 2]\nb: *a\n", struct{ B []int }{[]int{1, 2}}, }, - { - "&0: *0\n*0:\n*0:", - map[string]any{"null": nil}, - }, { "key1: &anchor\n subkey: *anchor\nkey2: *anchor\n", map[string]any{ @@ -1502,7 +1498,7 @@ items: Items []*Item } buf := bytes.NewBufferString(yml) - dec := yaml.NewDecoder(buf) + dec := yaml.NewDecoder(buf, yaml.AllowDuplicateMapKey()) var v T if err := dec.Decode(&v); err != nil { t.Fatalf("%+v", err) @@ -1834,39 +1830,23 @@ children: }) } -func TestDecoder_DisallowDuplicateKey(t *testing.T) { +func TestDecoder_AllowDuplicateMapKey(t *testing.T) { yml := ` a: b a: c -` - expected := ` -[3:1] duplicate key "a" - 2 | a: b -> 3 | a: c - ^ ` t.Run("map", func(t *testing.T) { var v map[string]string - err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v) - if err == nil { - t.Fatal("decoding should fail") - } - actual := "\n" + err.Error() - if expected != actual { - t.Fatalf("expected:[%s] actual:[%s]", expected, actual) + if err := yaml.NewDecoder(strings.NewReader(yml), yaml.AllowDuplicateMapKey()).Decode(&v); err != nil { + t.Fatal(err) } }) t.Run("struct", func(t *testing.T) { var v struct { A string } - err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v) - if err == nil { - t.Fatal("decoding should fail") - } - actual := "\n" + err.Error() - if expected != actual { - t.Fatalf("expected:[%s] actual:[%s]", expected, actual) + if err := yaml.NewDecoder(strings.NewReader(yml), yaml.AllowDuplicateMapKey()).Decode(&v); err != nil { + t.Fatal(err) } }) } @@ -3099,3 +3079,23 @@ nested: t.Fatal("decoder doesn't preserve struct defaults") } } + +func TestDecodeError(t *testing.T) { + tests := []struct { + name string + source string + }{ + { + name: "duplicated map key name with anchor-alias", + source: "&0: *0\n*0:\n*0:", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var v any + if err := yaml.Unmarshal([]byte(test.source), &v); err == nil { + t.Fatal("cannot catch decode error") + } + }) + } +} diff --git a/option.go b/option.go index eab5d43..ad43c47 100644 --- a/option.go +++ b/option.go @@ -54,7 +54,6 @@ func Validator(v StructValidator) DecodeOption { func Strict() DecodeOption { return func(d *Decoder) error { d.disallowUnknownField = true - d.disallowDuplicateKey = true return nil } } @@ -69,10 +68,10 @@ func DisallowUnknownField() DecodeOption { } } -// DisallowDuplicateKey causes an error when mapping keys that are duplicates -func DisallowDuplicateKey() DecodeOption { +// AllowDuplicateMapKey ignore syntax error when mapping keys that are duplicates. +func AllowDuplicateMapKey() DecodeOption { return func(d *Decoder) error { - d.disallowDuplicateKey = true + d.allowDuplicateMapKey = true return nil } } diff --git a/parser/option.go b/parser/option.go new file mode 100644 index 0000000..3121a64 --- /dev/null +++ b/parser/option.go @@ -0,0 +1,12 @@ +package parser + +// Option represents parser's option. +type Option func(p *parser) + +// AllowDuplicateMapKey allow the use of keys with the same name in the same map, +// but by default, this is not permitted. +func AllowDuplicateMapKey() Option { + return func(p *parser) { + p.allowDuplicateMapKey = true + } +} diff --git a/parser/parser.go b/parser/parser.go index 71c99a8..ec54ce7 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -12,12 +12,14 @@ import ( ) type parser struct { - idx int - size int - tokens token.Tokens + idx int + size int + tokens token.Tokens + pathMap map[string]ast.Node + allowDuplicateMapKey bool } -func newParser(tokens token.Tokens, mode Mode) *parser { +func newParser(tokens token.Tokens, mode Mode, opts []Option) *parser { filteredTokens := []*token.Token{} if mode&ParseComments != 0 { filteredTokens = tokens @@ -31,11 +33,16 @@ func newParser(tokens token.Tokens, mode Mode) *parser { filteredTokens = append(filteredTokens, tk) } } - return &parser{ - idx: 0, - size: len(filteredTokens), - tokens: token.Tokens(filteredTokens), + p := &parser{ + idx: 0, + size: len(filteredTokens), + tokens: token.Tokens(filteredTokens), + pathMap: make(map[string]ast.Node), } + for _, opt := range opts { + opt(p) + } + return p } func (p *parser) next() bool { @@ -281,7 +288,16 @@ func (p *parser) existsNewLineCharacter(src string) bool { return false } -func (p *parser) validateMapKey(tk *token.Token) error { +func (p *parser) validateMapKey(tk *token.Token, keyPath string) error { + if !p.allowDuplicateMapKey { + if n, exists := p.pathMap[keyPath]; exists { + pos := n.GetToken().Position + return errors.ErrSyntax( + fmt.Sprintf("mapping key %q already defined at [%d:%d]", tk.Value, pos.Line, pos.Column), + tk, + ) + } + } if tk.Type != token.StringType { return nil } @@ -323,7 +339,7 @@ func (p *parser) createMapValueNode(ctx *context, key ast.MapKeyNode, colonToken if tk.Type == token.CommentType { comment = p.parseCommentOnly(ctx) if comment != nil { - comment.SetPath(ctx.withChild(key.GetToken().Value).path) + comment.SetPath(ctx.withChild(p.mapKeyText(key)).path) } tk = p.currentToken() } @@ -397,16 +413,28 @@ func (p *parser) validateMapValue(ctx *context, key, value ast.Node) error { return nil } +func (p *parser) mapKeyText(n ast.Node) string { + switch nn := n.(type) { + case *ast.MappingKeyNode: + return p.mapKeyText(nn.Value) + case *ast.TagNode: + return p.mapKeyText(nn.Value) + } + return n.GetToken().Value +} + func (p *parser) parseMappingValue(ctx *context) (ast.Node, error) { key, err := p.parseMapKey(ctx) if err != nil { return nil, err } - keyText := key.GetToken().Value - key.SetPath(ctx.withChild(keyText).path) - if err := p.validateMapKey(key.GetToken()); err != nil { + keyText := p.mapKeyText(key) + keyPath := ctx.withChild(keyText).path + key.SetPath(keyPath) + if err := p.validateMapKey(key.GetToken(), keyPath); err != nil { return nil, err } + p.pathMap[keyPath] = key p.progress(1) // progress to mapping value token if ctx.isFlow { // if "{key}" or "{key," style, returns MappingValueNode. @@ -557,7 +585,7 @@ func (p *parser) parseFlowMapNullValue(ctx *context, key ast.MapKeyNode) (*ast.M return nil, err } node := ast.MappingValue(tk, key, value) - node.SetPath(ctx.withChild(key.GetToken().Value).path) + node.SetPath(ctx.withChild(p.mapKeyText(key)).path) return node, nil } @@ -868,7 +896,7 @@ func (p *parser) parseMappingKey(ctx *context) (*ast.MappingKeyNode, error) { node := ast.MappingKey(keyTk) node.SetPath(ctx.path) p.progress(1) // skip mapping key token - value, err := p.parseToken(ctx.withChild(keyTk.Value), p.currentToken()) + value, err := p.parseToken(ctx, p.currentToken()) if err != nil { return nil, err } @@ -962,9 +990,9 @@ const ( ) // ParseBytes parse from byte slice, and returns ast.File -func ParseBytes(bytes []byte, mode Mode) (*ast.File, error) { +func ParseBytes(bytes []byte, mode Mode, opts ...Option) (*ast.File, error) { tokens := lexer.Tokenize(string(bytes)) - f, err := Parse(tokens, mode) + f, err := Parse(tokens, mode, opts...) if err != nil { return nil, err } @@ -972,11 +1000,11 @@ func ParseBytes(bytes []byte, mode Mode) (*ast.File, error) { } // Parse parse from token instances, and returns ast.File -func Parse(tokens token.Tokens, mode Mode) (*ast.File, error) { +func Parse(tokens token.Tokens, mode Mode, opts ...Option) (*ast.File, error) { if tk := tokens.InvalidToken(); tk != nil { return nil, errors.ErrSyntax("found invalid token", tk) } - f, err := newParser(tokens, mode).parse(newContext()) + f, err := newParser(tokens, mode, opts).parse(newContext()) if err != nil { return nil, err } @@ -984,12 +1012,12 @@ func Parse(tokens token.Tokens, mode Mode) (*ast.File, error) { } // Parse parse from filename, and returns ast.File -func ParseFile(filename string, mode Mode) (*ast.File, error) { +func ParseFile(filename string, mode Mode, opts ...Option) (*ast.File, error) { file, err := os.ReadFile(filename) if err != nil { return nil, err } - f, err := ParseBytes(file, mode) + f, err := ParseBytes(file, mode, opts...) if err != nil { return nil, err } diff --git a/parser/parser_test.go b/parser/parser_test.go index 189b970..8d4e4d1 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1284,6 +1284,42 @@ a: [1:4] found an invalid key for this map > 1 | [a]: v ^ +`, + }, + { + ` +foo: + bar: + foo: 2 + baz: + foo: 3 +foo: 2 +`, + ` +[7:1] mapping key "foo" already defined at [2:1] + 4 | foo: 2 + 5 | baz: + 6 | foo: 3 +> 7 | foo: 2 + ^ +`, + }, + { + ` +foo: + bar: + foo: 2 + baz: + foo: 3 + foo: 4 +`, + ` +[7:5] mapping key "foo" already defined at [6:5] + 4 | foo: 2 + 5 | baz: + 6 | foo: 3 +> 7 | foo: 4 + ^ `, }, }