Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow unpacking of TypedDict into TypedDict #13353

Closed
wants to merge 13 commits into from
122 changes: 79 additions & 43 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_types, narrow_declared_type
from mypy.message_registry import ErrorMessage
from mypy.messages import MessageBuilder
from mypy.messages import MessageBuilder, format_type_bare
from mypy.nodes import (
ARG_NAMED,
ARG_POS,
Expand Down Expand Up @@ -629,13 +629,11 @@ def check_typeddict_call(
args: List[Expression],
context: Context,
) -> Type:
if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]):
# ex: Point(x=42, y=1337)
assert all(arg_name is not None for arg_name in arg_names)
item_names = cast(List[str], arg_names)
if all(ak in {ARG_NAMED, ARG_STAR2} for ak in arg_kinds):
# ex: Point(x=42, y=1337, **other_point)
item_args = args
return self.check_typeddict_call_with_kwargs(
callee, dict(zip(item_names, item_args)), context
callee, list(zip(arg_names, item_args)), context
)

if len(args) == 1 and arg_kinds[0] == ARG_POS:
Expand All @@ -647,44 +645,46 @@ def check_typeddict_call(
# ex: Point(dict(x=42, y=1337))
return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context)

if len(args) == 0:
# ex: EmptyDict()
return self.check_typeddict_call_with_kwargs(callee, {}, context)

self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context)
return AnyType(TypeOfAny.from_error)

def validate_typeddict_kwargs(self, kwargs: DictExpr) -> "Optional[Dict[str, Expression]]":
item_args = [item[1] for item in kwargs.items]

item_names = [] # List[str]
def validate_typeddict_kwargs(
self, kwargs: DictExpr
) -> Optional[List[Tuple[Optional[str], Expression]]]:
"""Validate kwargs for TypedDict constructor, e.g. Point({'x': 1, 'y': 2}).
Check that all items have string literal keys or are using unpack operator (**)
"""
items: List[Tuple[Optional[str], Expression]] = []
for item_name_expr, item_arg in kwargs.items:
# If unpack operator (**) was used, name will be None
if item_name_expr is None:
items.append((None, item_arg))
continue
literal_value = None
if item_name_expr:
key_type = self.accept(item_name_expr)
values = try_getting_str_literals(item_name_expr, key_type)
if values and len(values) == 1:
literal_value = values[0]
key_type = self.accept(item_name_expr)
values = try_getting_str_literals(item_name_expr, key_type)
if values and len(values) == 1:
literal_value = values[0]
if literal_value is None:
key_context = item_name_expr or item_arg
self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_context)
return None
else:
item_names.append(literal_value)
return dict(zip(item_names, item_args))
items.append((literal_value, item_arg))
return items

def match_typeddict_call_with_dict(
self, callee: TypedDictType, kwargs: DictExpr, context: Context
) -> bool:
def match_typeddict_call_with_dict(self, callee: TypedDictType, kwargs: DictExpr) -> bool:
"""Check that kwargs is valid set of TypedDict items, contains all required keys of callee, and has no extraneous keys"""
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs)
if validated_kwargs is not None:
return callee.required_keys <= set(validated_kwargs.keys()) <= set(callee.items.keys())
return callee.required_keys <= dict(validated_kwargs).keys() <= callee.items.keys()
else:
return False

def check_typeddict_call_with_dict(
self, callee: TypedDictType, kwargs: DictExpr, context: Context
) -> Type:
"""Check TypedDict constructor of format Point({'x': 1, 'y': 2})"""
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs)
if validated_kwargs is not None:
return self.check_typeddict_call_with_kwargs(
Expand All @@ -694,30 +694,66 @@ def check_typeddict_call_with_dict(
return AnyType(TypeOfAny.from_error)

def check_typeddict_call_with_kwargs(
self, callee: TypedDictType, kwargs: Dict[str, Expression], context: Context
self,
callee: TypedDictType,
kwargs: List[Tuple[Optional[str], Expression]],
context: Context,
) -> Type:
if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())):
expected_keys = [
key
for key in callee.items.keys()
if key in callee.required_keys or key in kwargs.keys()
]
actual_keys = kwargs.keys()
"""Check TypedDict constructor of format Point(x=1, y=2)"""
# Infer types of item values and expand unpack operators
items: Dict[str, Tuple[Expression, Type]] = {}
sure_keys: List[str] = []
maybe_keys: List[str] = [] # Will contain non-required items of unpacked TypedDicts
for key, value_expr in kwargs:
if key is not None:
# Regular key and value
value_type = self.accept(value_expr, callee.items.get(key))
items[key] = (value_expr, value_type)
sure_keys.append(key)
else:
# Unpack operator (**) was used; unpack all items of the type of this expression into items list
value_type = self.accept(value_expr, callee)
proper_type = get_proper_type(value_type)
if isinstance(proper_type, TypedDictType):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work for unions of TypedDicts (or other more complex types). I'm open to suggestions to improve this

for nested_key, nested_value_type in proper_type.items.items():
items[nested_key] = (value_expr, nested_value_type)
if nested_key in proper_type.required_keys:
sure_keys.append(nested_key)
else:
maybe_keys.append(nested_key)
else:
# Fail when trying to unpack anything but TypedDict
self.chk.fail(
ErrorMessage.format(
message_registry.TYPEDDICT_UNPACKING_MUST_BE_TYPEDDICT,
format_type_bare(value_type),
),
value_expr,
)
return AnyType(TypeOfAny.from_error)

