From 72c6bc702ac504f3829b3d2c5b5bb38767e9dc83 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 15 Nov 2023 21:16:26 +0000 Subject: [PATCH] Make imprecise contraints handling more robust --- mypy/constraints.py | 76 +++++++++++-------- mypy/expandtype.py | 1 + .../unit/check-parameter-specification.test | 23 ++++++ 3 files changed, 67 insertions(+), 33 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 88ede372e011..d6a4b28799e5 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -226,25 +226,22 @@ def infer_constraints_for_callable( actual_type = mapper.expand_actual_type( actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] ) - if ( - param_spec - and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2) - and not incomplete_star_mapping - ): + if param_spec and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2): # If actual arguments are mapped to ParamSpec type, we can't infer individual # constraints, instead store them and infer single constraint at the end. # It is impossible to map actual kind to formal kind, so use some heuristic. # This inference is used as a fallback, so relying on heuristic should be OK. - param_spec_arg_types.append( - mapper.expand_actual_type( - actual_arg_type, arg_kinds[actual], None, arg_kinds[actual] + if not incomplete_star_mapping: + param_spec_arg_types.append( + mapper.expand_actual_type( + actual_arg_type, arg_kinds[actual], None, arg_kinds[actual] + ) ) - ) - actual_kind = arg_kinds[actual] - param_spec_arg_kinds.append( - ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind - ) - param_spec_arg_names.append(arg_names[actual] if arg_names else None) + actual_kind = arg_kinds[actual] + param_spec_arg_kinds.append( + ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind + ) + param_spec_arg_names.append(arg_names[actual] if arg_names else None) else: c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) constraints.extend(c) @@ -267,6 +264,9 @@ def infer_constraints_for_callable( ), ) ) + if any(isinstance(v, ParamSpecType) for v in callee.variables): + # As a perf optimization filter imprecise constraints only when we can have them. + constraints = filter_imprecise_kinds(constraints) return constraints @@ -1094,29 +1094,18 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: ) param_spec_target: Type | None = None - skip_imprecise = ( - any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds - ) if not cactual_ps: max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]) prefix_len = min(prefix_len, max_prefix_len) - # This logic matches top-level callable constraint exception, if we managed - # to get other constraints for ParamSpec, don't infer one with imprecise kinds - if not skip_imprecise: - param_spec_target = Parameters( - arg_types=cactual.arg_types[prefix_len:], - arg_kinds=cactual.arg_kinds[prefix_len:], - arg_names=cactual.arg_names[prefix_len:], - variables=cactual.variables - if not type_state.infer_polymorphic - else [], - imprecise_arg_kinds=cactual.imprecise_arg_kinds, - ) + param_spec_target = Parameters( + arg_types=cactual.arg_types[prefix_len:], + arg_kinds=cactual.arg_kinds[prefix_len:], + arg_names=cactual.arg_names[prefix_len:], + variables=cactual.variables if not type_state.infer_polymorphic else [], + imprecise_arg_kinds=cactual.imprecise_arg_kinds, + ) else: - if ( - len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types) - and not skip_imprecise - ): + if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types): param_spec_target = cactual_ps.copy_modified( prefix=Parameters( arg_types=cactual_ps.prefix.arg_types[prefix_len:], @@ -1611,3 +1600,24 @@ def infer_callable_arguments_constraints( infer_directed_arg_constraints(left_by_name.typ, right_by_name.typ, direction) ) return res + + +def filter_imprecise_kinds(cs: list[Constraint]) -> list[Constraint]: + """For each ParamSpec remove all imprecise constraints, if at least one precise available.""" + have_precise = set() + for c in cs: + if not isinstance(c.origin_type_var, ParamSpecType): + continue + if ( + isinstance(c.target, ParamSpecType) + or isinstance(c.target, Parameters) + and not c.target.imprecise_arg_kinds + ): + have_precise.add(c.type_var) + new_cs = [] + for c in cs: + if not isinstance(c.origin_type_var, ParamSpecType) or c.type_var not in have_precise: + new_cs.append(c) + if not isinstance(c.target, Parameters) or not c.target.imprecise_arg_kinds: + new_cs.append(c) + return new_cs diff --git a/mypy/expandtype.py b/mypy/expandtype.py index cb09a1ee99f5..3acec4b96d06 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -253,6 +253,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: t.prefix.arg_kinds + repl.arg_kinds, t.prefix.arg_names + repl.arg_names, variables=[*t.prefix.variables, *repl.variables], + imprecise_arg_kinds=repl.imprecise_arg_kinds, ) else: # We could encode Any as trivial parameters etc., but it would be too verbose. diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index eb6fbf07f045..f178bd369d11 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2108,3 +2108,26 @@ def func2(arg: T) -> List[Union[T, str]]: reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]" reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]" [builtins fixtures/paramspec.pyi] + +[case testParamSpecPreciseKindsUsedIfPossible] +from typing import Callable, Generic +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +class Case(Generic[P]): + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: + pass + +def _test(a: int, b: int = 0) -> None: ... + +def parametrize( + func: Callable[P, None], *cases: Case[P], **named_cases: Case[P] +) -> Callable[[], None]: + ... + +parametrize(_test, Case(1, 2), Case(3, 4)) +parametrize(_test, Case(1, b=2), Case(3, b=4)) +parametrize(_test, Case(1, 2), Case(3)) +parametrize(_test, Case(1, 2), Case(3, b=4)) +[builtins fixtures/paramspec.pyi]