Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 44 additions & 40 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6718,6 +6718,7 @@ def narrow_type_by_equality(
is_target_for_value_narrowing = is_singleton_identity_type
should_coerce_literals = True
should_narrow_by_identity_equality = True
enum_comparison_is_ambiguous = False

elif operator in {"==", "!="}:
is_target_for_value_narrowing = is_singleton_equality_type
Expand All @@ -6730,9 +6731,8 @@ def narrow_type_by_equality(
break

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity_equality = not any(
map(has_custom_eq_checks, expr_types)
) and not is_ambiguous_mix_of_enums(expr_types)
should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types))
enum_comparison_is_ambiguous = True
else:
raise AssertionError

Expand Down Expand Up @@ -6765,11 +6765,18 @@ def narrow_type_by_equality(
for i in expr_indices:
if i not in narrowable_indices:
continue
expr_type = coerce_to_literal(operand_types[i])
expr_type = try_expanding_sum_type_to_union(expr_type, None)
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
for j, target in value_targets:
if i == j:
continue
expr_type = coerce_to_literal(operand_types[i])
expr_type = try_expanding_sum_type_to_union(expr_type, None)
if (
# See comments in ambiguous_enum_equality_keys
enum_comparison_is_ambiguous
and len(expr_enum_keys | ambiguous_enum_equality_keys(target.item)) > 1
):
continue
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
Expand All @@ -6779,10 +6786,10 @@ def narrow_type_by_equality(
for i in expr_indices:
if i not in narrowable_indices:
continue
expr_type = operand_types[i]
for j, target in type_targets:
if i == j:
continue
expr_type = operand_types[i]
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
Expand Down Expand Up @@ -9371,47 +9378,44 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
self.lvalue = False


def is_ambiguous_mix_of_enums(types: list[Type]) -> bool:
"""Do types have IntEnum/StrEnum types that are potentially overlapping with other types?
def ambiguous_enum_equality_keys(t: Type) -> set[str]:
"""
Used when narrowing types based on equality.

If True, we shouldn't attempt type narrowing based on enum values, as it gets
too ambiguous.
Certain kinds of enums can compare equal to values of other types, so doing type math
the way `conditional_types` does will be misleading if you expect it to correspond to
conditions based on equality comparisons.

For example, return True if there's an 'int' type together with an IntEnum literal.
However, IntEnum together with a literal of the same IntEnum type is not ambiguous.
For example, StrEnum classes can compare equal to str values. So if we see
`val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
"""
# We need these things for this to be ambiguous:
# (1) an IntEnum or StrEnum type
# (1) an IntEnum or StrEnum type or enum subclass of int or str
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
#
# It would be slightly more correct to calculate this separately for IntEnum and
# StrEnum related types, as an IntEnum can't be confused with a StrEnum.
return len(_ambiguous_enum_variants(types)) > 1


def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
result = set()
for t in types:
t = get_proper_type(t)
if isinstance(t, UnionType):
result.update(_ambiguous_enum_variants(t.items))
elif isinstance(t, Instance):
if t.last_known_value:
result.update(_ambiguous_enum_variants([t.last_known_value]))
elif t.type.is_enum and any(
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
):
result.add(t.type.fullname)
elif not t.type.is_enum:
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
# let's be conservative
result.add("<other>")
elif isinstance(t, LiteralType):
result.update(_ambiguous_enum_variants([t.fallback]))
elif isinstance(t, NoneType):
pass
else:
t = get_proper_type(t)
if isinstance(t, UnionType):
for item in t.items:
result.update(ambiguous_enum_equality_keys(item))
elif isinstance(t, Instance):
if t.last_known_value:
result.update(ambiguous_enum_equality_keys(t.last_known_value))
elif t.type.is_enum and any(
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is important line in the diff / review without whitespace changes

for base in t.type.mro
):
result.add(t.type.fullname)
elif not t.type.is_enum:
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
# let's be conservative
result.add("<other>")
elif isinstance(t, LiteralType):
result.update(ambiguous_enum_equality_keys(t.fallback))
elif isinstance(t, NoneType):
pass
else:
result.add("<other>")
return result


Expand Down
25 changes: 23 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2124,7 +2124,7 @@ else:
[builtins fixtures/ops.pyi]

[case testNarrowingWithIntEnum]
# mypy: strict-equality
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Any
from enum import IntEnum
Expand Down Expand Up @@ -2179,7 +2179,7 @@ def f6(x: IE) -> None:
[builtins fixtures/primitives.pyi]

[case testNarrowingWithIntEnum2]
# mypy: strict-equality
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Any
from enum import IntEnum, Enum
Expand Down Expand Up @@ -2284,6 +2284,27 @@ def f4(x: SE) -> None:
reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]"
[builtins fixtures/primitives.pyi]

[case testNarrowingWithEnumStrSubclass]
# flags: --strict-equality --warn-unreachable
from enum import Enum

class ParameterLocation(str, Enum):
QUERY = "query"
HEADER = "header"
PATH = "path"

def foo(location: ParameterLocation):
if location == "path":
reveal_type(location) # N: Revealed type is "__main__.ParameterLocation"
else:
reveal_type(location) # N: Revealed type is "__main__.ParameterLocation"

if location == ParameterLocation.PATH:
reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.PATH]"
else:
reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.QUERY] | Literal[__main__.ParameterLocation.HEADER]"
[builtins fixtures/primitives.pyi]

[case testConsistentNarrowingEqAndIn]
# flags: --python-version 3.10

Expand Down