Skip to content

Commit

Permalink
[dataclass_transform] Support default parameters (#14580)
Browse files Browse the repository at this point in the history
PEP 681 defines several parameters for `typing.dataclass_transform`.
This commit adds support for collecting these arguments and forwarding
them to the dataclasses plugin. For this first iteration, only the
`*_default` parameters are supported; `field_specifiers` will be
implemented in a separate commit, since it is more complicated.
  • Loading branch information
wesleywright authored Feb 8, 2023
1 parent f505614 commit 9e85f9b
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 75 deletions.
73 changes: 64 additions & 9 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_import_all(self)


FUNCBASE_FLAGS: Final = [
"is_property",
"is_class",
"is_static",
"is_final",
"is_dataclass_transform",
]
FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"]


class FuncBase(Node):
Expand All @@ -512,7 +506,6 @@ class FuncBase(Node):
"is_static", # Uses "@staticmethod"
"is_final", # Uses "@final"
"_fullname",
"is_dataclass_transform", # Is decorated with "@typing.dataclass_transform" or similar
)

def __init__(self) -> None:
Expand All @@ -531,7 +524,6 @@ def __init__(self) -> None:
self.is_final = False
# Name with module prefix
self._fullname = ""
self.is_dataclass_transform = False

@property
@abstractmethod
Expand Down Expand Up @@ -758,6 +750,8 @@ class FuncDef(FuncItem, SymbolNode, Statement):
"deco_line",
"is_trivial_body",
"is_mypy_only",
# Present only when a function is decorated with @typing.datasclass_transform or similar
"dataclass_transform_spec",
)

__match_args__ = ("name", "arguments", "type", "body")
Expand Down Expand Up @@ -785,6 +779,7 @@ def __init__(
self.deco_line: int | None = None
# Definitions that appear in if TYPE_CHECKING are marked with this flag.
self.is_mypy_only = False
self.dataclass_transform_spec: DataclassTransformSpec | None = None

@property
def name(self) -> str:
Expand All @@ -810,6 +805,11 @@ def serialize(self) -> JsonDict:
"flags": get_flags(self, FUNCDEF_FLAGS),
"abstract_status": self.abstract_status,
# TODO: Do we need expanded, original_def?
"dataclass_transform_spec": (
None
if self.dataclass_transform_spec is None
else self.dataclass_transform_spec.serialize()
),
}

@classmethod
Expand All @@ -832,6 +832,11 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
ret.arg_names = data["arg_names"]
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
ret.abstract_status = data["abstract_status"]
ret.dataclass_transform_spec = (
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
if data["dataclass_transform_spec"] is not None
else None
)
# Leave these uninitialized so that future uses will trigger an error
del ret.arguments
del ret.max_pos
Expand Down Expand Up @@ -3857,6 +3862,56 @@ def deserialize(cls, data: JsonDict) -> SymbolTable:
return st


class DataclassTransformSpec:
"""Specifies how a dataclass-like transform should be applied. The fields here are based on the
parameters accepted by `typing.dataclass_transform`."""

__slots__ = (
"eq_default",
"order_default",
"kw_only_default",
"frozen_default",
"field_specifiers",
)

def __init__(
self,
*,
eq_default: bool | None = None,
order_default: bool | None = None,
kw_only_default: bool | None = None,
field_specifiers: tuple[str, ...] | None = None,
# Specified outside of PEP 681:
# frozen_default was added to CPythonin https://github.com/python/cpython/pull/99958 citing
# positive discussion in typing-sig
frozen_default: bool | None = None,
):
self.eq_default = eq_default if eq_default is not None else True
self.order_default = order_default if order_default is not None else False
self.kw_only_default = kw_only_default if kw_only_default is not None else False
self.frozen_default = frozen_default if frozen_default is not None else False
self.field_specifiers = field_specifiers if field_specifiers is not None else ()

