Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the return type of Enumerable#sum/product for union elements #15314

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "spec"
require "../spec_helper"
require "spec/helpers/iterate"

module SomeInterface; end
Expand Down Expand Up @@ -1364,6 +1365,18 @@ describe "Enumerable" do
it { [1, 2, 3].sum(4.5).should eq(10.5) }
it { (1..3).sum { |x| x * 2 }.should eq(12) }
it { (1..3).sum(1.5) { |x| x * 2 }.should eq(13.5) }
it { [1, 3_u64].sum(0_i32).should eq(4_u32) }
it { [1, 3].sum(0_u64).should eq(4_u64) }
it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) }
it "raises if union types are summed", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].sum
CRYSTAL
"Enumerable#sum/product() does support Union types. Instead, " +
"use Enumerable#sum/product(initial) with an initial value of " +
"the expected type of the sum/product call."
end

it "uses additive_identity from type" do
typeof([1, 2, 3].sum).should eq(Int32)
Expand Down Expand Up @@ -1405,6 +1418,19 @@ describe "Enumerable" do
typeof([1.5, 2.5, 3.5].product).should eq(Float64)
typeof([1, 2, 3].product(&.to_f)).should eq(Float64)
end

it { [1, 3_u64].product(3_i32).should eq(9_u32) }
it { [1, 3].product(3_u64).should eq(9_u64) }
it { [1, 10000000000_u64].product(3_u64).should eq(30000000000_u64) }
it "raises if union types are multiplied", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].product
CRYSTAL
"Enumerable#sum/product() does support Union types. Instead, " +
"use Enumerable#sum/product(initial) with an initial value of " +
"the expected type of the sum/product call."
end
end

describe "first" do
Expand Down
18 changes: 14 additions & 4 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,10 @@ module Enumerable(T)
# Expects all types returned from the block to respond to `#+` method.
#
# This method calls `.additive_identity` on the yielded type to determine the
# type of the sum value.
# type of the sum value. Hence, it can fail to compile if
# `.additive_identity` fails to determine a safe type, e.g., in case of
# union types. In such cases, use `sum(initial)` with an initial value of
# the expected type of the sum value.
#
# If the collection is empty, returns `additive_identity`.
#
Expand Down Expand Up @@ -1886,8 +1889,11 @@ module Enumerable(T)
#
# Expects all types returned from the block to respond to `#*` method.
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# This method calls `.multiplicative_identity` on the element type to
# determine the type of the product value. Hence, it can fail to compile if
# `.multiplicative_identity` fails to determine a safe type, e.g., in case
# of union types. In such cases, use `product(initial)` with an initial
# value of the expected type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
Expand Down Expand Up @@ -2292,7 +2298,11 @@ module Enumerable(T)
# if the type is a union.
def self.first
{% if X.union? %}
{{X.union_types.first}}
{{
raise("Enumerable#sum/product() does support Union types. " +
"Instead, use Enumerable#sum/product(initial) with an " +
"initial value of the expected type of the sum/product call.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Polish the error message:

Suggested change
raise("Enumerable#sum/product() does support Union types. " +
"Instead, use Enumerable#sum/product(initial) with an " +
"initial value of the expected type of the sum/product call.")
raise("`Enumerable#sum` and `#product` cannot determine the initial value from a union type. " +
"Please pass an initial value of the intended type to the call.")

Copy link
Author

@rvprasad rvprasad Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading your fix, I realized a missing negation in the first sentence of error message I committed :)

Also, how about we update the error message as follows?

raise("Enumerable#sum` and #product do not support union types. Instead, use Enumerable#sum/product(initial) with an initial value of the expected type of the sum/product call.")

Following is my reasoning for the above change.

  1. The existing logic fails to identify an appropriate/safe type that can hold the final value of the call.
  2. Using the identity is a way of determining the appropriate/safe type. Since it is an implementation detail, I think surface such details in documentation while mentioning what (not how) is wrong and its fix in the error message may be better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"do not support union types" isn't entirely accurate. They do support them. Just without automatically determining the return type, hence the need to specify explicitly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, "sum()" and "product()" will not support union types (but "sum(initial)" and "product(initial)" will support union types). So, do you mean that we should be precise in the error message and the doc by referring to the no-arg variant of the methods?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could differentiate between the different overloads by signature.

Enumerable#sum() and #product() do not support union types.`

It's a bit of a subtle detail though. Hence my suggestion to write out the issue explicitly:

Enumerable#sum and #product cannot determine the initial value from a union type.

I'm fine either way though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR such that

  1. error messages and doc refer to specific overloads and
  2. error messages focus only on the what is wrong and how to fix it.

}}
{% else %}
X
{% end %}
Expand Down