From e27add428c9c2f5b120ad55c7a1abff73a2635c4 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 2 Nov 2024 17:40:43 +0900 Subject: [PATCH] fix invalid alias (#497) --- decode.go | 256 +++++++++++++++++++++++++++++++++++-------------- decode_test.go | 31 +++++- 2 files changed, 216 insertions(+), 71 deletions(-) diff --git a/decode.go b/decode.go index 9ff9f3eb..27171f14 100644 --- a/decode.go +++ b/decode.go @@ -107,51 +107,78 @@ func (d *Decoder) mergeValueNode(value ast.Node) ast.Node { return value } -func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) string { - key := d.nodeToValue(node) +func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) (string, error) { + key, err := d.nodeToValue(node) + if err != nil { + return "", err + } if key == nil { - return "null" + return "null", nil } if k, ok := key.(string); ok { - return k + return k, nil } - return fmt.Sprint(key) + return fmt.Sprint(key), nil } -func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) { +func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error { d.setPathToCommentMap(node) switch n := node.(type) { case *ast.MappingValueNode: if n.Key.Type() == ast.MergeKeyType { - d.setToMapValue(d.mergeValueNode(n.Value), m) + if err := d.setToMapValue(d.mergeValueNode(n.Value), m); err != nil { + return err + } } else { - key := d.mapKeyNodeToString(n.Key) - m[key] = d.nodeToValue(n.Value) + key, err := d.mapKeyNodeToString(n.Key) + if err != nil { + return err + } + v, err := d.nodeToValue(n.Value) + if err != nil { + return err + } + m[key] = v } case *ast.MappingNode: for _, value := range n.Values { - d.setToMapValue(value, m) + if err := d.setToMapValue(value, m); err != nil { + return err + } } case *ast.AnchorNode: anchorName := n.Name.GetToken().Value d.anchorNodeMap[anchorName] = n.Value } + return nil } -func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) { +func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error { switch n := node.(type) { case *ast.MappingValueNode: if n.Key.Type() == ast.MergeKeyType { - d.setToOrderedMapValue(d.mergeValueNode(n.Value), m) + if err := d.setToOrderedMapValue(d.mergeValueNode(n.Value), m); err != nil { + return err + } } else { - key := d.mapKeyNodeToString(n.Key) - *m = append(*m, MapItem{Key: key, Value: d.nodeToValue(n.Value)}) + key, err := d.mapKeyNodeToString(n.Key) + if err != nil { + return err + } + value, err := d.nodeToValue(n.Value) + if err != nil { + return err + } + *m = append(*m, MapItem{Key: key, Value: value}) } case *ast.MappingNode: for _, value := range n.Values { - d.setToOrderedMapValue(value, m) + if err := d.setToOrderedMapValue(value, m); err != nil { + return err + } } } + return nil } func (d *Decoder) setPathToCommentMap(node ast.Node) { @@ -260,38 +287,50 @@ func (d *Decoder) addCommentToMap(path string, comment *Comment) { }) } -func (d *Decoder) nodeToValue(node ast.Node) interface{} { +func (d *Decoder) nodeToValue(node ast.Node) (any, error) { d.setPathToCommentMap(node) switch n := node.(type) { case *ast.NullNode: - return nil + return nil, nil case *ast.StringNode: - return n.GetValue() + return n.GetValue(), nil case *ast.IntegerNode: - return n.GetValue() + return n.GetValue(), nil case *ast.FloatNode: - return n.GetValue() + return n.GetValue(), nil case *ast.BoolNode: - return n.GetValue() + return n.GetValue(), nil case *ast.InfinityNode: - return n.GetValue() + return n.GetValue(), nil case *ast.NanNode: - return n.GetValue() + return n.GetValue(), nil case *ast.TagNode: switch token.ReservedTagKeyword(n.Start.Value) { case token.TimestampTag: t, _ := d.castToTime(n.Value) - return t + return t, nil case token.IntegerTag: - i, _ := strconv.Atoi(fmt.Sprint(d.nodeToValue(n.Value))) - return i + v, err := d.nodeToValue(n.Value) + if err != nil { + return nil, err + } + i, _ := strconv.Atoi(fmt.Sprint(v)) + return i, nil case token.FloatTag: - return d.castToFloat(d.nodeToValue(n.Value)) + v, err := d.nodeToValue(n.Value) + if err != nil { + return nil, err + } + return d.castToFloat(v), nil case token.NullTag: - return nil + return nil, nil case token.BinaryTag: - b, _ := base64.StdEncoding.DecodeString(d.nodeToValue(n.Value).(string)) - return b + v, err := d.nodeToValue(n.Value) + if err != nil { + return nil, err + } + b, _ := base64.StdEncoding.DecodeString(v.(string)) + return b, nil case token.StringTag: return d.nodeToValue(n.Value) case token.MappingTag: @@ -299,25 +338,38 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} { } case *ast.AnchorNode: anchorName := n.Name.GetToken().Value - anchorValue := d.nodeToValue(n.Value) + + // To handle the case where alias is processed recursively, the result of alias can be set to nil in advance. + d.anchorNodeMap[anchorName] = nil + anchorValue, err := d.nodeToValue(n.Value) + if err != nil { + delete(d.anchorNodeMap, anchorName) + return nil, err + } d.anchorNodeMap[anchorName] = n.Value - return anchorValue + return anchorValue, nil case *ast.AliasNode: if v, exists := d.aliasValueMap[n]; exists { - return v + return v, nil } // To handle the case where alias is processed recursively, the result of alias can be set to nil in advance. d.aliasValueMap[n] = nil aliasName := n.Value.GetToken().Value - node := d.anchorNodeMap[aliasName] - aliasValue := d.nodeToValue(node) + node, exists := d.anchorNodeMap[aliasName] + if !exists { + return nil, errors.ErrSyntax(fmt.Sprintf("could not find alias %q", aliasName), n.Value.GetToken()) + } + aliasValue, err := d.nodeToValue(node) + if err != nil { + return nil, err + } // once the correct alias value is obtained, overwrite with that value. d.aliasValueMap[n] = aliasValue - return aliasValue + return aliasValue, nil case *ast.LiteralNode: - return n.Value.GetValue() + return n.Value.GetValue(), nil case *ast.MappingKeyNode: return d.nodeToValue(n.Value) case *ast.MappingValueNode: @@ -325,41 +377,62 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} { value := d.mergeValueNode(n.Value) if d.useOrderedMap { m := MapSlice{} - d.setToOrderedMapValue(value, &m) - return m + if err := d.setToOrderedMapValue(value, &m); err != nil { + return nil, err + } + return m, nil } m := map[string]interface{}{} - d.setToMapValue(value, m) - return m + if err := d.setToMapValue(value, m); err != nil { + return nil, err + } + return m, nil + } + key, err := d.mapKeyNodeToString(n.Key) + if err != nil { + return nil, err } - key := d.mapKeyNodeToString(n.Key) if d.useOrderedMap { - return MapSlice{{Key: key, Value: d.nodeToValue(n.Value)}} + v, err := d.nodeToValue(n.Value) + if err != nil { + return nil, err + } + return MapSlice{{Key: key, Value: v}}, nil } - return map[string]interface{}{ - key: d.nodeToValue(n.Value), + v, err := d.nodeToValue(n.Value) + if err != nil { + return nil, err } + return map[string]interface{}{key: v}, nil case *ast.MappingNode: if d.useOrderedMap { m := make(MapSlice, 0, len(n.Values)) for _, value := range n.Values { - d.setToOrderedMapValue(value, &m) + if err := d.setToOrderedMapValue(value, &m); err != nil { + return nil, err + } } - return m + return m, nil } m := make(map[string]interface{}, len(n.Values)) for _, value := range n.Values { - d.setToMapValue(value, m) + if err := d.setToMapValue(value, m); err != nil { + return nil, err + } } - return m + return m, nil case *ast.SequenceNode: v := make([]interface{}, 0, len(n.Values)) for _, value := range n.Values { - v = append(v, d.nodeToValue(value)) + vv, err := d.nodeToValue(value) + if err != nil { + return nil, err + } + v = append(v, vv) } - return v + return v, nil } - return nil + return nil, nil } func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) { @@ -844,7 +917,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No dst.Set(reflect.ValueOf(src)) return nil } - v := reflect.ValueOf(d.nodeToValue(src)) + srcVal, err := d.nodeToValue(src) + if err != nil { + return err + } + v := reflect.ValueOf(srcVal) if v.IsValid() { dst.Set(v) } @@ -863,7 +940,10 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No } return d.decodeStruct(ctx, dst, src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v := d.nodeToValue(src) + v, err := d.nodeToValue(src) + if err != nil { + return err + } switch vv := v.(type) { case int64: if !dst.OverflowInt(vv) { @@ -894,7 +974,10 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No } return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v := d.nodeToValue(src) + v, err := d.nodeToValue(src) + if err != nil { + return err + } switch vv := v.(type) { case int64: if 0 <= vv && !dst.OverflowUint(uint64(vv)) { @@ -926,7 +1009,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No } return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken()) } - v := reflect.ValueOf(d.nodeToValue(src)) + srcVal, err := d.nodeToValue(src) + if err != nil { + return err + } + v := reflect.ValueOf(srcVal) if v.IsValid() { convertedValue, err := d.convertValue(v, dst.Type(), src) if err != nil { @@ -1025,7 +1112,11 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue keyToNodeMap[k] = v } } else { - key, ok := d.nodeToValue(keyNode).(string) + keyVal, err := d.nodeToValue(keyNode) + if err != nil { + return nil, err + } + key, ok := keyVal.(string) if !ok { return nil, err } @@ -1094,7 +1185,10 @@ func (d *Decoder) castToTime(src ast.Node) (time.Time, error) { if src == nil { return time.Time{}, nil } - v := d.nodeToValue(src) + v, err := d.nodeToValue(src) + if err != nil { + return time.Time{}, err + } if t, ok := v.(time.Time); ok { return t, nil } @@ -1126,7 +1220,10 @@ func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) { if src == nil { return 0, nil } - v := d.nodeToValue(src) + v, err := d.nodeToValue(src) + if err != nil { + return 0, err + } if t, ok := v.(time.Duration); ok { return t, nil } @@ -1435,10 +1532,15 @@ func (d *Decoder) decodeMapItem(ctx context.Context, dst *MapItem, src ast.Node) } return nil } - *dst = MapItem{ - Key: d.nodeToValue(key), - Value: d.nodeToValue(value), + k, err := d.nodeToValue(key) + if err != nil { + return err + } + v, err := d.nodeToValue(value) + if err != nil { + return err } + *dst = MapItem{Key: k, Value: v} return nil } @@ -1483,14 +1585,18 @@ func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Nod } continue } - k := d.nodeToValue(key) + k, err := d.nodeToValue(key) + if err != nil { + return err + } if err := d.validateDuplicateKey(keyMap, k, key); err != nil { return err } - mapSlice = append(mapSlice, MapItem{ - Key: k, - Value: d.nodeToValue(value), - }) + v, err := d.nodeToValue(value) + if err != nil { + return err + } + mapSlice = append(mapSlice, MapItem{Key: k, Value: v}) } *dst = mapSlice return nil @@ -1534,7 +1640,11 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node return err } } else { - k = reflect.ValueOf(d.nodeToValue(key)) + keyVal, err := d.nodeToValue(key) + if err != nil { + return err + } + k = reflect.ValueOf(keyVal) if k.IsValid() && k.Type().ConvertibleTo(keyType) { k = k.Convert(keyType) } @@ -1682,7 +1792,11 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) { normalizedFile := &ast.File{} for _, doc := range f.Docs { // try to decode ast.Node to value and map anchor value to anchorMap - if v := d.nodeToValue(doc.Body); v != nil { + v, err := d.nodeToValue(doc.Body) + if err != nil { + return nil, err + } + if v != nil { normalizedFile.Docs = append(normalizedFile.Docs, doc) } } @@ -1780,7 +1894,9 @@ func (d *Decoder) DecodeFromNodeContext(ctx context.Context, node ast.Node, v in } } // resolve references to the anchor on the same file - d.nodeToValue(node) + if _, err := d.nodeToValue(node); err != nil { + return err + } if err := d.decodeValue(ctx, rv.Elem(), node); err != nil { return err } diff --git a/decode_test.go b/decode_test.go index 3e26f242..718ca2eb 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1125,6 +1125,35 @@ c: } } +func TestDecoder_Invalid(t *testing.T) { + tests := []struct { + src string + expect string + }{ + { + "*-0", + ` +[1:2] could not find alias "-0" +> 1 | *-0 + ^ +`, + }, + } + for _, test := range tests { + t.Run(test.src, func(t *testing.T) { + var v any + err := yaml.Unmarshal([]byte(test.src), &v) + if err == nil { + t.Fatal("cannot catch decode error") + } + actual := "\n" + err.Error() + if test.expect != actual { + t.Fatalf("expected: [%s] but got [%s]", test.expect, actual) + } + }) + } +} + func TestDecoder_ScientificNotation(t *testing.T) { tests := []struct { source string @@ -2634,7 +2663,7 @@ map: <<: *z e: f `, - err: "cannot find anchor by alias name y", + err: `could not find alias "y"`, }, }