From 4da64d9a029eef6d7eda7fb959a8d5c2d5a42bed Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Wed, 27 Nov 2024 17:08:29 +0900 Subject: [PATCH] guard for stack overflow --- decode.go | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++++- error.go | 1 + 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/decode.go b/decode.go index d07f1f4..f02fc05 100644 --- a/decode.go +++ b/decode.go @@ -43,6 +43,7 @@ type Decoder struct { useJSONUnmarshaler bool parsedFile *ast.File streamIndex int + decodeDepth int } // NewDecoder returns a new decoder that reads from r. @@ -65,6 +66,20 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder { } } +const maxDecodeDepth = 10000 + +func (d *Decoder) stepIn() { + d.decodeDepth++ +} + +func (d *Decoder) stepOut() { + d.decodeDepth-- +} + +func (d *Decoder) isExceededMaxDepth() bool { + return d.decodeDepth > maxDecodeDepth +} + func (d *Decoder) castToFloat(v interface{}) interface{} { switch vv := v.(type) { case int: @@ -123,6 +138,12 @@ func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) (string, error) { } func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + d.setPathToCommentMap(node) switch n := node.(type) { case *ast.MappingValueNode: @@ -155,6 +176,12 @@ func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error { } func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + d.setPathToCommentMap(node) switch n := node.(type) { case *ast.MappingValueNode: @@ -304,6 +331,12 @@ func (d *Decoder) addCommentToMap(path string, comment *Comment) { } func (d *Decoder) nodeToValue(node ast.Node) (any, error) { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return nil, ErrExceededMaxDepth + } + d.setPathToCommentMap(node) switch n := node.(type) { case *ast.NullNode: @@ -345,7 +378,14 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) { if err != nil { return nil, err } - b, _ := base64.StdEncoding.DecodeString(v.(string)) + str, ok := v.(string) + if !ok { + return nil, errors.ErrSyntax( + fmt.Sprintf("cannot convert %q to string", fmt.Sprint(v)), + n.Value.GetToken(), + ) + } + b, _ := base64.StdEncoding.DecodeString(str) return b, nil case token.BooleanTag: v, err := d.nodeToValue(n.Value) @@ -399,7 +439,6 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) { if err != nil { return nil, err } - // once the correct alias value is obtained, overwrite with that value. d.aliasValueMap[n] = aliasValue return aliasValue, nil @@ -471,6 +510,12 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) { } func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return nil, ErrExceededMaxDepth + } + switch n := node.(type) { case *ast.MappingNode: for idx, v := range n.Values { @@ -534,6 +579,12 @@ func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) { } func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return nil, ErrExceededMaxDepth + } + if _, ok := node.(*ast.NullNode); ok { return nil, nil } @@ -564,6 +615,12 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { } func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return nil, ErrExceededMaxDepth + } + if _, ok := node.(*ast.NullNode); ok { return nil, nil } @@ -887,6 +944,12 @@ var ( ) func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.Node) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + if src.Type() == ast.AnchorType { anchorName := src.(*ast.AnchorNode).Name.GetToken().Value if _, exists := d.anchorValueMap[anchorName]; !exists { @@ -1088,6 +1151,12 @@ func (d *Decoder) createDecodedNewValue( } func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return nil, ErrExceededMaxDepth + } + mapNode, err := d.getMapNode(node) if err != nil { return nil, err @@ -1274,6 +1343,12 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N if src == nil { return nil } + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + structType := dst.Type() srcValue := reflect.ValueOf(src) srcType := srcValue.Type() @@ -1439,6 +1514,12 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N } func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.Node) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + arrayNode, err := d.getArrayNode(src) if err != nil { return err @@ -1479,6 +1560,12 @@ func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.No } func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.Node) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + arrayNode, err := d.getArrayNode(src) if err != nil { return err @@ -1516,6 +1603,12 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No } func (d *Decoder) decodeMapItem(ctx context.Context, dst *MapItem, src ast.Node) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + mapNode, err := d.getMapNode(src) if err != nil { return err @@ -1562,6 +1655,12 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface } func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Node) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + mapNode, err := d.getMapNode(src) if err != nil { return err @@ -1606,6 +1705,12 @@ func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Nod } func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node) error { + d.stepIn() + defer d.stepOut() + if d.isExceededMaxDepth() { + return ErrExceededMaxDepth + } + mapNode, err := d.getMapNode(src) if err != nil { return err @@ -1839,6 +1944,7 @@ func (d *Decoder) decodeInit() error { } func (d *Decoder) decode(ctx context.Context, v reflect.Value) error { + d.decodeDepth = 0 if len(d.parsedFile.Docs) <= d.streamIndex { return io.EOF } diff --git a/error.go b/error.go index 3d09712..6d2a759 100644 --- a/error.go +++ b/error.go @@ -15,6 +15,7 @@ var ( ErrUnknownCommentPositionType = errors.New("unknown comment position type") ErrInvalidCommentMapValue = errors.New("invalid comment map value. it must be not nil value") ErrDecodeRequiredPointerType = errors.New("required pointer type value") + ErrExceededMaxDepth = errors.New("exceeded max depth") ) type (