From 90d0d1fbdc703bd147484093fc5ee4e09a0c9df8 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 14 Nov 2022 19:53:07 +0000 Subject: [PATCH] Fix crash on nested generic callable --- mypy/applytype.py | 19 +++++++++++++----- mypy/expandtype.py | 29 ++++++++++++++++++++------- mypy/subtypes.py | 6 +++++- test-data/unit/check-dataclasses.test | 23 +++++++++++++++++++++ 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index 1c401664568d..d7f31b36c244 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -73,6 +73,7 @@ def apply_generic_arguments( report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], context: Context, skip_unsatisfied: bool = False, + allow_erased_callables: bool = False, ) -> CallableType: """Apply generic type arguments to a callable type. @@ -130,18 +131,26 @@ def apply_generic_arguments( + callable.arg_names[star_index + 1 :] ) arg_types = ( - [expand_type(at, id_to_type) for at in callable.arg_types[:star_index]] + [ + expand_type(at, id_to_type, allow_erased_callables) + for at in callable.arg_types[:star_index] + ] + expanded - + [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]] + + [ + expand_type(at, id_to_type, allow_erased_callables) + for at in callable.arg_types[star_index + 1 :] + ] ) else: - arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] + arg_types = [ + expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types + ] arg_kinds = callable.arg_kinds arg_names = callable.arg_names # Apply arguments to TypeGuard if any. if callable.type_guard is not None: - type_guard = expand_type(callable.type_guard, id_to_type) + type_guard = expand_type(callable.type_guard, id_to_type, allow_erased_callables) else: type_guard = None @@ -150,7 +159,7 @@ def apply_generic_arguments( return callable.copy_modified( arg_types=arg_types, - ret_type=expand_type(callable.ret_type, id_to_type), + ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables), variables=remaining_tvars, type_guard=type_guard, arg_kinds=arg_kinds, diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 08bc216689fb..89568838232f 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -39,20 +39,26 @@ @overload -def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: +def expand_type( + typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ... +) -> ProperType: ... @overload -def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: +def expand_type( + typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ... +) -> Type: ... -def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: +def expand_type( + typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = False +) -> Type: """Substitute any type variable references in a type given by a type environment. """ - return typ.accept(ExpandTypeVisitor(env)) + return typ.accept(ExpandTypeVisitor(env, allow_erased_callables)) @overload @@ -129,8 +135,11 @@ class ExpandTypeVisitor(TypeVisitor[Type]): variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value - def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: + def __init__( + self, variables: Mapping[TypeVarId, Type], allow_erased_callables: bool = False + ) -> None: self.variables = variables + self.allow_erased_callables = allow_erased_callables def visit_unbound_type(self, t: UnboundType) -> Type: return t @@ -148,8 +157,14 @@ def visit_deleted_type(self, t: DeletedType) -> Type: return t def visit_erased_type(self, t: ErasedType) -> Type: - # Should not get here. - raise RuntimeError() + if not self.allow_erased_callables: + raise RuntimeError() + # This may happen during type inference if some function argument + # type is a generic callable, and its erased form will appear in inferred + # constraints, then solver may check subtyping between them, which will trigger + # unify_generic_callables(), this is why we can get here. In all other cases it + # is a sign of a bug, since should never appear in any stored types. + return t def visit_instance(self, t: Instance) -> Type: args = self.expand_types_with_unpack(list(t.args)) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index f928e1cc7918..09ebf266545d 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1665,8 +1665,12 @@ def report(*args: Any) -> None: nonlocal had_errors had_errors = True + # This function may be called by the solver, so we need to allow erased types here. + # We anyway allow checking subtyping between other types containing + # (probably also because solver needs subtyping). See also comment in + # ExpandTypeVisitor.visit_erased_type(). applied = mypy.applytype.apply_generic_arguments( - type, non_none_inferred_vars, report, context=target + type, non_none_inferred_vars, report, context=target, allow_erased_callables=True ) if had_errors: return None diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index d4064124109b..a27e7ff4f617 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -1933,3 +1933,26 @@ B = List[C] class C(CC): ... class CC: ... [builtins fixtures/dataclasses.pyi] + +[case testNoCrashOnNestedGenericCallable] +from dataclasses import dataclass +from typing import Generic, TypeVar, Callable + +T = TypeVar('T') +R = TypeVar('R') +X = TypeVar('X') + +@dataclass +class Box(Generic[T]): + inner: T + +@dataclass +class Cont(Generic[R]): + run: Box[Callable[[X], R]] + +def const_two(x: T) -> str: + return "two" + +c = Cont(Box(const_two)) +reveal_type(c) # N: Revealed type is "__main__.Cont[builtins.str]" +[builtins fixtures/dataclasses.pyi]