if not (
callee.required_keys
<= set(sure_keys)
<= set(sure_keys + maybe_keys)
<= set(callee.items.keys())
):
self.msg.unexpected_typeddict_keys(
callee, expected_keys=expected_keys, actual_keys=list(actual_keys), context=context
callee, actual_sure_keys=sure_keys, actual_maybe_keys=maybe_keys, context=context
)
return AnyType(TypeOfAny.from_error)

# Check item value types
for (item_name, item_expected_type) in callee.items.items():
if item_name in kwargs:
item_value = kwargs[item_name]
self.chk.check_simple_assignment(
lvalue_type=item_expected_type,
rvalue=item_value,
context=item_value,
if item_name in items:
item_value_expr, item_actual_type = items[item_name]
self.chk.check_subtype(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm calling check_subtype instead of check_simple_assignment because that allows consistent handling of both immediate and unpacked items. Let me know if check_simple_assignment does any important additional checks that might be important

subtype=item_actual_type,
supertype=item_expected_type,
context=item_value_expr,
msg=message_registry.INCOMPATIBLE_TYPES,
lvalue_name=f'TypedDict item "{item_name}"',
rvalue_name="expression",
subtype_label="expression has type",
supertype_label=f'TypedDict item "{item_name}" has type',
code=codes.TYPEDDICT_ITEM,
)

Expand Down Expand Up @@ -4009,7 +4045,7 @@ def find_typeddict_context(
for item in context.items:
item_context = self.find_typeddict_context(item, dict_expr)
if item_context is not None and self.match_typeddict_call_with_dict(
item_context, dict_expr, dict_expr
item_context, dict_expr
):
items.append(item_context)
if len(items) == 1:
Expand Down
3 changes: 3 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage":
TYPEDDICT_KEY_MUST_BE_STRING_LITERAL: Final = ErrorMessage(
"Expected TypedDict key to be string literal"
)
TYPEDDICT_UNPACKING_MUST_BE_TYPEDDICT: Final = ErrorMessage(
"{} cannot be unpacked into TypedDict (must be TypedDict)"
)
MALFORMED_ASSERT: Final = ErrorMessage("Assertion is always true, perhaps remove parentheses?")
DUPLICATE_TYPE_SIGNATURES: Final = "Function has duplicate type signatures"
DESCRIPTOR_SET_NOT_CALLABLE: Final = ErrorMessage("{}.__set__ is not callable")
Expand Down
66 changes: 29 additions & 37 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,46 +1569,38 @@ def explicit_any(self, ctx: Context) -> None:
def unexpected_typeddict_keys(
self,
typ: TypedDictType,
expected_keys: List[str],
actual_keys: List[str],
actual_sure_keys: List[str],
actual_maybe_keys: List[str],
context: Context,
) -> None:
actual_set = set(actual_keys)
expected_set = set(expected_keys)
actual_keys = actual_sure_keys + actual_maybe_keys
actual_sure_set = set(actual_sure_keys)
required_keys = [k for k in typ.items.keys() if k in typ.required_keys]
required_set = set(required_keys)
expected_set = set(typ.items.keys())

if not typ.is_anonymous():
# Generate simpler messages for some common special cases.
if actual_set < expected_set:
# Use list comprehension instead of set operations to preserve order.
missing = [key for key in expected_keys if key not in actual_set]
self.fail(
"Missing {} for TypedDict {}".format(
format_key_list(missing, short=True), format_type(typ)
),
context,
code=codes.TYPEDDICT_ITEM,
)
return
else:
extra = [key for key in actual_keys if key not in expected_set]
if extra:
# If there are both extra and missing keys, only report extra ones for
# simplicity.
self.fail(
"Extra {} for TypedDict {}".format(
format_key_list(extra, short=True), format_type(typ)
),
context,
code=codes.TYPEDDICT_ITEM,
)
return
found = format_key_list(actual_keys, short=True)
if not expected_keys:
self.fail(f"Unexpected TypedDict {found}", context)
return
expected = format_key_list(expected_keys)
if actual_keys and actual_set < expected_set:
found = f"only {found}"
self.fail(f"Expected {expected} but found {found}", context, code=codes.TYPEDDICT_ITEM)
type_description = f" for TypedDict {format_type(typ)}"
else:
type_description = ""

if actual_sure_set < required_set:
# Use list comprehension instead of set operations to preserve order.
missing = [key for key in required_keys if key not in actual_sure_set]
self.fail(
f"Missing {format_key_list(missing, short=True)}{type_description}",
context,
code=codes.TYPEDDICT_ITEM,
)
else:
# If there are both extra and missing keys, only report extra ones for
# simplicity.
extra = [key for key in actual_keys if key not in expected_set]
self.fail(
f"Extra {format_key_list(extra, short=True)}{type_description}",
context,
code=codes.TYPEDDICT_ITEM,
)

def typeddict_key_must_be_string_literal(self, typ: TypedDictType, context: Context) -> None:
self.fail(
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-errorcodes.test
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ class E(TypedDict):
a: D = {'x': ''} # E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") [typeddict-item]
b: D = {'y': ''} # E: Extra key "y" for TypedDict "D" [typeddict-item]
c = D(x=0) if int() else E(x=0, y=0)
c = {} # E: Expected TypedDict key "x" but found no keys [typeddict-item]
c = {} # E: Missing key "x" [typeddict-item]

a['y'] = 1 # E: TypedDict "D" has no key "y" [typeddict-item]
a['x'] = 'x' # E: Value of "x" has incompatible type "str"; expected "int" [typeddict-item]
Expand Down
Loading