diff --git a/decode.go b/decode.go index 2cb7a44f..b0990c8e 100644 --- a/decode.go +++ b/decode.go @@ -34,6 +34,7 @@ type Decoder struct { isResolvedReference bool validator StructValidator disallowUnknownField bool + disallowDuplicateKey bool useOrderedMap bool parsedFile *ast.File streamIndex int @@ -52,6 +53,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder { isRecursiveDir: false, isResolvedReference: false, disallowUnknownField: false, + disallowDuplicateKey: false, useOrderedMap: false, } } @@ -347,6 +349,18 @@ func errUnknownField(msg string, tk *token.Token) *unknownFieldError { return &unknownFieldError{err: errors.ErrSyntax(msg, tk)} } +type duplicateKeyError struct { + err error +} + +func (e *duplicateKeyError) Error() string { + return e.err.Error() +} + +func errDuplicateKey(msg string, tk *token.Token) *duplicateKeyError { + return &duplicateKeyError{err: errors.ErrSyntax(msg, tk)} +} + func (d *Decoder) deleteStructKeys(structValue reflect.Value, unknownFields map[string]ast.Node) error { strType := structValue.Type() structFieldMap, err := structFieldMap(strType) @@ -953,6 +967,20 @@ func (d *Decoder) decodeMapItem(dst *MapItem, src ast.Node) error { return nil } +func (d *Decoder) validateMapKey(keyMap map[string]struct{}, key interface{}, keyNode ast.Node) error { + k, ok := key.(string) + if !ok { + return nil + } + if d.disallowDuplicateKey { + if _, exists := keyMap[k]; exists { + return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken()) + } + } + keyMap[k] = struct{}{} + return nil +} + func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error { mapNode, err := d.getMapNode(src) if err != nil { @@ -963,6 +991,7 @@ func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error { } mapSlice := MapSlice{} mapIter := mapNode.MapRange() + keyMap := map[string]struct{}{} for mapIter.Next() { key := mapIter.Key() value := mapIter.Value() @@ -972,12 +1001,19 @@ func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error { return errors.Wrapf(err, "failed to decode map with merge key") } for _, v := range m { + if err := d.validateMapKey(keyMap, v.Key, value); err != nil { + return errors.Wrapf(err, "invalid map key") + } mapSlice = append(mapSlice, v) } continue } + k := d.nodeToValue(key) + if err := d.validateMapKey(keyMap, k, key); err != nil { + return errors.Wrapf(err, "invalid map key") + } mapSlice = append(mapSlice, MapItem{ - Key: d.nodeToValue(key), + Key: k, Value: d.nodeToValue(value), }) } @@ -998,6 +1034,7 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error { keyType := mapValue.Type().Key() valueType := mapValue.Type().Elem() mapIter := mapNode.MapRange() + keyMap := map[string]struct{}{} var foundErr error for mapIter.Next() { key := mapIter.Key() @@ -1008,6 +1045,9 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error { } iter := dst.MapRange() for iter.Next() { + if err := d.validateMapKey(keyMap, iter.Key(), value); err != nil { + return errors.Wrapf(err, "invalid map key") + } mapValue.SetMapIndex(iter.Key(), iter.Value()) } continue @@ -1016,6 +1056,11 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error { if k.IsValid() && k.Type().ConvertibleTo(keyType) { k = k.Convert(keyType) } + if k.IsValid() { + if err := d.validateMapKey(keyMap, k.Interface(), key); err != nil { + return errors.Wrapf(err, "invalid map key") + } + } if valueType.Kind() == reflect.Ptr && value.Type() == ast.NullType { // set nil value to pointer mapValue.SetMapIndex(k, reflect.Zero(valueType)) diff --git a/decode_test.go b/decode_test.go index 86792195..ab5757f2 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1438,6 +1438,26 @@ children: }) } +func TestDecoder_DisallowDuplicateKey(t *testing.T) { + yml := ` +a: b +a: c +` + expected := ` +[3:1] duplicate key "a" + 2 | +> 3 | a: b + 4 | a: c + ^ +` + var v map[string]string + err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v) + actual := "\n" + err.Error() + if expected != actual { + t.Fatalf("expected:[%s] actual:[%s]", expected, actual) + } +} + func TestDecoder_DefaultValues(t *testing.T) { v := struct { A string `yaml:"a"` diff --git a/option.go b/option.go index b81ed116..1b40f3eb 100644 --- a/option.go +++ b/option.go @@ -49,6 +49,15 @@ func Validator(v StructValidator) DecodeOption { } } +// Strict enable DisallowUnknownField and DisallowDuplicateKey +func Strict() DecodeOption { + return func(d *Decoder) error { + d.disallowUnknownField = true + d.disallowDuplicateKey = true + return nil + } +} + // DisallowUnknownField causes the Decoder to return an error when the destination // is a struct and the input contains object keys which do not match any // non-ignored, exported fields in the destination. @@ -59,6 +68,14 @@ func DisallowUnknownField() DecodeOption { } } +// DisallowDuplicateKey causes an error when mapping keys that are duplicates +func DisallowDuplicateKey() DecodeOption { + return func(d *Decoder) error { + d.disallowDuplicateKey = true + return nil + } +} + // UseOrderedMap can be interpreted as a map, // and uses MapSlice ( ordered map ) aggressively if there is no type specification func UseOrderedMap() DecodeOption {