diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d839ad4925fd..c990e9b59f98 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1583,21 +1583,21 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] res.append(arg_type) return res - @contextmanager - def allow_unions(self, type_context: Type) -> Iterator[None]: - # This is a hack to better support inference for recursive types. - # When the outer context for a function call is known to be recursive, - # we solve type constraints inferred from arguments using unions instead - # of joins. This is a bit arbitrary, but in practice it works for most - # cases. A cleaner alternative would be to switch to single bin type - # inference, but this is a lot of work. + def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool: + """Adjust type inference of unions if type context has a recursive type. + + Return the old state. The caller must assign it to type_state.infer_unions + afterwards. + + This is a hack to better support inference for recursive types. + + Note: This is performance-sensitive and must not be a context manager + until mypyc supports them better. + """ old = type_state.infer_unions if has_recursive_types(type_context): type_state.infer_unions = True - try: - yield - finally: - type_state.infer_unions = old + return old def infer_arg_types_in_context( self, @@ -1618,8 +1618,16 @@ def infer_arg_types_in_context( for i, actuals in enumerate(formal_to_actual): for ai in actuals: if not arg_kinds[ai].is_star(): - with self.allow_unions(callee.arg_types[i]): - res[ai] = self.accept(args[ai], callee.arg_types[i]) + arg_type = callee.arg_types[i] + # When the outer context for a function call is known to be recursive, + # we solve type constraints inferred from arguments using unions instead + # of joins. This is a bit arbitrary, but in practice it works for most + # cases. A cleaner alternative would be to switch to single bin type + # inference, but this is a lot of work. + old = self.infer_more_unions_for_recursive_type(arg_type) + res[ai] = self.accept(args[ai], arg_type) + # We need to manually restore union inference state, ugh. + type_state.infer_unions = old # Fill in the rest of the argument types. for i, t in enumerate(res):