Skip to content

Commit

Permalink
Fix false negatives involving Unions and generators or coroutines (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored Dec 1, 2022
1 parent 6e9227a commit 3c71548
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,10 @@ def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Typ

if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
elif isinstance(return_type, UnionType):
return make_simplified_union(
[self.get_generator_yield_type(item, is_coroutine) for item in return_type.items]
)
elif not self.is_generator_return_type(
return_type, is_coroutine
) and not self.is_async_generator_return_type(return_type):
Expand Down Expand Up @@ -878,6 +882,10 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T

if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
elif isinstance(return_type, UnionType):
return make_simplified_union(
[self.get_generator_receive_type(item, is_coroutine) for item in return_type.items]
)
elif not self.is_generator_return_type(
return_type, is_coroutine
) and not self.is_async_generator_return_type(return_type):
Expand Down Expand Up @@ -917,6 +925,10 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty

if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
elif isinstance(return_type, UnionType):
return make_simplified_union(
[self.get_generator_return_type(item, is_coroutine) for item in return_type.items]
)
elif not self.is_generator_return_type(return_type, is_coroutine):
# If the function doesn't have a proper Generator (or
# Awaitable) return type, anything is permissible.
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,21 @@ async def f() -> AsyncGenerator[int, None]:

[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

[case testAwaitUnion]
from typing import overload, Union

class A: ...
class B: ...

@overload
async def foo(x: A) -> B: ...
@overload
async def foo(x: B) -> A: ...
async def foo(x): ...

async def bar(x: Union[A, B]) -> None:
reveal_type(await foo(x)) # N: Revealed type is "Union[__main__.B, __main__.A]"

[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]
9 changes: 9 additions & 0 deletions test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -2206,3 +2206,12 @@ def foo():
x: int = "no" # N: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs
y = "no" # type: int # N: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs
z: int # N: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs

[case testGeneratorUnion]
from typing import Generator, Union

class A: pass
class B: pass

def foo(x: int) -> Union[Generator[A, None, None], Generator[B, None, None]]:
yield x # E: Incompatible types in "yield" (actual type "int", expected type "Union[A, B]")

0 comments on commit 3c71548

Please sign in to comment.