diff --git a/mypy/checker.py b/mypy/checker.py index 6d70dcb90e94..0a506ee983ba 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6718,6 +6718,7 @@ def narrow_type_by_equality( is_target_for_value_narrowing = is_singleton_identity_type should_coerce_literals = True should_narrow_by_identity_equality = True + enum_comparison_is_ambiguous = False elif operator in {"==", "!="}: is_target_for_value_narrowing = is_singleton_equality_type @@ -6730,9 +6731,8 @@ def narrow_type_by_equality( break expr_types = [operand_types[i] for i in expr_indices] - should_narrow_by_identity_equality = not any( - map(has_custom_eq_checks, expr_types) - ) and not is_ambiguous_mix_of_enums(expr_types) + should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types)) + enum_comparison_is_ambiguous = True else: raise AssertionError @@ -6765,11 +6765,18 @@ def narrow_type_by_equality( for i in expr_indices: if i not in narrowable_indices: continue + expr_type = coerce_to_literal(operand_types[i]) + expr_type = try_expanding_sum_type_to_union(expr_type, None) + expr_enum_keys = ambiguous_enum_equality_keys(expr_type) for j, target in value_targets: if i == j: continue - expr_type = coerce_to_literal(operand_types[i]) - expr_type = try_expanding_sum_type_to_union(expr_type, None) + if ( + # See comments in ambiguous_enum_equality_keys + enum_comparison_is_ambiguous + and len(expr_enum_keys | ambiguous_enum_equality_keys(target.item)) > 1 + ): + continue if_map, else_map = conditional_types_to_typemaps( operands[i], *conditional_types(expr_type, [target]) ) @@ -6779,10 +6786,10 @@ def narrow_type_by_equality( for i in expr_indices: if i not in narrowable_indices: continue + expr_type = operand_types[i] for j, target in type_targets: if i == j: continue - expr_type = operand_types[i] if_map, else_map = conditional_types_to_typemaps( operands[i], *conditional_types(expr_type, [target]) ) @@ -9371,47 +9378,44 @@ def visit_starred_pattern(self, p: StarredPattern) -> None: self.lvalue = False -def is_ambiguous_mix_of_enums(types: list[Type]) -> bool: - """Do types have IntEnum/StrEnum types that are potentially overlapping with other types? +def ambiguous_enum_equality_keys(t: Type) -> set[str]: + """ + Used when narrowing types based on equality. - If True, we shouldn't attempt type narrowing based on enum values, as it gets - too ambiguous. + Certain kinds of enums can compare equal to values of other types, so doing type math + the way `conditional_types` does will be misleading if you expect it to correspond to + conditions based on equality comparisons. - For example, return True if there's an 'int' type together with an IntEnum literal. - However, IntEnum together with a literal of the same IntEnum type is not ambiguous. + For example, StrEnum classes can compare equal to str values. So if we see + `val: StrEnum; if val == "foo": ...` we currently avoid narrowing. + Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...` """ # We need these things for this to be ambiguous: - # (1) an IntEnum or StrEnum type + # (1) an IntEnum or StrEnum type or enum subclass of int or str # (2) either a different IntEnum/StrEnum type or a non-enum type ("") - # - # It would be slightly more correct to calculate this separately for IntEnum and - # StrEnum related types, as an IntEnum can't be confused with a StrEnum. - return len(_ambiguous_enum_variants(types)) > 1 - - -def _ambiguous_enum_variants(types: list[Type]) -> set[str]: result = set() - for t in types: - t = get_proper_type(t) - if isinstance(t, UnionType): - result.update(_ambiguous_enum_variants(t.items)) - elif isinstance(t, Instance): - if t.last_known_value: - result.update(_ambiguous_enum_variants([t.last_known_value])) - elif t.type.is_enum and any( - base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro - ): - result.add(t.type.fullname) - elif not t.type.is_enum: - # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so - # let's be conservative - result.add("") - elif isinstance(t, LiteralType): - result.update(_ambiguous_enum_variants([t.fallback])) - elif isinstance(t, NoneType): - pass - else: + t = get_proper_type(t) + if isinstance(t, UnionType): + for item in t.items: + result.update(ambiguous_enum_equality_keys(item)) + elif isinstance(t, Instance): + if t.last_known_value: + result.update(ambiguous_enum_equality_keys(t.last_known_value)) + elif t.type.is_enum and any( + base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int") + for base in t.type.mro + ): + result.add(t.type.fullname) + elif not t.type.is_enum: + # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so + # let's be conservative result.add("") + elif isinstance(t, LiteralType): + result.update(ambiguous_enum_equality_keys(t.fallback)) + elif isinstance(t, NoneType): + pass + else: + result.add("") return result diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index d0543c6965e0..663a09bcaa5b 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2124,7 +2124,7 @@ else: [builtins fixtures/ops.pyi] [case testNarrowingWithIntEnum] -# mypy: strict-equality +# flags: --strict-equality --warn-unreachable from __future__ import annotations from typing import Any from enum import IntEnum @@ -2179,7 +2179,7 @@ def f6(x: IE) -> None: [builtins fixtures/primitives.pyi] [case testNarrowingWithIntEnum2] -# mypy: strict-equality +# flags: --strict-equality --warn-unreachable from __future__ import annotations from typing import Any from enum import IntEnum, Enum @@ -2284,6 +2284,27 @@ def f4(x: SE) -> None: reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]" [builtins fixtures/primitives.pyi] +[case testNarrowingWithEnumStrSubclass] +# flags: --strict-equality --warn-unreachable +from enum import Enum + +class ParameterLocation(str, Enum): + QUERY = "query" + HEADER = "header" + PATH = "path" + +def foo(location: ParameterLocation): + if location == "path": + reveal_type(location) # N: Revealed type is "__main__.ParameterLocation" + else: + reveal_type(location) # N: Revealed type is "__main__.ParameterLocation" + + if location == ParameterLocation.PATH: + reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.PATH]" + else: + reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.QUERY] | Literal[__main__.ParameterLocation.HEADER]" +[builtins fixtures/primitives.pyi] + [case testConsistentNarrowingEqAndIn] # flags: --python-version 3.10