diff --git a/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md b/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md index 117108928f00f..d2e30b0f521d0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md +++ b/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md @@ -144,3 +144,82 @@ class NotIterable: for x in NotIterable(): # error: "Object of type `NotIterable` is not iterable" pass ``` + +## Union type as iterable + +```py +class TestIter: + def __next__(self) -> int: + return 42 + +class Test: + def __iter__(self) -> TestIter: + return TestIter() + +class Test2: + def __iter__(self) -> TestIter: + return TestIter() + +def bool_instance() -> bool: + return True + +flag = bool_instance() + +for x in Test() if flag else Test2(): + reveal_type(x) # revealed: int +``` + +## Union type as iterator + +```py +class TestIter: + def __next__(self) -> int: + return 42 + +class TestIter2: + def __next__(self) -> int: + return 42 + +class Test: + def __iter__(self) -> TestIter | TestIter2: + return TestIter() + +for x in Test(): + reveal_type(x) # revealed: int +``` + +## Union type as iterable and union type as iterator + +```py +class TestIter: + def __next__(self) -> int | Exception: + return 42 + +class TestIter2: + def __next__(self) -> str | tuple[int, int]: + return "42" + +class TestIter3: + def __next__(self) -> bytes: + return b"42" + +class TestIter4: + def __next__(self) -> memoryview: + return memoryview(b"42") + +class Test: + def __iter__(self) -> TestIter | TestIter2: + return TestIter() + +class Test2: + def __iter__(self) -> TestIter3 | TestIter4: + return TestIter3() + +def bool_instance() -> bool: + return True + +flag = bool_instance() + +for x in Test() if flag else Test2(): + reveal_type(x) # revealed: int | Exception | str | tuple[int, int] | bytes | memoryview +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 81c8ddba70f7c..e9dab63f0c7ba 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1104,10 +1104,7 @@ impl<'db> Type<'db> { let dunder_iter_method = iterable_meta_type.member(db, "__iter__"); if !dunder_iter_method.is_unbound() { - let CallOutcome::Callable { - return_ty: iterator_ty, - } = dunder_iter_method.call(db, &[self]) - else { + let Some(iterator_ty) = dunder_iter_method.call(db, &[self]).return_ty(db) else { return IterationOutcome::NotIterable { not_iterable_ty: self, }; @@ -1115,7 +1112,7 @@ impl<'db> Type<'db> { let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__"); return dunder_next_method - .call(db, &[self]) + .call(db, &[iterator_ty]) .return_ty(db) .map(|element_ty| IterationOutcome::Iterable { element_ty }) .unwrap_or(IterationOutcome::NotIterable {