Skip to content

Commit

Permalink
Eager union loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Apakottur committed Sep 3, 2023
1 parent f83d6eb commit ff75260
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,16 @@ def infer_function_type_arguments_using_context(
# in this case external context is almost everything we have.
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
return callable.copy_modified()

proper_ret_type = get_proper_type(ret_type)
if isinstance(proper_ret_type, Instance) and proper_ret_type.type.fullname == "typing.Coroutine":
proper_ret_type = proper_ret_type.args[-1]

if isinstance(proper_ret_type, UnionType) and any(isinstance(t, TypeVarType) for t in proper_ret_type.items):
# Avoid over eager inference of type variables in unions containing a type variable.
# See github issue #15886
return callable.copy_modified()

args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
# Only substitute non-Uninhabited and non-erased types.
new_args: list[Type | None] = []
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3402,3 +3402,31 @@ reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[b
h: Callable[[Unpack[Us]], Foo[int]]
reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[builtins.int]"
[builtins fixtures/list.pyi]

[case testEagerInferenceOfGenericUnionReturn]
from typing import Generic, TypeVar, Union

T = TypeVar("T")

class Cls(Generic[T]):
pass

def inner(c: Cls[T]) -> Union[T, int]:
return 1

def outer(c: Cls[T]) -> Union[T, int]:
return inner(c)

[case testEagerInferenceOfGenericUnionReturnAsync]
from typing import Generic, TypeVar, Optional

T = TypeVar("T")

class Cls(Generic[T]):
pass

async def inner(c: Cls[T]) -> Optional[T]:
return None

async def outer(c: Cls[T]) -> Optional[T]:
return await inner(c)

0 comments on commit ff75260

Please sign in to comment.