diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e245ca1cbd8f..01ee003995b4 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -32,7 +32,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_types, narrow_declared_type from mypy.message_registry import ErrorMessage -from mypy.messages import MessageBuilder +from mypy.messages import MessageBuilder, format_type_bare from mypy.nodes import ( ARG_NAMED, ARG_POS, @@ -629,13 +629,11 @@ def check_typeddict_call( args: List[Expression], context: Context, ) -> Type: - if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]): - # ex: Point(x=42, y=1337) - assert all(arg_name is not None for arg_name in arg_names) - item_names = cast(List[str], arg_names) + if all(ak in {ARG_NAMED, ARG_STAR2} for ak in arg_kinds): + # ex: Point(x=42, y=1337, **other_point) item_args = args return self.check_typeddict_call_with_kwargs( - callee, dict(zip(item_names, item_args)), context + callee, list(zip(arg_names, item_args)), context ) if len(args) == 1 and arg_kinds[0] == ARG_POS: @@ -647,44 +645,46 @@ def check_typeddict_call( # ex: Point(dict(x=42, y=1337)) return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context) - if len(args) == 0: - # ex: EmptyDict() - return self.check_typeddict_call_with_kwargs(callee, {}, context) - self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context) return AnyType(TypeOfAny.from_error) - def validate_typeddict_kwargs(self, kwargs: DictExpr) -> "Optional[Dict[str, Expression]]": - item_args = [item[1] for item in kwargs.items] - - item_names = [] # List[str] + def validate_typeddict_kwargs( + self, kwargs: DictExpr + ) -> Optional[List[Tuple[Optional[str], Expression]]]: + """Validate kwargs for TypedDict constructor, e.g. Point({'x': 1, 'y': 2}). + Check that all items have string literal keys or are using unpack operator (**) + """ + items: List[Tuple[Optional[str], Expression]] = [] for item_name_expr, item_arg in kwargs.items: + # If unpack operator (**) was used, name will be None + if item_name_expr is None: + items.append((None, item_arg)) + continue literal_value = None - if item_name_expr: - key_type = self.accept(item_name_expr) - values = try_getting_str_literals(item_name_expr, key_type) - if values and len(values) == 1: - literal_value = values[0] + key_type = self.accept(item_name_expr) + values = try_getting_str_literals(item_name_expr, key_type) + if values and len(values) == 1: + literal_value = values[0] if literal_value is None: key_context = item_name_expr or item_arg self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_context) return None else: - item_names.append(literal_value) - return dict(zip(item_names, item_args)) + items.append((literal_value, item_arg)) + return items - def match_typeddict_call_with_dict( - self, callee: TypedDictType, kwargs: DictExpr, context: Context - ) -> bool: + def match_typeddict_call_with_dict(self, callee: TypedDictType, kwargs: DictExpr) -> bool: + """Check that kwargs is valid set of TypedDict items, contains all required keys of callee, and has no extraneous keys""" validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) if validated_kwargs is not None: - return callee.required_keys <= set(validated_kwargs.keys()) <= set(callee.items.keys()) + return callee.required_keys <= dict(validated_kwargs).keys() <= callee.items.keys() else: return False def check_typeddict_call_with_dict( self, callee: TypedDictType, kwargs: DictExpr, context: Context ) -> Type: + """Check TypedDict constructor of format Point({'x': 1, 'y': 2})""" validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) if validated_kwargs is not None: return self.check_typeddict_call_with_kwargs( @@ -694,30 +694,66 @@ def check_typeddict_call_with_dict( return AnyType(TypeOfAny.from_error) def check_typeddict_call_with_kwargs( - self, callee: TypedDictType, kwargs: Dict[str, Expression], context: Context + self, + callee: TypedDictType, + kwargs: List[Tuple[Optional[str], Expression]], + context: Context, ) -> Type: - if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())): - expected_keys = [ - key - for key in callee.items.keys() - if key in callee.required_keys or key in kwargs.keys() - ] - actual_keys = kwargs.keys() + """Check TypedDict constructor of format Point(x=1, y=2)""" + # Infer types of item values and expand unpack operators + items: Dict[str, Tuple[Expression, Type]] = {} + sure_keys: List[str] = [] + maybe_keys: List[str] = [] # Will contain non-required items of unpacked TypedDicts + for key, value_expr in kwargs: + if key is not None: + # Regular key and value + value_type = self.accept(value_expr, callee.items.get(key)) + items[key] = (value_expr, value_type) + sure_keys.append(key) + else: + # Unpack operator (**) was used; unpack all items of the type of this expression into items list + value_type = self.accept(value_expr, callee) + proper_type = get_proper_type(value_type) + if isinstance(proper_type, TypedDictType): + for nested_key, nested_value_type in proper_type.items.items(): + items[nested_key] = (value_expr, nested_value_type) + if nested_key in proper_type.required_keys: + sure_keys.append(nested_key) + else: + maybe_keys.append(nested_key) + else: + # Fail when trying to unpack anything but TypedDict + self.chk.fail( + ErrorMessage.format( + message_registry.TYPEDDICT_UNPACKING_MUST_BE_TYPEDDICT, + format_type_bare(value_type), + ), + value_expr, + ) + return AnyType(TypeOfAny.from_error) + + if not ( + callee.required_keys + <= set(sure_keys) + <= set(sure_keys + maybe_keys) + <= set(callee.items.keys()) + ): self.msg.unexpected_typeddict_keys( - callee, expected_keys=expected_keys, actual_keys=list(actual_keys), context=context + callee, actual_sure_keys=sure_keys, actual_maybe_keys=maybe_keys, context=context ) return AnyType(TypeOfAny.from_error) + # Check item value types for (item_name, item_expected_type) in callee.items.items(): - if item_name in kwargs: - item_value = kwargs[item_name] - self.chk.check_simple_assignment( - lvalue_type=item_expected_type, - rvalue=item_value, - context=item_value, + if item_name in items: + item_value_expr, item_actual_type = items[item_name] + self.chk.check_subtype( + subtype=item_actual_type, + supertype=item_expected_type, + context=item_value_expr, msg=message_registry.INCOMPATIBLE_TYPES, - lvalue_name=f'TypedDict item "{item_name}"', - rvalue_name="expression", + subtype_label="expression has type", + supertype_label=f'TypedDict item "{item_name}" has type', code=codes.TYPEDDICT_ITEM, ) @@ -4009,7 +4045,7 @@ def find_typeddict_context( for item in context.items: item_context = self.find_typeddict_context(item, dict_expr) if item_context is not None and self.match_typeddict_call_with_dict( - item_context, dict_expr, dict_expr + item_context, dict_expr ): items.append(item_context) if len(items) == 1: diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 11c8696f73f4..ce9b6447e3f6 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -116,6 +116,9 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": TYPEDDICT_KEY_MUST_BE_STRING_LITERAL: Final = ErrorMessage( "Expected TypedDict key to be string literal" ) +TYPEDDICT_UNPACKING_MUST_BE_TYPEDDICT: Final = ErrorMessage( + "{} cannot be unpacked into TypedDict (must be TypedDict)" +) MALFORMED_ASSERT: Final = ErrorMessage("Assertion is always true, perhaps remove parentheses?") DUPLICATE_TYPE_SIGNATURES: Final = "Function has duplicate type signatures" DESCRIPTOR_SET_NOT_CALLABLE: Final = ErrorMessage("{}.__set__ is not callable") diff --git a/mypy/messages.py b/mypy/messages.py index 88e98633649e..7e04ccc2971a 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1569,46 +1569,38 @@ def explicit_any(self, ctx: Context) -> None: def unexpected_typeddict_keys( self, typ: TypedDictType, - expected_keys: List[str], - actual_keys: List[str], + actual_sure_keys: List[str], + actual_maybe_keys: List[str], context: Context, ) -> None: - actual_set = set(actual_keys) - expected_set = set(expected_keys) + actual_keys = actual_sure_keys + actual_maybe_keys + actual_sure_set = set(actual_sure_keys) + required_keys = [k for k in typ.items.keys() if k in typ.required_keys] + required_set = set(required_keys) + expected_set = set(typ.items.keys()) + if not typ.is_anonymous(): - # Generate simpler messages for some common special cases. - if actual_set < expected_set: - # Use list comprehension instead of set operations to preserve order. - missing = [key for key in expected_keys if key not in actual_set] - self.fail( - "Missing {} for TypedDict {}".format( - format_key_list(missing, short=True), format_type(typ) - ), - context, - code=codes.TYPEDDICT_ITEM, - ) - return - else: - extra = [key for key in actual_keys if key not in expected_set] - if extra: - # If there are both extra and missing keys, only report extra ones for - # simplicity. - self.fail( - "Extra {} for TypedDict {}".format( - format_key_list(extra, short=True), format_type(typ) - ), - context, - code=codes.TYPEDDICT_ITEM, - ) - return - found = format_key_list(actual_keys, short=True) - if not expected_keys: - self.fail(f"Unexpected TypedDict {found}", context) - return - expected = format_key_list(expected_keys) - if actual_keys and actual_set < expected_set: - found = f"only {found}" - self.fail(f"Expected {expected} but found {found}", context, code=codes.TYPEDDICT_ITEM) + type_description = f" for TypedDict {format_type(typ)}" + else: + type_description = "" + + if actual_sure_set < required_set: + # Use list comprehension instead of set operations to preserve order. + missing = [key for key in required_keys if key not in actual_sure_set] + self.fail( + f"Missing {format_key_list(missing, short=True)}{type_description}", + context, + code=codes.TYPEDDICT_ITEM, + ) + else: + # If there are both extra and missing keys, only report extra ones for + # simplicity. + extra = [key for key in actual_keys if key not in expected_set] + self.fail( + f"Extra {format_key_list(extra, short=True)}{type_description}", + context, + code=codes.TYPEDDICT_ITEM, + ) def typeddict_key_must_be_string_literal(self, typ: TypedDictType, context: Context) -> None: self.fail( diff --git a/test-data/unit/check-errorcodes.test b/test-data/unit/check-errorcodes.test index f1a6f3c77ada..ff29a33e89f6 100644 --- a/test-data/unit/check-errorcodes.test +++ b/test-data/unit/check-errorcodes.test @@ -452,7 +452,7 @@ class E(TypedDict): a: D = {'x': ''} # E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") [typeddict-item] b: D = {'y': ''} # E: Extra key "y" for TypedDict "D" [typeddict-item] c = D(x=0) if int() else E(x=0, y=0) -c = {} # E: Expected TypedDict key "x" but found no keys [typeddict-item] +c = {} # E: Missing key "x" [typeddict-item] a['y'] = 1 # E: TypedDict "D" has no key "y" [typeddict-item] a['x'] = 'x' # E: Value of "x" has incompatible type "str"; expected "int" [typeddict-item] diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 62ac5e31da45..aad547ef1b90 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -943,7 +943,7 @@ T = TypeVar('T') def join(x: T, y: T) -> T: return x ab = join(A(x=1, y=1), B(x=1, y='')) if int(): - ab = {'x': 1, 'z': 1} # E: Expected TypedDict key "x" but found keys ("x", "z") + ab = {'x': 1, 'z': 1} # E: Extra key "z" [builtins fixtures/dict.pyi] [case testCannotCreateAnonymousTypedDictInstanceUsingDictLiteralWithMissingItems] @@ -955,7 +955,7 @@ T = TypeVar('T') def join(x: T, y: T) -> T: return x ab = join(A(x=1, y=1, z=1), B(x=1, y=1, z='')) if int(): - ab = {} # E: Expected TypedDict keys ("x", "y") but found no keys + ab = {} # E: Missing keys ("x", "y") [builtins fixtures/dict.pyi] @@ -1653,9 +1653,9 @@ a.update({'x': 1}) a.update({'x': ''}) # E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") a.update({'x': 1, 'y': []}) a.update({'x': 1, 'y': [1]}) -a.update({'z': 1}) # E: Unexpected TypedDict key "z" -a.update({'z': 1, 'zz': 1}) # E: Unexpected TypedDict keys ("z", "zz") -a.update({'z': 1, 'x': 1}) # E: Expected TypedDict key "x" but found keys ("z", "x") +a.update({'z': 1}) # E: Extra key "z" +a.update({'z': 1, 'zz': 1}) # E: Extra keys ("z", "zz") +a.update({'z': 1, 'x': 1}) # E: Extra key "z" d = {'x': 1} a.update(d) # E: Argument 1 to "update" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'x'?: int, 'y'?: List[int]})" [builtins fixtures/dict.pyi] @@ -2395,3 +2395,232 @@ def func(foo: Union[F1, F2]): # E: Argument 1 to "__setitem__" has incompatible type "int"; expected "str" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackSame] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +foo1: Foo = {'a': 1, 'b': 1} +foo2: Foo = {**foo1, 'b': 2} +foo3 = Foo(**foo1, b=2) +foo4 = Foo({**foo1, 'b': 2}) + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackCompatible] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: int + +foo: Foo = {'a': 1} +bar: Bar = {**foo, 'b': 2} + +[typing fixtures/typing-typeddict.pyi] + + + +[case testTypedDictUnpackIncompatible] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: str + +class Bar(TypedDict): + a: int + b: int + +foo: Foo = {'a': 1, 'b': 'a'} +bar1: Bar = {**foo, 'b': 2} # Incompatible item is overriden +bar2: Bar = {**foo, 'a': 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") + +[builtins fixtures/tuple.pyi] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackNotRequiredKeyIncompatible] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: NotRequired[str] + +class Bar(TypedDict): + a: NotRequired[int] + +foo: Foo = {} +bar: Bar = {**foo} # E: Incompatible types (expression has type "str", TypedDict item "a" has type "int") + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackMissingOrExtraKey] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: int + +foo1: Foo = {'a': 1} +bar1: Bar = {'a': 1, 'b': 1} +foo2: Foo = {**bar1} # E: Extra key "b" for TypedDict "Foo" +bar2: Bar = {**foo1} # E: Missing key "b" for TypedDict "Bar" + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackNotRequiredKeyExtra] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo1: Foo = {'a': 1} +bar1: Bar = {'a': 1} +foo2: Foo = {**bar1} # E: Extra key "b" for TypedDict "Foo" +bar2: Bar = {**foo1} + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackRequiredKeyMissing] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: NotRequired[int] + +class Bar(TypedDict): + a: int + +foo: Foo = {'a': 1} +bar: Bar = {**foo} # E: Missing key "a" for TypedDict "Bar" + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackMultiple] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +class Baz(TypedDict): + a: int + b: int + c: int + +foo: Foo = {'a': 1} +bar: Bar = {'b': 1} +baz: Baz = {**foo, **bar, 'c': 1} + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackNested] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + c: Foo + d: int + +foo: Foo = {'a': 1, 'b': 1} +bar: Bar = {'c': foo, 'd': 1} +bar2: Bar = {**bar, 'c': {**bar['c'], 'b': 2}, 'd': 2} + +[builtins fixtures/tuple.pyi] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackNestedError] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + c: Foo + d: int + +foo: Foo = {'a': 1, 'b': 1} +bar: Bar = {'c': foo, 'd': 1} +bar2: Bar = {**bar, 'c': {**bar['c'], 'b': 'wrong'}, 'd': 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") + + +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackUntypedDict] +from typing import TypedDict + +class Bar(TypedDict): + pass + +foo: dict = {} +bar: Bar = {**foo} # E: Dict[Any, Any] cannot be unpacked into TypedDict (must be TypedDict) + +[builtins fixtures/tuple.pyi] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackIntoUnion] +from typing import TypedDict, Union + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +# Would be great if this worked in the future +foo: Foo = {'a': 1} +foo_or_bar: Union[Foo, Bar] = {**foo} # E: Incompatible types in assignment (expression has type "Dict[str, object]", variable has type "Union[Foo, Bar]") + +[builtins fixtures/tuple.pyi] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackFromUnion] +from typing import TypedDict, Union + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + b: int + +# Would be great if this worked in the future +foo_or_bar: Union[Foo, Bar] = {'b': 1} +foo: Bar = {**foo_or_bar} # E: Union[Foo, Bar] cannot be unpacked into TypedDict (must be TypedDict) + +[builtins fixtures/tuple.pyi] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index a79e8743fb97..214eb1921d4b 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1129,7 +1129,7 @@ _testTypedDictMappingMethods.py:10: note: Revealed type is "typing.ItemsView[bui _testTypedDictMappingMethods.py:11: note: Revealed type is "typing.ValuesView[builtins.object]" _testTypedDictMappingMethods.py:12: note: Revealed type is "TypedDict('_testTypedDictMappingMethods.Cell', {'value': builtins.int})" _testTypedDictMappingMethods.py:13: note: Revealed type is "builtins.int" -_testTypedDictMappingMethods.py:15: error: Unexpected TypedDict key "invalid" +_testTypedDictMappingMethods.py:15: error: Extra key "invalid" _testTypedDictMappingMethods.py:16: error: Key "value" of TypedDict "Cell" cannot be deleted _testTypedDictMappingMethods.py:21: note: Revealed type is "builtins.int"