Skip to content

Commit

Permalink
Proposal: don't simplify unions in expand_type() (#14178)
Browse files Browse the repository at this point in the history
Fixes #6730

Currently `expand_type()` is inherently recursive, going through
`expand_type` -> `make_simplified_union` -> `is_proper_subtype` ->
`map_instance_to_supertype` -> `expand_type`. TBH I never liked this, so
I propose that we don't do this. One one hand, this is a significant
change in semantics, but on the other hand:
* This fixes a crash (actually a whole class of crashes) that can happen
even without recursive aliases
* This removes an ugly import and simplifies an import cycle in mypy
code
* This makes mypy 2% faster (measured on self-check)

To make transition smoother, I propose to make trivial simplifications,
like removing `<nothing>` (and `None` without strict optional), removing
everything else if there is an `object` type, and remove strict
duplicates. Notably, with these few things _all existing tests pass_
(and even without it, only half a dozen tests fail on `reveal_type()`).
  • Loading branch information
ilevkivskyi authored Nov 24, 2022
1 parent 13bd201 commit 4471c7e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 5 deletions.
14 changes: 9 additions & 5 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
UnionType,
UnpackType,
expand_param_spec,
flatten_nested_unions,
get_proper_type,
remove_trivial,
)
from mypy.typevartuples import (
find_unpack_in_list,
Expand Down Expand Up @@ -405,11 +407,13 @@ def visit_literal_type(self, t: LiteralType) -> Type:
return t

def visit_union_type(self, t: UnionType) -> Type:
# After substituting for type variables in t.items,
# some of the resulting types might be subtypes of others.
from mypy.typeops import make_simplified_union # asdf

return make_simplified_union(self.expand_types(t.items), t.line, t.column)
expanded = self.expand_types(t.items)
# After substituting for type variables in t.items, some resulting types
# might be subtypes of others, however calling make_simplified_union()
# can cause recursion, so we just remove strict duplicates.
return UnionType.make_union(
remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
)

def visit_partial_type(self, t: PartialType) -> Type:
return t
Expand Down
30 changes: 30 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3487,3 +3487,33 @@ def store_argument_type(
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
arg_type = named_type("builtins.dict", [named_type("builtins.str", []), arg_type])
defn.arguments[i].variable.type = arg_type


def remove_trivial(types: Iterable[Type]) -> list[Type]:
"""Make trivial simplifications on a list of types without calling is_subtype().
This makes following simplifications:
* Remove bottom types (taking into account strict optional setting)
* Remove everything else if there is an `object`
* Remove strict duplicate types
"""
removed_none = False
new_types = []
all_types = set()
for t in types:
p_t = get_proper_type(t)
if isinstance(p_t, UninhabitedType):
continue
if isinstance(p_t, NoneType) and not state.strict_optional:
removed_none = True
continue
if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":
return [p_t]
if p_t not in all_types:
new_types.append(t)
all_types.add(p_t)
if new_types:
return new_types
if removed_none:
return [NoneType()]
return [UninhabitedType()]
34 changes: 34 additions & 0 deletions test-data/unit/check-recursive-types.test
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,37 @@ def foo(x: T) -> C: ...

Nested = Union[C, Sequence[Nested]]
x: Nested = foo(42)

[case testNoRecursiveExpandInstanceUnionCrash]
from typing import List, Union

class Tag(List[Union[Tag, List[Tag]]]): ...
Tag()

[case testNoRecursiveExpandInstanceUnionCrashGeneric]
from typing import Generic, Iterable, TypeVar, Union

ValueT = TypeVar("ValueT")
class Recursive(Iterable[Union[ValueT, Recursive[ValueT]]]):
pass

class Base(Generic[ValueT]):
def __init__(self, element: ValueT):
pass
class Sub(Base[Union[ValueT, Recursive[ValueT]]]):
pass

x: Iterable[str]
reveal_type(Sub) # N: Revealed type is "def [ValueT] (element: Union[ValueT`1, __main__.Recursive[ValueT`1]]) -> __main__.Sub[ValueT`1]"
reveal_type(Sub(x)) # N: Revealed type is "__main__.Sub[typing.Iterable[builtins.str]]"

[case testNoRecursiveExpandInstanceUnionCrashInference]
from typing import TypeVar, Union, Generic, List

T = TypeVar("T")
InList = Union[T, InListRecurse[T]]
class InListRecurse(Generic[T], List[InList[T]]): ...

def list_thing(transforming: InList[T]) -> T:
...
reveal_type(list_thing([5])) # N: Revealed type is "builtins.list[builtins.int]"
16 changes: 16 additions & 0 deletions test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -1735,3 +1735,19 @@ _testEnumNameWorkCorrectlyOn311.py:12: note: Revealed type is "Union[Literal[1]?
_testEnumNameWorkCorrectlyOn311.py:13: note: Revealed type is "Literal['X']?"
_testEnumNameWorkCorrectlyOn311.py:14: note: Revealed type is "builtins.int"
_testEnumNameWorkCorrectlyOn311.py:15: note: Revealed type is "builtins.int"

[case testTypedDictUnionGetFull]
from typing import Dict
from typing_extensions import TypedDict

class TD(TypedDict, total=False):
x: int
y: int

A = Dict[str, TD]
x: A
def foo(k: str) -> TD:
reveal_type(x.get(k, {}))
return x.get(k, {})
[out]
_testTypedDictUnionGetFull.py:11: note: Revealed type is "TypedDict('_testTypedDictUnionGetFull.TD', {'x'?: builtins.int, 'y'?: builtins.int})"

0 comments on commit 4471c7e

Please sign in to comment.