Skip to content

Commit

Permalink
stubgen: generate valid dataclass stubs
Browse files Browse the repository at this point in the history
Fixes #12441
  • Loading branch information
hamdanal committed Jul 8, 2023
1 parent 6cd8c00 commit faae3c0
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
30 changes: 30 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
OverloadedFuncDef,
Statement,
StrExpr,
TempNode,
TupleExpr,
TypeInfo,
UnaryExpr,
Expand Down Expand Up @@ -650,6 +651,7 @@ def __init__(
self.defined_names: set[str] = set()
# Short names of methods defined in the body of the current class
self.method_names: set[str] = set()
self.processing_dataclass = False

def visit_mypy_file(self, o: MypyFile) -> None:
self.module = o.fullname # Current module being processed
Expand Down Expand Up @@ -699,6 +701,9 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
self.clear_decorators()

def visit_func_def(self, o: FuncDef) -> None:
if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated:
# Skip methods generated by the @dataclass decorator
return
if (
self.is_private_name(o.name, o.fullname)
or self.is_not_in_all(o.name)
Expand Down Expand Up @@ -890,6 +895,9 @@ def visit_class_def(self, o: ClassDef) -> None:
if not self._indent and self._state != EMPTY:
sep = len(self._output)
self.add("\n")
decorators = self.get_class_decorators(o)
for d in decorators:
self.add(f"{self._indent}@{d}\n")
self.add(f"{self._indent}class {o.name}")
self.record_name(o.name)
base_types = self.get_base_types(o)
Expand Down Expand Up @@ -921,6 +929,7 @@ def visit_class_def(self, o: ClassDef) -> None:
else:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False

def get_base_types(self, cdef: ClassDef) -> list[str]:
"""Get list of base classes for a class."""
Expand Down Expand Up @@ -967,6 +976,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
base_types.append(f"{name}={value.accept(p)}")
return base_types

def get_class_decorators(self, cdef: ClassDef) -> list[str]:
decorators: list[str] = []
p = AliasPrinter(self)
for d in cdef.decorators:
if self.is_dataclass(d):
decorators.append(d.accept(p))
self.import_tracker.require_name(get_qualified_name(d))
self.processing_dataclass = True
return decorators

def is_dataclass(self, expr: Expression) -> bool:
if isinstance(expr, CallExpr):
expr = expr.callee
return self.get_fullname(expr) == "dataclasses.dataclass"

def visit_block(self, o: Block) -> None:
# Unreachable statements may be partially uninitialized and that may
# cause trouble.
Expand Down Expand Up @@ -1323,8 +1347,14 @@ def get_init(
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += f"[{final_arg}]"
elif self.processing_dataclass:
# attribute without annotation is not a dataclass field, don't add annotation.
return f"{self._indent}{lvalue} = ...\n"
else:
typename = self.get_str_type_of_node(rvalue)
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
# dataclass field with default value, keep the initializer.
return f"{self._indent}{lvalue}: {typename} = ...\n"
return f"{self._indent}{lvalue}: {typename}\n"

def add(self, string: str) -> None:
Expand Down
70 changes: 70 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3317,3 +3317,73 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...


[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass
from dataclasses import dataclass as dc

@dataclasses.dataclass
class X:
a: int
b: str = "hello"
non_field = None

@dcs.dataclass
class Y: ...

@dataclass
class Z: ...

@dc
class W: ...

[out]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, dataclass as dc

@dataclasses.dataclass
class X:
a: int
b: str = ...
non_field = ...

@dcs.dataclass
class Y: ...
@dataclass
class Z: ...
@dc
class W: ...

[case testDataclassWithKeywords]
from dataclasses import dataclass

@dataclass(init=False)
class X: ...

[out]
from dataclasses import dataclass

@dataclass(init=False)
class X: ...

[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
from dataclasses import dataclass

@dataclass
class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[out]
from dataclasses import dataclass

@dataclass
class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

0 comments on commit faae3c0

Please sign in to comment.