def serialize(self) -> JsonDict:
return {
"eq_default": self.eq_default,
"order_default": self.order_default,
"kw_only_default": self.kw_only_default,
"frozen_only_default": self.frozen_default,
"field_specifiers": self.field_specifiers,
}

@classmethod
def deserialize(cls, data: JsonDict) -> DataclassTransformSpec:
return DataclassTransformSpec(
eq_default=data.get("eq_default"),
order_default=data.get("order_default"),
kw_only_default=data.get("kw_only_default"),
frozen_default=data.get("frozen_default"),
field_specifiers=data.get("field_specifiers"),
)


def get_flags(node: Node, names: list[str]) -> list[str]:
return [name for name in names if getattr(node, name)]

Expand Down
106 changes: 71 additions & 35 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
AssignmentStmt,
CallExpr,
Context,
DataclassTransformSpec,
Expression,
JsonDict,
NameExpr,
Node,
PlaceholderNode,
RefExpr,
SymbolTableNode,
Expand All @@ -37,6 +39,7 @@
add_method,
deserialize_and_fixup_type,
)
from mypy.semanal_shared import find_dataclass_transform_spec
from mypy.server.trigger import make_wildcard_trigger
from mypy.state import state
from mypy.typeops import map_type_from_supertype
Expand All @@ -56,11 +59,16 @@

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
# The set of functions that generate dataclass fields.
field_makers: Final = {"dataclasses.field"}


SELF_TVAR_NAME: Final = "_DT"
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
eq_default=True,
order_default=False,
kw_only_default=False,
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)


class DataclassAttribute:
Expand Down Expand Up @@ -155,6 +163,7 @@ class DataclassTransformer:

def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx
self._spec = _get_transform_spec(ctx.reason)

def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying
Expand All @@ -172,9 +181,9 @@ def transform(self) -> bool:
return False
decorator_arguments = {
"init": _get_decorator_bool_argument(self._ctx, "init", True),
"eq": _get_decorator_bool_argument(self._ctx, "eq", True),
"order": _get_decorator_bool_argument(self._ctx, "order", False),
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", False),
"eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default),
"order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default),
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", self._spec.frozen_default),
"slots": _get_decorator_bool_argument(self._ctx, "slots", False),
"match_args": _get_decorator_bool_argument(self._ctx, "match_args", True),
}
Expand Down Expand Up @@ -411,7 +420,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:

# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, "kw_only", False)
kw_only = _get_decorator_bool_argument(ctx, "kw_only", self._spec.kw_only_default)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
Expand Down Expand Up @@ -461,7 +470,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
if self._is_kw_only_type(node_type):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
has_field_call, field_args = self._collect_field_args(stmt.rvalue, ctx)

is_in_init_param = field_args.get("init")
if is_in_init_param is None:
Expand Down Expand Up @@ -614,6 +623,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
kind=MDEF, node=var, plugin_generated=True
)

def _collect_field_args(
self, expr: Expression, ctx: ClassDefContext
) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname in self._spec.field_specifiers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}


def dataclass_tag_callback(ctx: ClassDefContext) -> None:
"""Record that we have a dataclass in the main semantic analysis pass.
Expand All @@ -631,32 +670,29 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
return transformer.transform()


def _collect_field_args(
expr: Expression, ctx: ClassDefContext
) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
def _get_transform_spec(reason: Expression) -> DataclassTransformSpec:
"""Find the relevant transform parameters from the decorator/parent class/metaclass that
triggered the dataclasses plugin.
Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform
function, we also use it for traditional dataclasses.dataclass classes as well for simplicity.
In those cases, we return a default spec rather than one based on a call to
`typing.dataclass_transform`.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname in field_makers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}
if _is_dataclasses_decorator(reason):
return _TRANSFORM_SPEC_FOR_DATACLASSES

spec = find_dataclass_transform_spec(reason)
assert spec is not None, (
"trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor "
"decorated with typing.dataclass_transform"
)
return spec


