Skip to content

Commit

Permalink
Infer ParamSpec constraint from arguments (#15896)
Browse files Browse the repository at this point in the history
Fixes #12278
Fixes #13191 (more tricky nested
use cases with optional/keyword args still don't work, but they are
quite tricky to fix and may selectively fixed later)

This unfortunately requires some special-casing, here is its summary:
* If actual argument for `Callable[P, T]` is non-generic and non-lambda,
do not put it into inference second pass.
* If we are able to infer constraints for `P` without using arguments
mapped to `*args: P.args` etc., do not add the constraint for `P` vs
those arguments (this applies to both top-level callable constraints,
and for nested callable constraints against callables that are known to
have imprecise argument kinds).

(Btw TODO I added is not related to this PR, I just noticed something
obviously wrong)
  • Loading branch information
ilevkivskyi authored Aug 25, 2023
1 parent f9b1db6 commit 7f65cc7
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 69 deletions.
41 changes: 36 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,7 @@ def infer_function_type_arguments(
)

arg_pass_nums = self.get_arg_infer_passes(
callee_type.arg_types, formal_to_actual, len(args)
callee_type, args, arg_types, formal_to_actual, len(args)
)

pass1_args: list[Type | None] = []
Expand All @@ -2001,6 +2001,7 @@ def infer_function_type_arguments(
callee_type,
pass1_args,
arg_kinds,
arg_names,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
Expand Down Expand Up @@ -2061,6 +2062,7 @@ def infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
Expand Down Expand Up @@ -2140,6 +2142,7 @@ def infer_function_type_arguments_pass2(
callee_type,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context=self.argument_infer_context(),
)
Expand All @@ -2152,7 +2155,12 @@ def argument_infer_context(self) -> ArgumentInferContext:
)

def get_arg_infer_passes(
self, arg_types: list[Type], formal_to_actual: list[list[int]], num_actuals: int
self,
callee: CallableType,
args: list[Expression],
arg_types: list[Type],
formal_to_actual: list[list[int]],
num_actuals: int,
) -> list[int]:
"""Return pass numbers for args for two-pass argument type inference.
Expand All @@ -2163,8 +2171,28 @@ def get_arg_infer_passes(
lambdas more effectively.
"""
res = [1] * num_actuals
for i, arg in enumerate(arg_types):
if arg.accept(ArgInferSecondPassQuery()):
for i, arg in enumerate(callee.arg_types):
skip_param_spec = False
p_formal = get_proper_type(callee.arg_types[i])
if isinstance(p_formal, CallableType) and p_formal.param_spec():
for j in formal_to_actual[i]:
p_actual = get_proper_type(arg_types[j])
# This is an exception from the usual logic where we put generic Callable
# arguments in the second pass. If we have a non-generic actual, it is
# likely to infer good constraints, for example if we have:
# def run(Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
# def test(x: int, y: int) -> int: ...
# run(test, 1, 2)
# we will use `test` for inference, since it will allow to infer also
# argument *names* for P <: [x: int, y: int].
if (
isinstance(p_actual, CallableType)
and not p_actual.variables
and not isinstance(args[j], LambdaExpr)
):
skip_param_spec = True
break
if not skip_param_spec and arg.accept(ArgInferSecondPassQuery()):
for j in formal_to_actual[i]:
res[j] = 2
return res
Expand Down Expand Up @@ -4903,7 +4931,9 @@ def infer_lambda_type_using_context(
self.chk.fail(message_registry.CANNOT_INFER_LAMBDA_TYPE, e)
return None, None

return callable_ctx, callable_ctx
# Type of lambda must have correct argument names, to prevent false
# negatives when lambdas appear in `ParamSpec` context.
return callable_ctx.copy_modified(arg_names=e.arg_names), callable_ctx

def visit_super_expr(self, e: SuperExpr) -> Type:
"""Type check a super expression (non-lvalue)."""
Expand Down Expand Up @@ -5921,6 +5951,7 @@ def __init__(self) -> None:
super().__init__(types.ANY_STRATEGY)

def visit_callable_type(self, t: CallableType) -> bool:
# TODO: we need to check only for type variables of original callable.
return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery())


