Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support to disallow duplicate map key at parsing #531

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Decoder struct {
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
disallowDuplicateKey bool
allowDuplicateMapKey bool
useOrderedMap bool
useJSONUnmarshaler bool
parsedFile *ast.File
Expand All @@ -60,7 +60,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
disallowDuplicateKey: false,
allowDuplicateMapKey: false,
useOrderedMap: false,
}
}
Expand Down Expand Up @@ -1570,7 +1570,7 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface
if !ok {
return nil
}
if d.disallowDuplicateKey {
if !d.allowDuplicateMapKey {
if _, exists := keyMap[k]; exists {
return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken())
}
Expand Down Expand Up @@ -1806,7 +1806,11 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) {
if d.toCommentMap != nil {
parseMode = parser.ParseComments
}
f, err := parser.ParseBytes(bytes, parseMode)
var opts []parser.Option
if d.allowDuplicateMapKey {
opts = append(opts, parser.AllowDuplicateMapKey())
}
f, err := parser.ParseBytes(bytes, parseMode, opts...)
if err != nil {
return nil, err
}
Expand Down
52 changes: 26 additions & 26 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,6 @@ func TestDecoder(t *testing.T) {
"a: &a [1, 2]\nb: *a\n",
struct{ B []int }{[]int{1, 2}},
},
{
"&0: *0\n*0:\n*0:",
map[string]any{"null": nil},
},
{
"key1: &anchor\n subkey: *anchor\nkey2: *anchor\n",
map[string]any{
Expand Down Expand Up @@ -1502,7 +1498,7 @@ items:
Items []*Item
}
buf := bytes.NewBufferString(yml)
dec := yaml.NewDecoder(buf)
dec := yaml.NewDecoder(buf, yaml.AllowDuplicateMapKey())
var v T
if err := dec.Decode(&v); err != nil {
t.Fatalf("%+v", err)
Expand Down Expand Up @@ -1834,39 +1830,23 @@ children:
})
}

func TestDecoder_DisallowDuplicateKey(t *testing.T) {
func TestDecoder_AllowDuplicateMapKey(t *testing.T) {
yml := `
a: b
a: c
`
expected := `
[3:1] duplicate key "a"
2 | a: b
> 3 | a: c
^
`
t.Run("map", func(t *testing.T) {
var v map[string]string
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
if err == nil {
t.Fatal("decoding should fail")
}
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
if err := yaml.NewDecoder(strings.NewReader(yml), yaml.AllowDuplicateMapKey()).Decode(&v); err != nil {
t.Fatal(err)
}
})
t.Run("struct", func(t *testing.T) {
var v struct {
A string
}
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
if err == nil {
t.Fatal("decoding should fail")
}
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
if err := yaml.NewDecoder(strings.NewReader(yml), yaml.AllowDuplicateMapKey()).Decode(&v); err != nil {
t.Fatal(err)
}
})
}
Expand Down Expand Up @@ -3099,3 +3079,23 @@ nested:
t.Fatal("decoder doesn't preserve struct defaults")
}
}

func TestDecodeError(t *testing.T) {
tests := []struct {
name string
source string
}{
{
name: "duplicated map key name with anchor-alias",
source: "&0: *0\n*0:\n*0:",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var v any
if err := yaml.Unmarshal([]byte(test.source), &v); err == nil {
t.Fatal("cannot catch decode error")
}
})
}
}
7 changes: 3 additions & 4 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func Validator(v StructValidator) DecodeOption {
func Strict() DecodeOption {
return func(d *Decoder) error {
d.disallowUnknownField = true
d.disallowDuplicateKey = true
return nil
}
}
Expand All @@ -69,10 +68,10 @@ func DisallowUnknownField() DecodeOption {
}
}

// DisallowDuplicateKey causes an error when mapping keys that are duplicates
func DisallowDuplicateKey() DecodeOption {
// AllowDuplicateMapKey ignore syntax error when mapping keys that are duplicates.
func AllowDuplicateMapKey() DecodeOption {
return func(d *Decoder) error {
d.disallowDuplicateKey = true
d.allowDuplicateMapKey = true
return nil
}
}
Expand Down
12 changes: 12 additions & 0 deletions parser/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package parser

// Option represents parser's option.
type Option func(p *parser)

// AllowDuplicateMapKey allow the use of keys with the same name in the same map,
// but by default, this is not permitted.
func AllowDuplicateMapKey() Option {
return func(p *parser) {
p.allowDuplicateMapKey = true
}
}
70 changes: 49 additions & 21 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
)

type parser struct {
idx int
size int
tokens token.Tokens
idx int
size int
tokens token.Tokens
pathMap map[string]ast.Node
allowDuplicateMapKey bool
}

