Skip to content

Commit

Permalink
Merge pull request #124 from goccy/feature/fix-issue-122
Browse files Browse the repository at this point in the history
Support decoding with DisallowDuplicateKey option for struct
  • Loading branch information
goccy authored Jun 8, 2020
2 parents 85a4ca1 + 4013d13 commit 24e2c3f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
17 changes: 12 additions & 5 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue
if err != nil {
return nil, errors.Wrapf(err, "failed to get map node")
}
keyMap := map[string]struct{}{}
keyToNodeMap := map[string]ast.Node{}
if mapNode == nil {
return keyToNodeMap, nil
Expand All @@ -597,13 +598,19 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue
return nil, errors.Wrapf(err, "failed to get keyToNodeMap by MergeKey node")
}
for k, v := range mergeMap {
if err := d.validateDuplicateKey(keyMap, k, v); err != nil {
return nil, errors.Wrapf(err, "invalid struct key")
}
keyToNodeMap[k] = v
}
} else {
key, ok := d.nodeToValue(keyNode).(string)
if !ok {
return nil, errors.Wrapf(err, "failed to decode map key")
}
if err := d.validateDuplicateKey(keyMap, key, keyNode); err != nil {
return nil, errors.Wrapf(err, "invalid struct key")
}
keyToNodeMap[key] = getKeyOrValueNode(mapIter)
}
}
Expand Down Expand Up @@ -970,7 +977,7 @@ func (d *Decoder) decodeMapItem(dst *MapItem, src ast.Node) error {
return nil
}

func (d *Decoder) validateMapKey(keyMap map[string]struct{}, key interface{}, keyNode ast.Node) error {
func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface{}, keyNode ast.Node) error {
k, ok := key.(string)
if !ok {
return nil
Expand Down Expand Up @@ -1004,15 +1011,15 @@ func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error {
return errors.Wrapf(err, "failed to decode map with merge key")
}
for _, v := range m {
if err := d.validateMapKey(keyMap, v.Key, value); err != nil {
if err := d.validateDuplicateKey(keyMap, v.Key, value); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapSlice = append(mapSlice, v)
}
continue
}
k := d.nodeToValue(key)
if err := d.validateMapKey(keyMap, k, key); err != nil {
if err := d.validateDuplicateKey(keyMap, k, key); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapSlice = append(mapSlice, MapItem{
Expand Down Expand Up @@ -1048,7 +1055,7 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
}
iter := dst.MapRange()
for iter.Next() {
if err := d.validateMapKey(keyMap, iter.Key(), value); err != nil {
if err := d.validateDuplicateKey(keyMap, iter.Key(), value); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapValue.SetMapIndex(iter.Key(), iter.Value())
Expand All @@ -1060,7 +1067,7 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
k = k.Convert(keyType)
}
if k.IsValid() {
if err := d.validateMapKey(keyMap, k.Interface(), key); err != nil {
if err := d.validateDuplicateKey(keyMap, k.Interface(), key); err != nil {
return errors.Wrapf(err, "invalid map key")
}
}
Expand Down
30 changes: 24 additions & 6 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1491,12 +1491,30 @@ a: c
> 3 | a: c
^
`
var v map[string]string
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
}
t.Run("map", func(t *testing.T) {
var v map[string]string
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
if err == nil {
t.Fatal("decoding should fail")
}
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
}
})
t.Run("struct", func(t *testing.T) {
var v struct {
A string
}
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
if err == nil {
t.Fatal("decoding should fail")
}
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
}
})
}

func TestDecoder_DefaultValues(t *testing.T) {
Expand Down

0 comments on commit 24e2c3f

Please sign in to comment.