def _is_dataclasses_decorator(node: Node) -> bool:
if isinstance(node, CallExpr):
node = node.callee
if isinstance(node, RefExpr):
return node.fullname in dataclass_makers
return False
53 changes: 33 additions & 20 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
ConditionalExpr,
Context,
ContinueStmt,
DataclassTransformSpec,
Decorator,
DelStmt,
DictExpr,
Expand Down Expand Up @@ -213,6 +214,7 @@
PRIORITY_FALLBACKS,
SemanticAnalyzerInterface,
calculate_tuple_fallback,
find_dataclass_transform_spec,
has_placeholder,
set_callable_name as set_callable_name,
)
Expand Down Expand Up @@ -1523,7 +1525,7 @@ def visit_decorator(self, dec: Decorator) -> None:
elif isinstance(d, CallExpr) and refers_to_fullname(
d.callee, DATACLASS_TRANSFORM_NAMES
):
dec.func.is_dataclass_transform = True
dec.func.dataclass_transform_spec = self.parse_dataclass_transform_spec(d)
elif not dec.var.is_property:
# We have seen a "non-trivial" decorator before seeing @property, if
# we will see a @property later, give an error, as we don't support this.
Expand Down Expand Up @@ -1728,7 +1730,7 @@ def apply_class_plugin_hooks(self, defn: ClassDef) -> None:
# Special case: if the decorator is itself decorated with
# typing.dataclass_transform, apply the hook for the dataclasses plugin
# TODO: remove special casing here
if hook is None and is_dataclass_transform_decorator(decorator):
if hook is None and find_dataclass_transform_spec(decorator):
hook = dataclasses_plugin.dataclass_tag_callback
if hook:
hook(ClassDefContext(defn, decorator, self))
Expand Down Expand Up @@ -6456,6 +6458,35 @@ def set_future_import_flags(self, module_name: str) -> None:
def is_future_flag_set(self, flag: str) -> bool:
return self.modules[self.cur_mod_id].is_future_flag_set(flag)

def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSpec:
"""Build a DataclassTransformSpec from the arguments passed to the given call to
typing.dataclass_transform."""
parameters = DataclassTransformSpec()
for name, value in zip(call.arg_names, call.args):
# field_specifiers is currently the only non-boolean argument; check for it first so
# so the rest of the block can fail through to handling booleans
if name == "field_specifiers":
self.fail('"field_specifiers" support is currently unimplemented', call)
continue

boolean = self.parse_bool(value)
if boolean is None:
self.fail(f'"{name}" argument must be a True or False literal', call)
continue

if name == "eq_default":
parameters.eq_default = boolean
elif name == "order_default":
parameters.order_default = boolean
elif name == "kw_only_default":
parameters.kw_only_default = boolean
elif name == "frozen_default":
parameters.frozen_default = boolean
else:
self.fail(f'Unrecognized dataclass_transform parameter "{name}"', call)

return parameters


def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
Expand Down Expand Up @@ -6645,21 +6676,3 @@ def halt(self, reason: str = ...) -> NoReturn:
return isinstance(stmt, PassStmt) or (
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
)


def is_dataclass_transform_decorator(node: Node | None) -> bool:
if isinstance(node, RefExpr):
return is_dataclass_transform_decorator(node.node)
if isinstance(node, CallExpr):
# Like dataclasses.dataclass, transform-based decorators can be applied either with or
# without parameters; ie, both of these forms are accepted:
#
# @typing.dataclass_transform
# class Foo: ...
# @typing.dataclass_transform(eq=True, order=True, ...)
# class Bar: ...
#
# We need to unwrap the call for the second variant.
return is_dataclass_transform_decorator(node.callee)

return isinstance(node, Decorator) and node.func.is_dataclass_transform
Loading

0 comments on commit 9e85f9b

Please sign in to comment.