From 84014b3253921e712a2d99b187ca099b7ffd5c1f Mon Sep 17 00:00:00 2001 From: Morris Kelly Date: Thu, 11 Jul 2024 22:56:03 +0100 Subject: [PATCH 1/3] Fix scientific notation decoding and add encoding test cases --- decode.go | 15 +++++++++++++++ decode_test.go | 10 +++++++++- encode_test.go | 20 ++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/decode.go b/decode.go index b6e22a5e..e851feb4 100644 --- a/decode.go +++ b/decode.go @@ -488,6 +488,21 @@ func (d *Decoder) fileToNode(f *ast.File) ast.Node { func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) (reflect.Value, error) { if typ.Kind() != reflect.String { if !v.Type().ConvertibleTo(typ) { + + // Special case for "strings -> floats" aka scientific notation + // If the destination type is a float and the source type is a string, check if we can + // use strconv.ParseFloat to convert the string to a float. + if (typ.Kind() == reflect.Float32 || typ.Kind() == reflect.Float64) && + v.Type().Kind() == reflect.String { + if f, err := strconv.ParseFloat(v.String(), 64); err == nil { + if typ.Kind() == reflect.Float32 { + return reflect.ValueOf(float32(f)), nil + } else if typ.Kind() == reflect.Float64 { + return reflect.ValueOf(f), nil + } + // else, fall through to the error below + } + } return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken()) } return v.Convert(typ), nil diff --git a/decode_test.go b/decode_test.go index 7e9b6635..bf532b63 100644 --- a/decode_test.go +++ b/decode_test.go @@ -289,6 +289,10 @@ func TestDecoder(t *testing.T) { "v: 18446744073709551616", map[string]float32{"v": float32(math.MaxUint64 + 1)}, }, + { + "v: 1e-06", + map[string]float32{"v": 1e-6}, + }, // float64 { @@ -307,6 +311,10 @@ func TestDecoder(t *testing.T) { "v: 18446744073709551616", map[string]float64{"v": float64(math.MaxUint64 + 1)}, }, + { + "v: 1e-06", + map[string]float64{"v": 1e-06}, + }, // Timestamps { @@ -2932,4 +2940,4 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) { if val != "value" { t.Fatalf("expected to have value \"value\", but got %q", val) } -} \ No newline at end of file +} diff --git a/encode_test.go b/encode_test.go index 3ff6f1c1..9ca44040 100644 --- a/encode_test.go +++ b/encode_test.go @@ -80,6 +80,16 @@ func TestEncoder(t *testing.T) { map[string]float32{"v": 0.99}, nil, }, + { + "v: 1e-06\n", + map[string]float32{"v": 1e-06}, + nil, + }, + { + "v: 1e-06\n", + map[string]float64{"v": 0.000001}, + nil, + }, { "v: 0.123456789\n", map[string]float64{"v": 0.123456789}, @@ -100,6 +110,16 @@ func TestEncoder(t *testing.T) { map[string]float64{"v": 1000000}, nil, }, + { + "v: 1e-06\n", + map[string]float64{"v": 0.000001}, + nil, + }, + { + "v: 1e-06\n", + map[string]float64{"v": 1e-06}, + nil, + }, { "v: .inf\n", map[string]interface{}{"v": math.Inf(0)}, From b4e78408d8721f68737a49f96145f45c622012a2 Mon Sep 17 00:00:00 2001 From: Morris Kelly Date: Fri, 12 Jul 2024 09:35:30 +0100 Subject: [PATCH 2/3] Deal with ints and uints --- decode.go | 20 +++++++++++++- decode_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/decode.go b/decode.go index e851feb4..c194da6c 100644 --- a/decode.go +++ b/decode.go @@ -490,7 +490,7 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) if !v.Type().ConvertibleTo(typ) { // Special case for "strings -> floats" aka scientific notation - // If the destination type is a float and the source type is a string, check if we can + // If the destination type is a float and the source type is a string, check if we can // use strconv.ParseFloat to convert the string to a float. if (typ.Kind() == reflect.Float32 || typ.Kind() == reflect.Float64) && v.Type().Kind() == reflect.String { @@ -892,6 +892,15 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No dst.SetInt(int64(vv)) return nil } + case string: // handle scientific notation + if i, err := strconv.ParseFloat(vv, 64); err == nil { + if 0 <= i && i <= math.MaxUint64 && !dst.OverflowInt(int64(i)) { + dst.SetInt(int64(i)) + return nil + } + } else { // couldn't be parsed as float + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + } default: return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } @@ -914,6 +923,15 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No dst.SetUint(uint64(vv)) return nil } + case string: // handle scientific notation + if i, err := strconv.ParseFloat(vv, 64); err == nil { + if 0 <= i && i <= math.MaxUint64 && !dst.OverflowUint(uint64(i)) { + dst.SetUint(uint64(i)) + return nil + } else { // couldn't be parsed as float + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + } + } default: return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } diff --git a/decode_test.go b/decode_test.go index bf532b63..a76be944 100644 --- a/decode_test.go +++ b/decode_test.go @@ -253,6 +253,10 @@ func TestDecoder(t *testing.T) { "v: 4294967295", map[string]uint{"v": math.MaxUint32}, }, + { + "v: 1e3", + map[string]uint{"v": 1000}, + }, // uint64 { @@ -271,6 +275,10 @@ func TestDecoder(t *testing.T) { "v: 9223372036854775807", map[string]uint64{"v": math.MaxInt64}, }, + { + "v: 1e3", + map[string]uint64{"v": 1000}, + }, // float32 { @@ -1101,6 +1109,73 @@ c: } } +func TestDecoder_ScientificNotation(t *testing.T) { + tests := []struct { + source string + value interface{} + }{ + { + "v: 1e3", + map[string]uint{"v": 1000}, + }, + { + "v: 1e-3", + map[string]uint{"v": 0}, + }, + { + "v: 1e3", + map[string]int{"v": 1000}, + }, + { + "v: 1e-3", + map[string]int{"v": 0}, + }, + { + "v: 1e3", + map[string]float32{"v": 1000}, + }, + { + "v: 1.0e3", + map[string]float64{"v": 1000}, + }, + { + "v: 1e-3", + map[string]float64{"v": 0.001}, + }, + { + "v: 1.0e-3", + map[string]float64{"v": 0.001}, + }, + { + "v: 1.0e+3", + map[string]float64{"v": 1000}, + }, + { + "v: 1.0e+3", + map[string]float64{"v": 1000}, + }, + } + for _, test := range tests { + t.Run(test.source, func(t *testing.T) { + buf := bytes.NewBufferString(test.source) + dec := yaml.NewDecoder(buf) + typ := reflect.ValueOf(test.value).Type() + value := reflect.New(typ) + if err := dec.Decode(value.Interface()); err != nil { + if err == io.EOF { + return + } + t.Fatalf("%s: %+v", test.source, err) + } + actual := fmt.Sprintf("%+v", value.Elem().Interface()) + expect := fmt.Sprintf("%+v", test.value) + if actual != expect { + t.Fatalf("failed to test [%s], actual=[%s], expect=[%s]", test.source, actual, expect) + } + }) + } +} + func TestDecoder_TypeConversionError(t *testing.T) { t.Run("type conversion for struct", func(t *testing.T) { type T struct { From e9430c0b3714c7d5b103c79b1436cbb9044616a7 Mon Sep 17 00:00:00 2001 From: Morris Kelly Date: Fri, 12 Jul 2024 09:59:04 +0100 Subject: [PATCH 3/3] Add coverage for uint changes --- decode.go | 5 +++-- decode_test.go | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/decode.go b/decode.go index c194da6c..72af5e22 100644 --- a/decode.go +++ b/decode.go @@ -928,10 +928,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No if 0 <= i && i <= math.MaxUint64 && !dst.OverflowUint(uint64(i)) { dst.SetUint(uint64(i)) return nil - } else { // couldn't be parsed as float - return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } + } else { // couldn't be parsed as float + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } + default: return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } diff --git a/decode_test.go b/decode_test.go index a76be944..f0b0c080 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1198,6 +1198,17 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) + t.Run("string to uint", func(t *testing.T) { + var v T + err := yaml.Unmarshal([]byte(`b: str`), &v) + if err == nil { + t.Fatal("expected to error") + } + msg := "cannot unmarshal string into Go struct field T.B of type uint" + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) + } + }) t.Run("string to bool", func(t *testing.T) { var v T err := yaml.Unmarshal([]byte(`d: str`), &v)