diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 0e98ed048197..01d9c4463174 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -383,8 +383,6 @@ def visit_callable_type(self, t: CallableType) -> CallableType: t = t.expand_param_spec(repl) return t.copy_modified( arg_types=self.expand_types(t.arg_types), - arg_kinds=t.arg_kinds, - arg_names=t.arg_names, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) @@ -402,6 +400,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType: arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:], arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:], ret_type=t.ret_type.accept(self), + from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types), ) var_arg = t.var_arg() diff --git a/mypy/messages.py b/mypy/messages.py index c9bf26f8952e..aab30ee29108 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2116,9 +2116,11 @@ def report_protocol_problems( return # Report member type conflicts - conflict_types = get_conflict_protocol_types(subtype, supertype, class_obj=class_obj) + conflict_types = get_conflict_protocol_types( + subtype, supertype, class_obj=class_obj, options=self.options + ) if conflict_types and ( - not is_subtype(subtype, erase_type(supertype)) + not is_subtype(subtype, erase_type(supertype), options=self.options) or not subtype.type.defn.type_vars or not supertype.type.defn.type_vars ): @@ -2780,7 +2782,11 @@ def [T <: int] f(self, x: int, y: T) -> None slash = True # If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list - if isinstance(tp.definition, FuncDef) and hasattr(tp.definition, "arguments"): + if ( + isinstance(tp.definition, FuncDef) + and hasattr(tp.definition, "arguments") + and not tp.from_concatenate + ): definition_arg_names = [arg.variable.name for arg in tp.definition.arguments] if ( len(definition_arg_names) > len(tp.arg_names) @@ -2857,7 +2863,7 @@ def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str def get_conflict_protocol_types( - left: Instance, right: Instance, class_obj: bool = False + left: Instance, right: Instance, class_obj: bool = False, options: Options | None = None ) -> list[tuple[str, Type, Type]]: """Find members that are defined in 'left' but have incompatible types. Return them as a list of ('member', 'got', 'expected'). @@ -2872,9 +2878,9 @@ def get_conflict_protocol_types( subtype = mypy.typeops.get_protocol_member(left, member, class_obj) if not subtype: continue - is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True) + is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True, options=options) if IS_SETTABLE in get_member_flags(member, right): - is_compat = is_compat and is_subtype(supertype, subtype) + is_compat = is_compat and is_subtype(supertype, subtype, options=options) if not is_compat: conflicts.append((member, subtype, supertype)) return conflicts diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 60fccc7e357c..11847858c62c 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -600,7 +600,7 @@ def check_mixed( type_state.record_negative_subtype_cache_entry(self._subtype_kind, left, right) return nominal if right.type.is_protocol and is_protocol_implementation( - left, right, proper_subtype=self.proper_subtype + left, right, proper_subtype=self.proper_subtype, options=self.options ): return True # We record negative cache entry here, and not in the protocol check like we do for @@ -647,7 +647,7 @@ def visit_param_spec(self, left: ParamSpecType) -> bool: and right.id == left.id and right.flavor == left.flavor ): - return True + return self._is_subtype(left.prefix, right.prefix) if isinstance(right, Parameters) and are_trivial_parameters(right): return True return self._is_subtype(left.upper_bound, self.right) @@ -696,7 +696,7 @@ def visit_callable_type(self, left: CallableType) -> bool: ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate) if self.options - else True, + else False, ) elif isinstance(right, Overloaded): return all(self._is_subtype(left, item) for item in right.items) @@ -863,7 +863,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: strict_concat = ( (self.options.extra_checks or self.options.strict_concatenate) if self.options - else True + else False ) if left_index not in matched_overloads and ( is_callable_compatible( @@ -1003,6 +1003,7 @@ def is_protocol_implementation( proper_subtype: bool = False, class_obj: bool = False, skip: list[str] | None = None, + options: Options | None = None, ) -> bool: """Check whether 'left' implements the protocol 'right'. @@ -1068,7 +1069,9 @@ def f(self) -> A: ... # Nominal check currently ignores arg names # NOTE: If we ever change this, be sure to also change the call to # SubtypeVisitor.build_subtype_kind(...) down below. - is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=ignore_names) + is_compat = is_subtype( + subtype, supertype, ignore_pos_arg_names=ignore_names, options=options + ) else: is_compat = is_proper_subtype(subtype, supertype) if not is_compat: @@ -1080,7 +1083,7 @@ def f(self) -> A: ... superflags = get_member_flags(member, right) if IS_SETTABLE in superflags: # Check opposite direction for settable attributes. - if not is_subtype(supertype, subtype): + if not is_subtype(supertype, subtype, options=options): return False if not class_obj: if IS_SETTABLE not in superflags: @@ -1479,7 +1482,7 @@ def are_parameters_compatible( ignore_pos_arg_names: bool = False, check_args_covariantly: bool = False, allow_partial_overlap: bool = False, - strict_concatenate_check: bool = True, + strict_concatenate_check: bool = False, ) -> bool: """Helper function for is_callable_compatible, used for Parameter compatibility""" if right.is_ellipsis_args: diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index e59b12d47980..4a4c19b4a0e9 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6483,7 +6483,7 @@ P = ParamSpec("P") R = TypeVar("R") @overload -def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... +def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types @overload def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ... def func(x: Callable[..., R]) -> Callable[..., R]: ... diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index f523cb005a2c..b06944389623 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1576,3 +1576,73 @@ def test() -> None: ... # TODO: avoid this error, although it may be non-trivial. apply(apply, test) # E: Argument 2 to "apply" has incompatible type "Callable[[], None]"; expected "Callable[P, T]" [builtins fixtures/paramspec.pyi] + +[case testParamSpecPrefixSubtypingGenericInvalid] +from typing import Generic +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") + +class A(Generic[P]): + def foo(self, *args: P.args, **kwargs: P.kwargs): + ... + +def bar(b: A[P]) -> A[Concatenate[int, P]]: + return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") +[builtins fixtures/paramspec.pyi] + +[case testParamSpecPrefixSubtypingProtocolInvalid] +from typing import Protocol +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") + +class A(Protocol[P]): + def foo(self, *args: P.args, **kwargs: P.kwargs): + ... + +def bar(b: A[P]) -> A[Concatenate[int, P]]: + return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") +[builtins fixtures/paramspec.pyi] + +[case testParamSpecPrefixSubtypingValidNonStrict] +from typing import Protocol +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") + +class A(Protocol[P]): + def foo(self, a: int, *args: P.args, **kwargs: P.kwargs): + ... + +class B(Protocol[P]): + def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs): + ... + +def bar(b: B[P]) -> A[Concatenate[int, P]]: + return b +[builtins fixtures/paramspec.pyi] + +[case testParamSpecPrefixSubtypingInvalidStrict] +# flags: --extra-checks +from typing import Protocol +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") + +class A(Protocol[P]): + def foo(self, a: int, *args: P.args, **kwargs: P.kwargs): + ... + +class B(Protocol[P]): + def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs): + ... + +def bar(b: B[P]) -> A[Concatenate[int, P]]: + return b # E: Incompatible return value type (got "B[P]", expected "A[[int, **P]]") \ + # N: Following member(s) of "B[P]" have conflicts: \ + # N: Expected: \ + # N: def foo(self, a: int, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \ + # N: Got: \ + # N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any +[builtins fixtures/paramspec.pyi]