Skip to content

Commit

Permalink
Keep __init__ in dataclasses and add more tests
Browse files Browse the repository at this point in the history
We cannot safely remove `__init__` and depend on the plugin
because its signature depends on dataclass field assignments to
`dataclasses.field` and these assignments are not included in the stub
  • Loading branch information
hamdanal committed Jul 14, 2023
1 parent faae3c0 commit 691b8e9
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 12 deletions.
15 changes: 13 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,11 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
self.clear_decorators()

def visit_func_def(self, o: FuncDef) -> None:
if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated:
# Skip methods generated by the @dataclass decorator
is_dataclass_generated = (
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
)
if is_dataclass_generated and o.name != "__init__":
# Skip methods generated by the @dataclass decorator (except for __init__)
return
if (
self.is_private_name(o.name, o.fullname)
Expand Down Expand Up @@ -769,6 +772,12 @@ def visit_func_def(self, o: FuncDef) -> None:
else:
arg = name + annotation
args.append(arg)
if o.name == "__init__" and is_dataclass_generated and "**" in args:
# The dataclass plugin generates invalid nameless "*" and "**" arguments
new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique
args[args.index("**")] = f"**{new_name}__" # same here

retname = None
if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType):
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
Expand Down Expand Up @@ -1413,6 +1422,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
return False
if fullname in EXTRA_EXPORTED:
return False
if name == "_":
return False
return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS)

def is_private_member(self, fullname: str) -> bool:
Expand Down
90 changes: 80 additions & 10 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3318,17 +3318,25 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
class X(_Incomplete): ...
class Y(_Incomplete): ...


[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass
from dataclasses import dataclass, InitVar, KW_ONLY
from dataclasses import dataclass as dc
from typing import ClassVar

@dataclasses.dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dcs.dataclass
Expand All @@ -3340,15 +3348,27 @@ class Z: ...
@dc
class W: ...

@dataclass(init=False, repr=False)
class V: ...

[out]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, dataclass as dc
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
from typing import ClassVar

@dataclasses.dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...

@dcs.dataclass
Expand All @@ -3357,18 +3377,51 @@ class Y: ...
class Z: ...
@dc
class W: ...
@dataclass(init=False, repr=False)
class V: ...

[case testDataclassWithKeywords]
from dataclasses import dataclass
[case testDataclass_semanal]
from dataclasses import dataclass, InitVar, KW_ONLY
from typing import ClassVar

@dataclass(init=False)
class X: ...
@dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import dataclass
from dataclasses import InitVar, KW_ONLY, dataclass
from typing import ClassVar

@dataclass(init=False)
class X: ...
@dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b, f, g, *, h, i, j) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
from dataclasses import dataclass
Expand All @@ -3387,3 +3440,20 @@ class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[case testDataclassInheritsFromAny_semanal]
from dataclasses import dataclass
import missing

@dataclass
class X(missing.Base):
a: int

[out]
import missing
from dataclasses import dataclass

@dataclass
class X(missing.Base):
a: int
def __init__(self, *selfa_, a, **selfa__) -> None: ...

0 comments on commit 691b8e9

Please sign in to comment.