Skip to content

Commit

Permalink
Fix type variable clash in nested positions and in attributes (#14095)
Browse files Browse the repository at this point in the history
Addresses the non-crash part of #10244 (and similar situations).

The `freshen_function_type_vars()` use in `checkmember.py` was
inconsistent:
* It needs to be applied to attributes too, not just methods
* It needs to be a visitor, since generic callable can appear in a
nested position

The downsides are ~2% performance regression, and people will see more
large ids in `reveal_type()` (since refreshing functions uses a global
unique counter). But since this is a correctness issue that can cause
really bizarre error messages, I think it is totally worth it.
  • Loading branch information
ilevkivskyi authored Nov 16, 2022
1 parent 49316f9 commit 48c4a47
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 21 deletions.
41 changes: 23 additions & 18 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from mypy import meet, message_registry, subtypes
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_self_type, expand_type_by_instance, freshen_function_type_vars
from mypy.expandtype import (
expand_self_type,
expand_type_by_instance,
freshen_all_functions_type_vars,
)
from mypy.maptype import map_instance_to_supertype
from mypy.messages import MessageBuilder
from mypy.nodes import (
Expand Down Expand Up @@ -66,6 +70,7 @@
get_proper_type,
has_type_vars,
)
from mypy.typetraverser import TypeTraverserVisitor

if TYPE_CHECKING: # import for forward declaration only
import mypy.checker
Expand Down Expand Up @@ -311,7 +316,7 @@ def analyze_instance_member_access(
if mx.is_lvalue:
mx.msg.cant_assign_to_method(mx.context)
signature = function_type(method, mx.named_type("builtins.function"))
signature = freshen_function_type_vars(signature)
signature = freshen_all_functions_type_vars(signature)
if name == "__new__" or method.is_static:
# __new__ is special and behaves like a static method -- don't strip
# the first argument.
Expand All @@ -329,7 +334,7 @@ def analyze_instance_member_access(
# Since generic static methods should not be allowed.
typ = map_instance_to_supertype(typ, method.info)
member_type = expand_type_by_instance(signature, typ)
freeze_type_vars(member_type)
freeze_all_type_vars(member_type)
return member_type
else:
# Not a method.
Expand Down Expand Up @@ -727,11 +732,13 @@ def analyze_var(
mx.msg.read_only_property(name, itype.type, mx.context)
if mx.is_lvalue and var.is_classvar:
mx.msg.cant_assign_to_classvar(name, mx.context)
t = freshen_all_functions_type_vars(typ)
if not (mx.is_self or mx.is_super) or supported_self_type(
get_proper_type(mx.original_type)
):
typ = expand_self_type(var, typ, mx.original_type)
t = get_proper_type(expand_type_by_instance(typ, itype))
t = expand_self_type(var, t, mx.original_type)
t = get_proper_type(expand_type_by_instance(t, itype))
freeze_all_type_vars(t)
result: Type = t
typ = get_proper_type(typ)
if (
Expand Down Expand Up @@ -759,13 +766,13 @@ def analyze_var(
# In `x.f`, when checking `x` against A1 we assume x is compatible with A
# and similarly for B1 when checking against B
dispatched_type = meet.meet_types(mx.original_type, itype)
signature = freshen_function_type_vars(functype)
signature = freshen_all_functions_type_vars(functype)
signature = check_self_arg(
signature, dispatched_type, var.is_classmethod, mx.context, name, mx.msg
)
signature = bind_self(signature, mx.self_type, var.is_classmethod)
expanded_signature = expand_type_by_instance(signature, itype)
freeze_type_vars(expanded_signature)
freeze_all_type_vars(expanded_signature)
if var.is_property:
# A property cannot have an overloaded type => the cast is fine.
assert isinstance(expanded_signature, CallableType)
Expand All @@ -788,16 +795,14 @@ def analyze_var(
return result


def freeze_type_vars(member_type: Type) -> None:
if not isinstance(member_type, ProperType):
return
if isinstance(member_type, CallableType):
for v in member_type.variables:
def freeze_all_type_vars(member_type: Type) -> None:
member_type.accept(FreezeTypeVarsVisitor())


class FreezeTypeVarsVisitor(TypeTraverserVisitor):
def visit_callable_type(self, t: CallableType) -> None:
for v in t.variables:
v.id.meta_level = 0
if isinstance(member_type, Overloaded):
for it in member_type.items:
for v in it.variables:
v.id.meta_level = 0


def lookup_member_var_or_accessor(info: TypeInfo, name: str, is_lvalue: bool) -> SymbolNode | None:
Expand Down Expand Up @@ -1131,11 +1136,11 @@ class B(A[str]): pass
if isinstance(t, CallableType):
tvars = original_vars if original_vars is not None else []
if is_classmethod:
t = freshen_function_type_vars(t)
t = freshen_all_functions_type_vars(t)
t = bind_self(t, original_type, is_classmethod=True)
assert isuper is not None
t = cast(CallableType, expand_type_by_instance(t, isuper))
freeze_type_vars(t)
freeze_all_type_vars(t)
return t.copy_modified(variables=list(tvars) + list(t.variables))
elif isinstance(t, Overloaded):
return Overloaded(
Expand Down
21 changes: 21 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_STAR, Var
from mypy.type_visitor import TypeTranslator
from mypy.types import (
AnyType,
CallableType,
Expand Down Expand Up @@ -130,6 +131,26 @@ def freshen_function_type_vars(callee: F) -> F:
return cast(F, fresh_overload)


T = TypeVar("T", bound=Type)


def freshen_all_functions_type_vars(t: T) -> T:
result = t.accept(FreshenCallableVisitor())
assert isinstance(result, type(t))
return result


class FreshenCallableVisitor(TypeTranslator):
def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
return freshen_function_type_vars(result)

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Same as for ExpandTypeVisitor
return t.copy_modified(args=[arg.accept(self) for arg in t.args])


class ExpandTypeVisitor(TypeVisitor[Type]):
"""Visitor that substitutes type variables with values."""

Expand Down
3 changes: 2 additions & 1 deletion mypy/typestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mypy.nodes import TypeInfo
from mypy.server.trigger import make_trigger
from mypy.types import Instance, Type, get_proper_type
from mypy.types import Instance, Type, TypeVarId, get_proper_type

# Represents that the 'left' instance is a subtype of the 'right' instance
SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance]
Expand Down Expand Up @@ -275,3 +275,4 @@ def reset_global_state() -> None:
"""
TypeState.reset_all_subtype_caches()
TypeState.reset_protocol_deps()
TypeVarId.next_raw_id = 1
67 changes: 66 additions & 1 deletion test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,7 @@ class C(Generic[T]):
reveal_type(C.F(17).foo()) # N: Revealed type is "builtins.int"
reveal_type(C("").F(17).foo()) # N: Revealed type is "builtins.int"
reveal_type(C.F) # N: Revealed type is "def [K] (k: K`1) -> __main__.C.F[K`1]"
reveal_type(C("").F) # N: Revealed type is "def [K] (k: K`1) -> __main__.C.F[K`1]"
reveal_type(C("").F) # N: Revealed type is "def [K] (k: K`6) -> __main__.C.F[K`6]"


-- Callable subtyping with generic functions
Expand Down Expand Up @@ -2580,3 +2580,68 @@ class Bar(Foo[AnyStr]):
[out]
main:10: error: Argument 1 to "method1" of "Foo" has incompatible type "str"; expected "AnyStr"
main:10: error: Argument 2 to "method1" of "Foo" has incompatible type "bytes"; expected "AnyStr"

[case testTypeVariableClashVar]
from typing import Generic, TypeVar, Callable

T = TypeVar("T")
R = TypeVar("R")
class C(Generic[R]):
x: Callable[[T], R]

def func(x: C[R]) -> R:
return x.x(42) # OK

[case testTypeVariableClashVarTuple]
from typing import Generic, TypeVar, Callable, Tuple

T = TypeVar("T")
R = TypeVar("R")
class C(Generic[R]):
x: Callable[[T], Tuple[R, T]]

def func(x: C[R]) -> R:
if bool():
return x.x(42)[0] # OK
else:
return x.x(42)[1] # E: Incompatible return value type (got "int", expected "R")
[builtins fixtures/tuple.pyi]

[case testTypeVariableClashMethod]
from typing import Generic, TypeVar, Callable

T = TypeVar("T")
R = TypeVar("R")
class C(Generic[R]):
def x(self) -> Callable[[T], R]: ...

def func(x: C[R]) -> R:
return x.x()(42) # OK

[case testTypeVariableClashMethodTuple]
from typing import Generic, TypeVar, Callable, Tuple

T = TypeVar("T")
R = TypeVar("R")
class C(Generic[R]):
def x(self) -> Callable[[T], Tuple[R, T]]: ...

def func(x: C[R]) -> R:
if bool():
return x.x()(42)[0] # OK
else:
return x.x()(42)[1] # E: Incompatible return value type (got "int", expected "R")
[builtins fixtures/tuple.pyi]

[case testTypeVariableClashVarSelf]
from typing import Self, TypeVar, Generic, Callable

T = TypeVar("T")
S = TypeVar("S")

class C(Generic[T]):
x: Callable[[S], Self]
y: T

def foo(x: C[T]) -> T:
return x.x(42).y # OK
2 changes: 1 addition & 1 deletion test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ class C:
def bar(self) -> Self: ...
foo: Callable[[S, Self], Tuple[Self, S]]

reveal_type(C().foo) # N: Revealed type is "def [S] (S`-1, __main__.C) -> Tuple[__main__.C, S`-1]"
reveal_type(C().foo) # N: Revealed type is "def [S] (S`1, __main__.C) -> Tuple[__main__.C, S`1]"
reveal_type(C().foo(42, C())) # N: Revealed type is "Tuple[__main__.C, builtins.int]"
class This: ...
[builtins fixtures/tuple.pyi]
Expand Down

0 comments on commit 48c4a47

Please sign in to comment.