Expand Down
136 changes: 97 additions & 39 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def infer_constraints_for_callable(
callee: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
) -> list[Constraint]:
Expand All @@ -118,6 +119,20 @@ def infer_constraints_for_callable(
constraints: list[Constraint] = []
mapper = ArgTypeExpander(context)

param_spec = callee.param_spec()
param_spec_arg_types = []
param_spec_arg_names = []
param_spec_arg_kinds = []

incomplete_star_mapping = False
for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
if actual is None and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
# We can't use arguments to infer ParamSpec constraint, if only some
# are present in the current inference pass.
incomplete_star_mapping = True
break

for i, actuals in enumerate(formal_to_actual):
if isinstance(callee.arg_types[i], UnpackType):
unpack_type = callee.arg_types[i]
Expand Down Expand Up @@ -194,11 +209,47 @@ def infer_constraints_for_callable(
actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
# TODO: if callee has ParamSpec, we need to collect all actuals that map to star
# args and create single constraint between P and resulting Parameters instead.
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)

if (
param_spec
and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
and not incomplete_star_mapping
):
# If actual arguments are mapped to ParamSpec type, we can't infer individual
# constraints, instead store them and infer single constraint at the end.
# It is impossible to map actual kind to formal kind, so use some heuristic.
# This inference is used as a fallback, so relying on heuristic should be OK.
param_spec_arg_types.append(
mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
)
)
actual_kind = arg_kinds[actual]
param_spec_arg_kinds.append(
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
)
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
else:
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)
if (
param_spec
and not any(c.type_var == param_spec.id for c in constraints)
and not incomplete_star_mapping
):
# Use ParamSpec constraint from arguments only if there are no other constraints,
# since as explained above it is quite ad-hoc.
constraints.append(
Constraint(
param_spec,
SUPERTYPE_OF,
Parameters(
arg_types=param_spec_arg_types,
arg_kinds=param_spec_arg_kinds,
arg_names=param_spec_arg_names,
imprecise_arg_kinds=True,
),
)
)
return constraints


Expand Down Expand Up @@ -949,6 +1000,14 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
res: list[Constraint] = []
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))

if param_spec is None:
# TODO: Erase template variables if it is generic?
if (
Expand Down Expand Up @@ -1008,51 +1067,50 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
extra_tvars = True

# Compare prefixes as well
cactual_prefix = cactual.copy_modified(
arg_types=cactual.arg_types[:prefix_len],
arg_kinds=cactual.arg_kinds[:prefix_len],
arg_names=cactual.arg_names[:prefix_len],
)
res.extend(
infer_callable_arguments_constraints(prefix, cactual_prefix, self.direction)
)

param_spec_target: Type | None = None
skip_imprecise = (
any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds
)
if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
res.append(
Constraint(
param_spec,
neg_op(self.direction),
Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
),
# This logic matches top-level callable constraint exception, if we managed
# to get other constraints for ParamSpec, don't infer one with imprecise kinds
if not skip_imprecise:
param_spec_target = Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
)
)
else:
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
cactual_ps = cactual_ps.copy_modified(
if (
len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types)
and not skip_imprecise
):
param_spec_target = cactual_ps.copy_modified(
prefix=Parameters(
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:],
arg_names=cactual_ps.prefix.arg_names[prefix_len:],
imprecise_arg_kinds=cactual_ps.prefix.imprecise_arg_kinds,
)
)
res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps))

# Compare prefixes as well
cactual_prefix = cactual.copy_modified(
arg_types=cactual.arg_types[:prefix_len],
arg_kinds=cactual.arg_kinds[:prefix_len],
arg_names=cactual.arg_names[:prefix_len],
)
res.extend(
infer_callable_arguments_constraints(prefix, cactual_prefix, self.direction)
)

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard

res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
if param_spec_target is not None:
res.append(Constraint(param_spec, neg_op(self.direction), param_spec_target))
if extra_tvars:
for c in res:
c.extra_tvars += cactual.variables
Expand Down
2 changes: 2 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_types=self.expand_types(t.arg_types),
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
)
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
Expand All @@ -352,6 +353,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
ret_type=t.ret_type.accept(self),
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds),
)

var_arg = t.var_arg()
Expand Down
3 changes: 2 additions & 1 deletion mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def infer_function_type_arguments(
callee_type: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
strict: bool = True,
Expand All @@ -53,7 +54,7 @@ def infer_function_type_arguments(
"""
# Infer constraints.
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, formal_to_actual, context
callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context
)

# Solve constraints.
Expand Down
Loading

0 comments on commit 7f65cc7

Please sign in to comment.