Skip to content

Commit

Permalink
guard stack overflow at decoding (#549)
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy authored Nov 27, 2024
1 parent 23c9234 commit 3584ab7
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 2 deletions.
110 changes: 108 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Decoder struct {
useJSONUnmarshaler bool
parsedFile *ast.File
streamIndex int
decodeDepth int
}

// NewDecoder returns a new decoder that reads from r.
Expand All @@ -65,6 +66,20 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
}
}

const maxDecodeDepth = 10000

func (d *Decoder) stepIn() {
d.decodeDepth++
}

func (d *Decoder) stepOut() {
d.decodeDepth--
}

func (d *Decoder) isExceededMaxDepth() bool {
return d.decodeDepth > maxDecodeDepth
}

func (d *Decoder) castToFloat(v interface{}) interface{} {
switch vv := v.(type) {
case int:
Expand Down Expand Up @@ -123,6 +138,12 @@ func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) (string, error) {
}

func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

d.setPathToCommentMap(node)
switch n := node.(type) {
case *ast.MappingValueNode:
Expand Down Expand Up @@ -155,6 +176,12 @@ func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
}

func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

d.setPathToCommentMap(node)
switch n := node.(type) {
case *ast.MappingValueNode:
Expand Down Expand Up @@ -304,6 +331,12 @@ func (d *Decoder) addCommentToMap(path string, comment *Comment) {
}

func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return nil, ErrExceededMaxDepth
}

d.setPathToCommentMap(node)
switch n := node.(type) {
case *ast.NullNode:
Expand Down Expand Up @@ -345,7 +378,14 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
if err != nil {
return nil, err
}
b, _ := base64.StdEncoding.DecodeString(v.(string))
str, ok := v.(string)
if !ok {
return nil, errors.ErrSyntax(
fmt.Sprintf("cannot convert %q to string", fmt.Sprint(v)),
n.Value.GetToken(),
)
}
b, _ := base64.StdEncoding.DecodeString(str)
return b, nil
case token.BooleanTag:
v, err := d.nodeToValue(n.Value)
Expand Down Expand Up @@ -399,7 +439,6 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
if err != nil {
return nil, err
}

// once the correct alias value is obtained, overwrite with that value.
d.aliasValueMap[n] = aliasValue
return aliasValue, nil
Expand Down Expand Up @@ -471,6 +510,12 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
}

func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return nil, ErrExceededMaxDepth
}

switch n := node.(type) {
case *ast.MappingNode:
for idx, v := range n.Values {
Expand Down Expand Up @@ -534,6 +579,12 @@ func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) {
}

func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return nil, ErrExceededMaxDepth
}

if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
Expand Down Expand Up @@ -564,6 +615,12 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
}

func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return nil, ErrExceededMaxDepth
}

if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
Expand Down Expand Up @@ -887,6 +944,12 @@ var (
)

func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.Node) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

if src.Type() == ast.AnchorType {
anchorName := src.(*ast.AnchorNode).Name.GetToken().Value
if _, exists := d.anchorValueMap[anchorName]; !exists {
Expand Down Expand Up @@ -1088,6 +1151,12 @@ func (d *Decoder) createDecodedNewValue(
}

func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return nil, ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(node)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1274,6 +1343,12 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
if src == nil {
return nil
}
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

structType := dst.Type()
srcValue := reflect.ValueOf(src)
srcType := srcValue.Type()
Expand Down Expand Up @@ -1439,6 +1514,12 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
}

func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.Node) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

arrayNode, err := d.getArrayNode(src)
if err != nil {
return err
Expand Down Expand Up @@ -1479,6 +1560,12 @@ func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.No
}

func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.Node) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

arrayNode, err := d.getArrayNode(src)
if err != nil {
return err
Expand Down Expand Up @@ -1516,6 +1603,12 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No
}

func (d *Decoder) decodeMapItem(ctx context.Context, dst *MapItem, src ast.Node) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(src)
if err != nil {
return err
Expand Down Expand Up @@ -1562,6 +1655,12 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface
}

func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Node) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(src)
if err != nil {
return err
Expand Down Expand Up @@ -1606,6 +1705,12 @@ func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Nod
}

func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(src)
if err != nil {
return err
Expand Down Expand Up @@ -1839,6 +1944,7 @@ func (d *Decoder) decodeInit() error {
}

func (d *Decoder) decode(ctx context.Context, v reflect.Value) error {
d.decodeDepth = 0
if len(d.parsedFile.Docs) <= d.streamIndex {
return io.EOF
}
Expand Down
1 change: 1 addition & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
ErrUnknownCommentPositionType = errors.New("unknown comment position type")
ErrInvalidCommentMapValue = errors.New("invalid comment map value. it must be not nil value")
ErrDecodeRequiredPointerType = errors.New("required pointer type value")
ErrExceededMaxDepth = errors.New("exceeded max depth")
)

type (
Expand Down

0 comments on commit 3584ab7

Please sign in to comment.