diff --git a/mypy/checker.py b/mypy/checker.py index 6d70dcb90e94..c52e4c4887bc 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6583,6 +6583,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa partial_type_maps = [] for operator, expr_indices in simplified_operator_list: + if_map: TypeMap + else_map: TypeMap + if operator in {"is", "is not", "==", "!="}: if_map, else_map = self.equality_type_narrowing_helper( node, @@ -6598,14 +6601,24 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa item_type = operand_types[left_index] iterable_type = operand_types[right_index] - if_map, else_map = {}, {} + if_map = {} + else_map = {} if left_index in narrowable_operand_index_to_hash: - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type(builtin_item_type(iterable_type)) + collection_item_type = get_proper_type(builtin_item_type(iterable_type)) + if collection_item_type is not None: + if_map, else_map = self.narrow_type_by_equality( + "==", + operands=[operands[left_index], operands[right_index]], + operand_types=[item_type, collection_item_type], + expr_indices=[left_index, right_index], + narrowable_indices={0}, + ) + + # We only try and narrow away 'None' for now if ( - collection_item_type is not None + if_map is not None + and is_overlapping_none(item_type) and not is_overlapping_none(collection_item_type) and not ( isinstance(collection_item_type, Instance) @@ -6622,11 +6635,11 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa expr = operands[right_index] if if_type is None: if_map = None - else: + elif if_map is not None: if_map[expr] = if_type if else_type is None: else_map = None - else: + elif else_map is not None: else_map[expr] = else_type else: diff --git a/mypy/constraints.py b/mypy/constraints.py index cfb627e9f2b5..05cdf2986020 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -124,7 +124,7 @@ def infer_constraints_for_callable( param_spec = callee.param_spec() param_spec_arg_types = [] param_spec_arg_names = [] - param_spec_arg_kinds = [] + param_spec_arg_kinds: list[ArgKind] = [] incomplete_star_mapping = False for i, actuals in enumerate(formal_to_actual): # TODO: isn't this `enumerate(arg_types)`? diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index 3028e940f7f9..912deb581c9a 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Final, Literal, TypedDict, cast +from typing import Any, Final, Literal, TypedDict from typing_extensions import NotRequired from mypy.nodes import ( @@ -138,7 +138,6 @@ def get_mypyc_attrs( def set_mypyc_attr(key: str, value: Any, line: int) -> None: if key in MYPYC_ATTRS: - key = cast(MypycAttr, key) attrs[key] = value lines[key] = line else: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 237271558ac6..2c59e840aa56 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1373,13 +1373,13 @@ else: reveal_type(val) # N: Revealed type is "None" if val in (None,): - reveal_type(val) # N: Revealed type is "__main__.A | None" + reveal_type(val) # N: Revealed type is "None" else: - reveal_type(val) # N: Revealed type is "__main__.A | None" + reveal_type(val) # N: Revealed type is "__main__.A" if val not in (None,): - reveal_type(val) # N: Revealed type is "__main__.A | None" + reveal_type(val) # N: Revealed type is "__main__.A" else: - reveal_type(val) # N: Revealed type is "__main__.A | None" + reveal_type(val) # N: Revealed type is "None" class Hmm: def __eq__(self, other) -> bool: ... @@ -2294,9 +2294,8 @@ def f(x: str | int) -> None: y = x if x in ["x"]: - # TODO: we should fix this reveal https://github.com/python/mypy/issues/3229 - reveal_type(x) # N: Revealed type is "builtins.str | builtins.int" - y = x # E: Incompatible types in assignment (expression has type "str | int", variable has type "str") + reveal_type(x) # N: Revealed type is "builtins.str" + y = x z = x z = y [builtins fixtures/primitives.pyi] @@ -2806,3 +2805,126 @@ class X: reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]" self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist" [builtins fixtures/dict.pyi] + + +[case testTypeNarrowingStringInLiteralUnion] +from typing import Literal, Tuple +typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b') +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInLiteralUnionSubset] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b') +strIn: str = "b" +strOut: str = "c" +if strIn in typeAlpha: + reveal_type(strIn) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']" +else: + reveal_type(strIn) # N: Revealed type is "builtins.str" +if strOut in typeAlpha: + reveal_type(strOut) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowingStringNotInLiteralUnion] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c') +strIn: str = "c" +strOut: str = "d" +if strIn not in typeAlpha: + reveal_type(strIn) # N: Revealed type is "builtins.str" +else: + reveal_type(strIn) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']" +if strOut in typeAlpha: + reveal_type(strOut) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowingStringInLiteralUnionDontExpand] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c') +strIn: Literal['c'] = "c" +reveal_type(strIn) # N: Revealed type is "Literal['c']" +#Check we don't expand a Literal into the Union type +if strIn not in typeAlpha: + reveal_type(strIn) # N: Revealed type is "Literal['c']" +else: + reveal_type(strIn) # N: Revealed type is "Literal['c']" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInMixedUnion] +from typing import Literal, Tuple +typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b') +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInSet] +from typing import Literal, Set +typ: Set[Literal['a', 'b']] = {'a', 'b'} +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +if x not in typ: + reveal_type(x) # N: Revealed type is "builtins.str" +else: + reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']" +[builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInList] +from typing import Literal, List +typ: List[Literal['a', 'b']] = ['a', 'b'] +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +if x not in typ: + reveal_type(x) # N: Revealed type is "builtins.str" +else: + reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']" +[builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingUnionStringFloat] +from typing import Union +def foobar(foo: Union[str, float]): + if foo in ['a', 'b']: + reveal_type(foo) # N: Revealed type is "builtins.str" + else: + reveal_type(foo) # N: Revealed type is "builtins.str | builtins.float" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowAnyWithEqualityOrContainment] +# https://github.com/python/mypy/issues/17841 +from typing import Any + +def f1(x: Any) -> None: + if x is not None and x not in ["x"]: + return + reveal_type(x) # N: Revealed type is "Any" + +def f2(x: Any) -> None: + if x is not None and x != "x": + return + reveal_type(x) # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/fixtures/narrowing.pyi b/test-data/unit/fixtures/narrowing.pyi index 89ee011c1c80..a36ac7f29bd2 100644 --- a/test-data/unit/fixtures/narrowing.pyi +++ b/test-data/unit/fixtures/narrowing.pyi @@ -1,5 +1,5 @@ # Builtins stub used in check-narrowing test cases. -from typing import Generic, Sequence, Tuple, Type, TypeVar, Union +from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable Tco = TypeVar('Tco', covariant=True) @@ -15,6 +15,13 @@ class function: pass class ellipsis: pass class int: pass class str: pass +class float: pass class dict(Generic[KT, VT]): pass def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass + +class list(Sequence[Tco]): + def __contains__(self, other: object) -> bool: pass +class set(Iterable[Tco], Generic[Tco]): + def __init__(self, iterable: Iterable[Tco] = ...) -> None: ... + def __contains__(self, item: object) -> bool: pass