Skip to content

Commit

Permalink
support to disallow duplicate map key at parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy committed Nov 13, 2024
1 parent 4385176 commit 1cf609f
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 55 deletions.
12 changes: 8 additions & 4 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Decoder struct {
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
disallowDuplicateKey bool
allowDuplicateMapKey bool
useOrderedMap bool
useJSONUnmarshaler bool
parsedFile *ast.File
Expand All @@ -60,7 +60,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
disallowDuplicateKey: false,
allowDuplicateMapKey: false,
useOrderedMap: false,
}
}
Expand Down Expand Up @@ -1570,7 +1570,7 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface
if !ok {
return nil
}
if d.disallowDuplicateKey {
if !d.allowDuplicateMapKey {
if _, exists := keyMap[k]; exists {
return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken())
}
Expand Down Expand Up @@ -1806,7 +1806,11 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) {
if d.toCommentMap != nil {
parseMode = parser.ParseComments
}
f, err := parser.ParseBytes(bytes, parseMode)
var opts []parser.Option
if d.allowDuplicateMapKey {
opts = append(opts, parser.AllowDuplicateMapKey())
}
f, err := parser.ParseBytes(bytes, parseMode, opts...)
if err != nil {
return nil, err
}
Expand Down
52 changes: 26 additions & 26 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,6 @@ func TestDecoder(t *testing.T) {
"a: &a [1, 2]\nb: *a\n",
struct{ B []int }{[]int{1, 2}},
},
{
"&0: *0\n*0:\n*0:",
map[string]any{"null": nil},
},
{
"key1: &anchor\n subkey: *anchor\nkey2: *anchor\n",
map[string]any{
Expand Down Expand Up @@ -1502,7 +1498,7 @@ items:
Items []*Item
}
buf := bytes.NewBufferString(yml)
dec := yaml.NewDecoder(buf)
dec := yaml.NewDecoder(buf, yaml.AllowDuplicateMapKey())
var v T
if err := dec.Decode(&v); err != nil {
t.Fatalf("%+v", err)
Expand Down Expand Up @@ -1834,39 +1830,23 @@ children:
})
}

func TestDecoder_DisallowDuplicateKey(t *testing.T) {
func TestDecoder_AllowDuplicateMapKey(t *testing.T) {
yml := `
a: b
a: c
`
expected := `
[3:1] duplicate key "a"
2 | a: b
> 3 | a: c
^
`
t.Run("map", func(t *testing.T) {
var v map[string]string
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
if err == nil {
t.Fatal("decoding should fail")
}
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
if err := yaml.NewDecoder(strings.NewReader(yml), yaml.AllowDuplicateMapKey()).Decode(&v); err != nil {
t.Fatal(err)
}
})
t.Run("struct", func(t *testing.T) {
var v struct {
A string
}
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
if err == nil {
t.Fatal("decoding should fail")
}
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
if err := yaml.NewDecoder(strings.NewReader(yml), yaml.AllowDuplicateMapKey()).Decode(&v); err != nil {
t.Fatal(err)
}
})
}
Expand Down Expand Up @@ -3099,3 +3079,23 @@ nested:
t.Fatal("decoder doesn't preserve struct defaults")
}
}

func TestDecodeError(t *testing.T) {
tests := []struct {
name string
source string
}{
{
name: "duplicated map key name with anchor-alias",
source: "&0: *0\n*0:\n*0:",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var v any
if err := yaml.Unmarshal([]byte(test.source), &v); err == nil {
t.Fatal("cannot catch decode error")
}
})
}
}
7 changes: 3 additions & 4 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func Validator(v StructValidator) DecodeOption {
func Strict() DecodeOption {
return func(d *Decoder) error {
d.disallowUnknownField = true
d.disallowDuplicateKey = true
return nil
}
}
Expand All @@ -69,10 +68,10 @@ func DisallowUnknownField() DecodeOption {
}
}

// DisallowDuplicateKey causes an error when mapping keys that are duplicates
func DisallowDuplicateKey() DecodeOption {
// AllowDuplicateMapKey ignore syntax error when mapping keys that are duplicates.
func AllowDuplicateMapKey() DecodeOption {
return func(d *Decoder) error {
d.disallowDuplicateKey = true
d.allowDuplicateMapKey = true
return nil
}
}
Expand Down
12 changes: 12 additions & 0 deletions parser/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package parser

// Option represents parser's option.
type Option func(p *parser)

// AllowDuplicateMapKey allow the use of keys with the same name in the same map,
// but by default, this is not permitted.
func AllowDuplicateMapKey() Option {
return func(p *parser) {
p.allowDuplicateMapKey = true
}
}
70 changes: 49 additions & 21 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
)

type parser struct {
idx int
size int
tokens token.Tokens
idx int
size int
tokens token.Tokens
pathMap map[string]ast.Node
allowDuplicateMapKey bool
}