func newParser(tokens token.Tokens, mode Mode) *parser {
func newParser(tokens token.Tokens, mode Mode, opts []Option) *parser {
filteredTokens := []*token.Token{}
if mode&ParseComments != 0 {
filteredTokens = tokens
Expand All @@ -31,11 +33,16 @@ func newParser(tokens token.Tokens, mode Mode) *parser {
filteredTokens = append(filteredTokens, tk)
}
}
return &parser{
idx: 0,
size: len(filteredTokens),
tokens: token.Tokens(filteredTokens),
p := &parser{
idx: 0,
size: len(filteredTokens),
tokens: token.Tokens(filteredTokens),
pathMap: make(map[string]ast.Node),
}
for _, opt := range opts {
opt(p)
}
return p
}

func (p *parser) next() bool {
Expand Down Expand Up @@ -281,7 +288,16 @@ func (p *parser) existsNewLineCharacter(src string) bool {
return false
}

func (p *parser) validateMapKey(tk *token.Token) error {
func (p *parser) validateMapKey(tk *token.Token, keyPath string) error {
if !p.allowDuplicateMapKey {
if n, exists := p.pathMap[keyPath]; exists {
pos := n.GetToken().Position
return errors.ErrSyntax(
fmt.Sprintf("mapping key %q already defined at [%d:%d]", tk.Value, pos.Line, pos.Column),
tk,
)
}
}
if tk.Type != token.StringType {
return nil
}
Expand Down Expand Up @@ -323,7 +339,7 @@ func (p *parser) createMapValueNode(ctx *context, key ast.MapKeyNode, colonToken
if tk.Type == token.CommentType {
comment = p.parseCommentOnly(ctx)
if comment != nil {
comment.SetPath(ctx.withChild(key.GetToken().Value).path)
comment.SetPath(ctx.withChild(p.mapKeyText(key)).path)
}
tk = p.currentToken()
}
Expand Down Expand Up @@ -397,16 +413,28 @@ func (p *parser) validateMapValue(ctx *context, key, value ast.Node) error {
return nil
}

func (p *parser) mapKeyText(n ast.Node) string {
switch nn := n.(type) {
case *ast.MappingKeyNode:
return p.mapKeyText(nn.Value)
case *ast.TagNode:
return p.mapKeyText(nn.Value)
}
return n.GetToken().Value
}

func (p *parser) parseMappingValue(ctx *context) (ast.Node, error) {
key, err := p.parseMapKey(ctx)
if err != nil {
return nil, err
}
keyText := key.GetToken().Value
key.SetPath(ctx.withChild(keyText).path)
if err := p.validateMapKey(key.GetToken()); err != nil {
keyText := p.mapKeyText(key)
keyPath := ctx.withChild(keyText).path
key.SetPath(keyPath)
if err := p.validateMapKey(key.GetToken(), keyPath); err != nil {
return nil, err
}
p.pathMap[keyPath] = key
p.progress(1) // progress to mapping value token
if ctx.isFlow {
// if "{key}" or "{key," style, returns MappingValueNode.
Expand Down Expand Up @@ -557,7 +585,7 @@ func (p *parser) parseFlowMapNullValue(ctx *context, key ast.MapKeyNode) (*ast.M
return nil, err
}
node := ast.MappingValue(tk, key, value)
node.SetPath(ctx.withChild(key.GetToken().Value).path)
node.SetPath(ctx.withChild(p.mapKeyText(key)).path)
return node, nil
}

Expand Down Expand Up @@ -868,7 +896,7 @@ func (p *parser) parseMappingKey(ctx *context) (*ast.MappingKeyNode, error) {
node := ast.MappingKey(keyTk)
node.SetPath(ctx.path)
p.progress(1) // skip mapping key token
value, err := p.parseToken(ctx.withChild(keyTk.Value), p.currentToken())
value, err := p.parseToken(ctx, p.currentToken())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -962,34 +990,34 @@ const (
)

// ParseBytes parse from byte slice, and returns ast.File
func ParseBytes(bytes []byte, mode Mode) (*ast.File, error) {
func ParseBytes(bytes []byte, mode Mode, opts ...Option) (*ast.File, error) {
tokens := lexer.Tokenize(string(bytes))
f, err := Parse(tokens, mode)
f, err := Parse(tokens, mode, opts...)
if err != nil {
return nil, err
}
return f, nil
}

// Parse parse from token instances, and returns ast.File
func Parse(tokens token.Tokens, mode Mode) (*ast.File, error) {
func Parse(tokens token.Tokens, mode Mode, opts ...Option) (*ast.File, error) {
if tk := tokens.InvalidToken(); tk != nil {
return nil, errors.ErrSyntax("found invalid token", tk)
}
f, err := newParser(tokens, mode).parse(newContext())
f, err := newParser(tokens, mode, opts).parse(newContext())
if err != nil {
return nil, err
}
return f, nil
}

// Parse parse from filename, and returns ast.File
func ParseFile(filename string, mode Mode) (*ast.File, error) {
func ParseFile(filename string, mode Mode, opts ...Option) (*ast.File, error) {
file, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
f, err := ParseBytes(file, mode)
f, err := ParseBytes(file, mode, opts...)
if err != nil {
return nil, err
}
Expand Down
36 changes: 36 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,42 @@ a:
[1:4] found an invalid key for this map
> 1 | [a]: v
^
`,
},
{
`
foo:
bar:
foo: 2
baz:
foo: 3
foo: 2
`,
`
[7:1] mapping key "foo" already defined at [2:1]
4 | foo: 2
5 | baz:
6 | foo: 3
> 7 | foo: 2
^
`,
},
{
`
foo:
bar:
foo: 2
baz:
foo: 3
foo: 4
`,
`
[7:5] mapping key "foo" already defined at [6:5]
4 | foo: 2
5 | baz:
6 | foo: 3
> 7 | foo: 4
^
`,
},
}
Expand Down
Loading