From 6129001528ebba99117a9e2168dbff3d41ef99fb Mon Sep 17 00:00:00 2001 From: Venkatesh-Prasad Ranganath Date: Fri, 20 Dec 2024 23:58:54 -0600 Subject: [PATCH] Fixes #15269 Using the first type of a union type as the type of the result of `Enumerable#sum()` call can cause runtime failures. A safer alternative is to flag the use of union types with `Enumerable#sum()` and suggest the use of `Enumerable#sum(initial)` with an initial value of the expected type of the `sum` call. --- spec/std/enumerable_spec.cr | 13 +++++++++++++ src/enumerable.cr | 6 +++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/spec/std/enumerable_spec.cr b/spec/std/enumerable_spec.cr index 084fe80dcf96..a28358e10404 100644 --- a/spec/std/enumerable_spec.cr +++ b/spec/std/enumerable_spec.cr @@ -1,5 +1,6 @@ require "spec" require "spec/helpers/iterate" +require "../spec_helper" module SomeInterface; end @@ -1364,6 +1365,18 @@ describe "Enumerable" do it { [1, 2, 3].sum(4.5).should eq(10.5) } it { (1..3).sum { |x| x * 2 }.should eq(12) } it { (1..3).sum(1.5) { |x| x * 2 }.should eq(13.5) } + it { [1, 3_u64].sum(0_i32).should eq(4_u32) } + it { [1, 3].sum(0_u64).should eq(4_u64) } + it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) } + it "raises if union types are summed" do + exc = assert_error <<-CRYSTAL, + require "prelude" + [1, 10000000000_u64].sum + CRYSTAL + "Enumerable#sum() does support Union types. Instead, " + + "use Enumerable#sum(initial) with an initial value of " + + "the expected type of the sum call." + end it "uses additive_identity from type" do typeof([1, 2, 3].sum).should eq(Int32) diff --git a/src/enumerable.cr b/src/enumerable.cr index 0993f38bbc4d..0ae315617286 100644 --- a/src/enumerable.cr +++ b/src/enumerable.cr @@ -2292,7 +2292,11 @@ module Enumerable(T) # if the type is a union. def self.first {% if X.union? %} - {{X.union_types.first}} + {{ + raise("Enumerable#sum() does support Union types. Instead, " + + "use Enumerable#sum(initial) with an initial value of " + + "the expected type of the sum call.") + }} {% else %} X {% end %}