func newParser(tokens token.Tokens, mode Mode) *parser {
func newParser(tokens token.Tokens, mode Mode, opts []Option) *parser {
filteredTokens := []*token.Token{}
if mode&ParseComments != 0 {
filteredTokens = tokens
Expand All @@ -31,11 +33,16 @@ func newParser(tokens token.Tokens, mode Mode) *parser {
filteredTokens = append(filteredTokens, tk)
}
}
return &parser{
idx: 0,
size: len(filteredTokens),
tokens: token.Tokens(filteredTokens),
p := &parser{
idx: 0,
size: len(filteredTokens),
tokens: token.Tokens(filteredTokens),
pathMap: make(map[string]ast.Node),
}
for _, opt := range opts {
opt(p)
}
return p
}

func (p *parser) next() bool {
Expand Down Expand Up @@ -281,7 +288,16 @@ func (p *parser) existsNewLineCharacter(src string) bool {
return false
}

func (p *parser) validateMapKey(tk *token.Token) error {
func (p *parser) validateMapKey(tk *token.Token, keyPath string) error {
if !p.allowDuplicateMapKey {
if n, exists := p.pathMap[keyPath]; exists {
pos := n.GetToken().Position
return errors.ErrSyntax(
fmt.Sprintf("mapping key %q already defined at [%d:%d]", tk.Value, pos.Line, pos.Column),
tk,
)
}
}
if tk.Type != token.StringType {
return nil
}
Expand Down Expand Up @@ -323,7 +339,7 @@ func (p *parser) createMapValueNode(ctx *context, key ast.MapKeyNode, colonToken
if tk.Type == token.CommentType {
comment = p.parseCommentOnly(ctx)
if comment != nil {
comment.SetPath(ctx.withChild(key.GetToken().Value).path)
comment.SetPath(ctx.withChild(p.mapKeyText(key)).path)
}
tk = p.currentToken()
}
Expand Down Expand Up @@ -397,16 +413,28 @@ func (p *parser) validateMapValue(ctx *context, key, value ast.Node) error {
return nil
}

func (p *parser) mapKeyText(n ast.Node) string {
switch nn := n.(type) {
case *ast.MappingKeyNode:
return p.mapKeyText(nn.Value)
case *ast.TagNode:
return p.mapKeyText(nn.Value)
}
return n.GetToken().Value
}

func (p *parser) parseMappingValue(ctx *context) (ast.Node, error) {
key, err := p.parseMapKey(ctx)
if err != nil {
return nil, err
}
keyText := key.GetToken().Value
key.SetPath(ctx.withChild(keyText).path)
if err := p.validateMapKey(key.GetToken()); err != nil {
keyText := p.mapKeyText(key)
keyPath := ctx.withChild(keyText).path
key.SetPath(keyPath)
if err := p.validateMapKey(key.GetToken(), keyPath); err != nil {
return nil, err
}
p.pathMap[keyPath] = key
p.progress(1) // progress to mapping value token
if ctx.isFlow {
// if "{key}" or "{key," style, returns MappingValueNode.
Expand Down Expand Up @@ -557,7 +585,7 @@ func (p *parser) parseFlowMapNullValue(ctx *context, key ast.MapKeyNode) (*ast.M
return nil, err
}
node := ast.MappingValue(tk, key, value)
node.SetPath(ctx.withChild(key.GetToken().Value).path)
node.SetPath(ctx.withChild(p.mapKeyText(key)).path)
return node, nil
}

Expand Down Expand Up @@ -868,7 +896,7 @@ func (p *parser) parseMappingKey(ctx *context) (*ast.MappingKeyNode, error) {
node := ast.MappingKey(keyTk)
node.SetPath(ctx.path)
p.progress(1) // skip mapping key token
value, err := p.parseToken(ctx.withChild(keyTk.Value), p.currentToken())
value, err := p.parseToken(ctx, p.currentToken())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -962,34 +990,34 @@ const (
)

// ParseBytes parse from byte slice, and returns ast.File
func ParseBytes(bytes []byte, mode Mode) (*ast.File, error) {
func ParseBytes(bytes []byte, mode Mode, opts ...Option) (*ast.File, error) {
tokens := lexer.Tokenize(string(bytes))
f, err := Parse(tokens, mode)
f, err := Parse(tokens, mode, opts...)
if err != nil {
return nil, err
}
return f, nil
}

// Parse parse from token instances, and returns ast.File
func Parse(tokens token.Tokens, mode Mode) (*ast.File, error) {
func Parse(tokens token.Tokens, mode Mode, opts ...Option) (*ast.File, error) {
if tk := tokens.InvalidToken(); tk != nil {
return nil, errors.ErrSyntax("found invalid token", tk)
}
f, err := newParser(tokens, mode).parse(newContext())
f, err := newParser(tokens, mode, opts).parse(newContext())
if err != nil {
return nil, err
}
return f, nil
}

// Parse parse from filename, and returns ast.File
func ParseFile(filename string, mode Mode) (*ast.File, error) {
func ParseFile(filename string, mode Mode, opts ...Option) (*ast.File, error) {
file, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
f, err := ParseBytes(file, mode)
f, err := ParseBytes(file, mode, opts...)
if err != nil {
return nil, err
}
Expand Down
36 changes: 36 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,42 @@ a:
[1:4] found an invalid key for this map
> 1 | [a]: v
^
`,
},
{
`
foo:
bar:
foo: 2
baz:
foo: 3
foo: 2
`,
`
[7:1] mapping key "foo" already defined at [2:1]
4 | foo: 2
5 | baz:
6 | foo: 3
> 7 | foo: 2
^
`,
},
{
`
foo:
bar:
foo: 2
baz:
foo: 3
foo: 4
`,
`
[7:5] mapping key "foo" already defined at [6:5]
4 | foo: 2
5 | baz:
6 | foo: 3
> 7 | foo: 4
^
`,
},
}
Expand Down

0 comments on commit 1cf609f

Please sign in to comment.