Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the return type of Enumerable#sum/product for union elements #15314

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "spec"
require "../spec_helper"
require "spec/helpers/iterate"

module SomeInterface; end
Expand Down Expand Up @@ -1364,6 +1365,19 @@ describe "Enumerable" do
it { [1, 2, 3].sum(4.5).should eq(10.5) }
it { (1..3).sum { |x| x * 2 }.should eq(12) }
it { (1..3).sum(1.5) { |x| x * 2 }.should eq(13.5) }
it { [1, 3_u64].sum(0_i32).should eq(4_u32) }
it { [1, 3].sum(0_u64).should eq(4_u64) }
it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) }
it "raises if union types are summed", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].sum
CRYSTAL
"`Enumerable#sum()` and `#product()` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call."
end

it "uses additive_identity from type" do
typeof([1, 2, 3].sum).should eq(Int32)
Expand Down Expand Up @@ -1405,6 +1419,20 @@ describe "Enumerable" do
typeof([1.5, 2.5, 3.5].product).should eq(Float64)
typeof([1, 2, 3].product(&.to_f)).should eq(Float64)
end

it { [1, 3_u64].product(3_i32).should eq(9_u32) }
it { [1, 3].product(3_u64).should eq(9_u64) }
it { [1, 10000000000_u64].product(3_u64).should eq(30000000000_u64) }
it "raises if union types are multiplied", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].product
CRYSTAL
"`Enumerable#sum()` and `#product()` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call."
end
end

describe "first" do
Expand Down
34 changes: 22 additions & 12 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ module Enumerable(T)
end

private def additive_identity(reflect)
type = reflect.first
type = reflect.type
if type.responds_to? :additive_identity
type.additive_identity
else
Expand Down Expand Up @@ -1808,7 +1808,10 @@ module Enumerable(T)
# Expects all types returned from the block to respond to `#+` method.
#
# This method calls `.additive_identity` on the yielded type to determine the
# type of the sum value.
# type of the sum value. Hence, it can fail to compile if
# `.additive_identity` fails to determine a safe type, e.g., in case of
# union types. In such cases, use `sum(initial)` with an initial value of
# the expected type of the sum value.
#
# If the collection is empty, returns `additive_identity`.
#
Expand Down Expand Up @@ -1847,15 +1850,15 @@ module Enumerable(T)
# ```
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
# ```
# ([] of Int32).product # => 1
# ```
def product
product Reflect(T).first.multiplicative_identity
product Reflect(T).type.multiplicative_identity
end

# Multiplies *initial* and all the elements in the collection
Expand Down Expand Up @@ -1886,16 +1889,19 @@ module Enumerable(T)
#
# Expects all types returned from the block to respond to `#*` method.
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# This method calls `.multiplicative_identity` on the element type to
# determine the type of the product value. Hence, it can fail to compile if
# `.multiplicative_identity` fails to determine a safe type, e.g., in case
# of union types. In such cases, use `product(initial)` with an initial
# value of the expected type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
# ```
# ([] of Int32).product { |x| x + 1 } # => 1
# ```
def product(& : T -> _)
product(Reflect(typeof(yield Enumerable.element_type(self))).first.multiplicative_identity) do |value|
product(Reflect(typeof(yield Enumerable.element_type(self))).type.multiplicative_identity) do |value|
yield value
end
end
Expand Down Expand Up @@ -2287,12 +2293,16 @@ module Enumerable(T)

# :nodoc:
private struct Reflect(X)
# For now it's just a way to implement `Enumerable#sum` in a way that the
# initial value given to it has the type of the first type in the union,
# if the type is a union.
def self.first
# For now, Reflect is used to reject union types in `#sum()` and
# `#product()` methods.
def self.type
{% if X.union? %}
{{X.union_types.first}}
{{
raise("`Enumerable#sum()` and `#product()` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call.")
}}
{% else %}
X
{% end %}
Expand Down