Skip to content

Commit

Permalink
fix error (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy authored Nov 14, 2024
1 parent 2c6a0e7 commit 271213a
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 155 deletions.
64 changes: 16 additions & 48 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
if ok {
return mapNode, nil
}
return nil, errUnexpectedNodeType(anchor.Value.Type(), ast.MappingType, node.GetToken())
return nil, errors.ErrUnexpectedNodeType(anchor.Value.Type(), ast.MappingType, node.GetToken())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
Expand All @@ -540,11 +540,11 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
if ok {
return mapNode, nil
}
return nil, errUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}
mapNode, ok := node.(ast.MapNode)
if !ok {
return nil, errUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}
return mapNode, nil
}
Expand All @@ -559,7 +559,7 @@ func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
return arrayNode, nil
}

return nil, errUnexpectedNodeType(anchor.Value.Type(), ast.SequenceType, node.GetToken())
return nil, errors.ErrUnexpectedNodeType(anchor.Value.Type(), ast.SequenceType, node.GetToken())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
Expand All @@ -571,11 +571,11 @@ func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
if ok {
return arrayNode, nil
}
return nil, errUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken())
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken())
}
arrayNode, ok := node.(ast.ArrayNode)
if !ok {
return nil, errUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken())
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken())
}
return arrayNode, nil
}
Expand All @@ -598,7 +598,7 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node)
// else, fall through to the error below
}
}
return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken())
return reflect.Zero(typ), errors.ErrTypeMismatch(typ, v.Type(), src.GetToken())
}
return v.Convert(typ), nil
}
Expand All @@ -614,43 +614,11 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node)
return reflect.ValueOf(fmt.Sprint(v.Bool())), nil
}
if !v.Type().ConvertibleTo(typ) {
return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken())
return reflect.Zero(typ), errors.ErrTypeMismatch(typ, v.Type(), src.GetToken())
}
return v.Convert(typ), nil
}

func errTypeMismatch(dstType, srcType reflect.Type, token *token.Token) *errors.TypeError {
return &errors.TypeError{DstType: dstType, SrcType: srcType, Token: token}
}

type unknownFieldError struct {
err error
}

func (e *unknownFieldError) Error() string {
return e.err.Error()
}

func errUnknownField(msg string, tk *token.Token) *unknownFieldError {
return &unknownFieldError{err: errors.ErrSyntax(msg, tk)}
}

func errUnexpectedNodeType(actual, expected ast.NodeType, tk *token.Token) error {
return errors.ErrSyntax(fmt.Sprintf("%s was used where %s is expected", actual.YAMLName(), expected.YAMLName()), tk)
}

type duplicateKeyError struct {
err error
}

func (e *duplicateKeyError) Error() string {
return e.err.Error()
}

func errDuplicateKey(msg string, tk *token.Token) *duplicateKeyError {
return &duplicateKeyError{err: errors.ErrSyntax(msg, tk)}
}

func (d *Decoder) deleteStructKeys(structType reflect.Type, unknownFields map[string]ast.Node) error {
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
Expand Down Expand Up @@ -988,10 +956,10 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
return nil
}
} else { // couldn't be parsed as float
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
default:
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
Expand Down Expand Up @@ -1022,11 +990,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
return nil
}
} else { // couldn't be parsed as float
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}

default:
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken())
}
Expand Down Expand Up @@ -1215,7 +1183,7 @@ func (d *Decoder) castToTime(src ast.Node) (time.Time, error) {
}
s, ok := v.(string)
if !ok {
return time.Time{}, errTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v), src.GetToken())
return time.Time{}, errors.ErrTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v), src.GetToken())
}
for _, format := range allowedTimestampFormats {
t, err := time.Parse(format, s)
Expand Down Expand Up @@ -1250,7 +1218,7 @@ func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) {
}
s, ok := v.(string)
if !ok {
return 0, errTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v), src.GetToken())
return 0, errors.ErrTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v), src.GetToken())
}
t, err := time.ParseDuration(s)
if err != nil {
Expand Down Expand Up @@ -1421,7 +1389,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
// Unknown fields are expected (they could be fields from the parent struct).
if len(unknownFields) != 0 && d.disallowUnknownField && src.GetToken() != nil {
for key, node := range unknownFields {
return errUnknownField(fmt.Sprintf(`unknown field "%s"`, key), node.GetToken())
return errors.ErrUnknownField(fmt.Sprintf(`unknown field "%s"`, key), node.GetToken())
}
}

Expand Down Expand Up @@ -1572,7 +1540,7 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface
}
if !d.allowDuplicateMapKey {
if _, exists := keyMap[k]; exists {
return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken())
return errors.ErrDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken())
}
}
keyMap[k] = struct{}{}
Expand Down
11 changes: 10 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package yaml

import (
"errors"
"fmt"

"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
)

var (
Expand All @@ -17,6 +17,15 @@ var (
ErrDecodeRequiredPointerType = errors.New("required pointer type value")
)

type (
SyntaxError = errors.SyntaxError
TypeError = errors.TypeError
OverflowError = errors.OverflowError
DuplicateKeyError = errors.DuplicateKeyError
UnknownFieldError = errors.UnknownFieldError
UnexpectedNodeTypeError = errors.UnexpectedNodeTypeError
)

func ErrUnsupportedHeadPositionType(node ast.Node) error {
return fmt.Errorf("unsupported comment head position for %s", node.Type())
}
Expand Down
Loading

0 comments on commit 271213a

Please sign in to comment.