From b5f63d50d0f61b670cc2e34ebf262d2ed0e0cd99 Mon Sep 17 00:00:00 2001 From: Morris Kelly Date: Tue, 16 Jul 2024 11:38:06 +0100 Subject: [PATCH] Fix decoding of scientific notation (#463) * Fix scientific notation decoding and add encoding test cases * Deal with ints and uints * Add coverage for uint changes --- decode.go | 34 ++++++++++++++++++ decode_test.go | 96 +++++++++++++++++++++++++++++++++++++++++++++++++- encode_test.go | 20 +++++++++++ 3 files changed, 149 insertions(+), 1 deletion(-) diff --git a/decode.go b/decode.go index b6e22a5e..72af5e22 100644 --- a/decode.go +++ b/decode.go @@ -488,6 +488,21 @@ func (d *Decoder) fileToNode(f *ast.File) ast.Node { func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) (reflect.Value, error) { if typ.Kind() != reflect.String { if !v.Type().ConvertibleTo(typ) { + + // Special case for "strings -> floats" aka scientific notation + // If the destination type is a float and the source type is a string, check if we can + // use strconv.ParseFloat to convert the string to a float. + if (typ.Kind() == reflect.Float32 || typ.Kind() == reflect.Float64) && + v.Type().Kind() == reflect.String { + if f, err := strconv.ParseFloat(v.String(), 64); err == nil { + if typ.Kind() == reflect.Float32 { + return reflect.ValueOf(float32(f)), nil + } else if typ.Kind() == reflect.Float64 { + return reflect.ValueOf(f), nil + } + // else, fall through to the error below + } + } return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken()) } return v.Convert(typ), nil @@ -877,6 +892,15 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No dst.SetInt(int64(vv)) return nil } + case string: // handle scientific notation + if i, err := strconv.ParseFloat(vv, 64); err == nil { + if 0 <= i && i <= math.MaxUint64 && !dst.OverflowInt(int64(i)) { + dst.SetInt(int64(i)) + return nil + } + } else { // couldn't be parsed as float + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + } default: return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } @@ -899,6 +923,16 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No dst.SetUint(uint64(vv)) return nil } + case string: // handle scientific notation + if i, err := strconv.ParseFloat(vv, 64); err == nil { + if 0 <= i && i <= math.MaxUint64 && !dst.OverflowUint(uint64(i)) { + dst.SetUint(uint64(i)) + return nil + } + } else { // couldn't be parsed as float + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + } + default: return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } diff --git a/decode_test.go b/decode_test.go index 7e9b6635..f0b0c080 100644 --- a/decode_test.go +++ b/decode_test.go @@ -253,6 +253,10 @@ func TestDecoder(t *testing.T) { "v: 4294967295", map[string]uint{"v": math.MaxUint32}, }, + { + "v: 1e3", + map[string]uint{"v": 1000}, + }, // uint64 { @@ -271,6 +275,10 @@ func TestDecoder(t *testing.T) { "v: 9223372036854775807", map[string]uint64{"v": math.MaxInt64}, }, + { + "v: 1e3", + map[string]uint64{"v": 1000}, + }, // float32 { @@ -289,6 +297,10 @@ func TestDecoder(t *testing.T) { "v: 18446744073709551616", map[string]float32{"v": float32(math.MaxUint64 + 1)}, }, + { + "v: 1e-06", + map[string]float32{"v": 1e-6}, + }, // float64 { @@ -307,6 +319,10 @@ func TestDecoder(t *testing.T) { "v: 18446744073709551616", map[string]float64{"v": float64(math.MaxUint64 + 1)}, }, + { + "v: 1e-06", + map[string]float64{"v": 1e-06}, + }, // Timestamps { @@ -1093,6 +1109,73 @@ c: } } +func TestDecoder_ScientificNotation(t *testing.T) { + tests := []struct { + source string + value interface{} + }{ + { + "v: 1e3", + map[string]uint{"v": 1000}, + }, + { + "v: 1e-3", + map[string]uint{"v": 0}, + }, + { + "v: 1e3", + map[string]int{"v": 1000}, + }, + { + "v: 1e-3", + map[string]int{"v": 0}, + }, + { + "v: 1e3", + map[string]float32{"v": 1000}, + }, + { + "v: 1.0e3", + map[string]float64{"v": 1000}, + }, + { + "v: 1e-3", + map[string]float64{"v": 0.001}, + }, + { + "v: 1.0e-3", + map[string]float64{"v": 0.001}, + }, + { + "v: 1.0e+3", + map[string]float64{"v": 1000}, + }, + { + "v: 1.0e+3", + map[string]float64{"v": 1000}, + }, + } + for _, test := range tests { + t.Run(test.source, func(t *testing.T) { + buf := bytes.NewBufferString(test.source) + dec := yaml.NewDecoder(buf) + typ := reflect.ValueOf(test.value).Type() + value := reflect.New(typ) + if err := dec.Decode(value.Interface()); err != nil { + if err == io.EOF { + return + } + t.Fatalf("%s: %+v", test.source, err) + } + actual := fmt.Sprintf("%+v", value.Elem().Interface()) + expect := fmt.Sprintf("%+v", test.value) + if actual != expect { + t.Fatalf("failed to test [%s], actual=[%s], expect=[%s]", test.source, actual, expect) + } + }) + } +} + func TestDecoder_TypeConversionError(t *testing.T) { t.Run("type conversion for struct", func(t *testing.T) { type T struct { @@ -1115,6 +1198,17 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) + t.Run("string to uint", func(t *testing.T) { + var v T + err := yaml.Unmarshal([]byte(`b: str`), &v) + if err == nil { + t.Fatal("expected to error") + } + msg := "cannot unmarshal string into Go struct field T.B of type uint" + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) + } + }) t.Run("string to bool", func(t *testing.T) { var v T err := yaml.Unmarshal([]byte(`d: str`), &v) @@ -2932,4 +3026,4 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) { if val != "value" { t.Fatalf("expected to have value \"value\", but got %q", val) } -} \ No newline at end of file +} diff --git a/encode_test.go b/encode_test.go index 3ff6f1c1..9ca44040 100644 --- a/encode_test.go +++ b/encode_test.go @@ -80,6 +80,16 @@ func TestEncoder(t *testing.T) { map[string]float32{"v": 0.99}, nil, }, + { + "v: 1e-06\n", + map[string]float32{"v": 1e-06}, + nil, + }, + { + "v: 1e-06\n", + map[string]float64{"v": 0.000001}, + nil, + }, { "v: 0.123456789\n", map[string]float64{"v": 0.123456789}, @@ -100,6 +110,16 @@ func TestEncoder(t *testing.T) { map[string]float64{"v": 1000000}, nil, }, + { + "v: 1e-06\n", + map[string]float64{"v": 0.000001}, + nil, + }, + { + "v: 1e-06\n", + map[string]float64{"v": 1e-06}, + nil, + }, { "v: .inf\n", map[string]interface{}{"v": math.Inf(0)},