Skip to content

Commit

Permalink
[red-knot] Fix bug where union of two iterable types was not recognis…
Browse files Browse the repository at this point in the history
…ed as iterable (#13992)
  • Loading branch information
AlexWaygood authored Oct 30, 2024
1 parent 1607d88 commit 42c7069
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 5 deletions.
79 changes: 79 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,82 @@ class NotIterable:
for x in NotIterable(): # error: "Object of type `NotIterable` is not iterable"
pass
```

## Union type as iterable

```py
class TestIter:
def __next__(self) -> int:
return 42

class Test:
def __iter__(self) -> TestIter:
return TestIter()

class Test2:
def __iter__(self) -> TestIter:
return TestIter()

def bool_instance() -> bool:
return True

flag = bool_instance()

for x in Test() if flag else Test2():
reveal_type(x) # revealed: int
```

## Union type as iterator

```py
class TestIter:
def __next__(self) -> int:
return 42

class TestIter2:
def __next__(self) -> int:
return 42

class Test:
def __iter__(self) -> TestIter | TestIter2:
return TestIter()

for x in Test():
reveal_type(x) # revealed: int
```

## Union type as iterable and union type as iterator

```py
class TestIter:
def __next__(self) -> int | Exception:
return 42

class TestIter2:
def __next__(self) -> str | tuple[int, int]:
return "42"

class TestIter3:
def __next__(self) -> bytes:
return b"42"

class TestIter4:
def __next__(self) -> memoryview:
return memoryview(b"42")

class Test:
def __iter__(self) -> TestIter | TestIter2:
return TestIter()

class Test2:
def __iter__(self) -> TestIter3 | TestIter4:
return TestIter3()

def bool_instance() -> bool:
return True

flag = bool_instance()

for x in Test() if flag else Test2():
reveal_type(x) # revealed: int | Exception | str | tuple[int, int] | bytes | memoryview
```
7 changes: 2 additions & 5 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,18 +1104,15 @@ impl<'db> Type<'db> {

let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let CallOutcome::Callable {
return_ty: iterator_ty,
} = dunder_iter_method.call(db, &[self])
else {
let Some(iterator_ty) = dunder_iter_method.call(db, &[self]).return_ty(db) else {
return IterationOutcome::NotIterable {
not_iterable_ty: self,
};
};

let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method
.call(db, &[self])
.call(db, &[iterator_ty])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
Expand Down

0 comments on commit 42c7069

Please sign in to comment.