Skip to content

Commit

Permalink
stubgen: Fix call-based namedtuple omitted from class bases (#14680)
Browse files Browse the repository at this point in the history
Fixes #9901
Fixes #13662

Fix inheriting from a call-based `collections.namedtuple` /
`typing.NamedTuple` definition that was omitted from the generated stub.

This automatically adds support for the call-based `NamedTuple` in
general not only in class bases (Closes #13788).

<details>
<summary>An example before and after</summary>
Input:

```python
import collections
import typing
from collections import namedtuple
from typing import NamedTuple

CollectionsCall = namedtuple("CollectionsCall", ["x", "y"])

class CollectionsClass(namedtuple("CollectionsClass", ["x", "y"])):
    def f(self, a):
        pass

class CollectionsDotClass(collections.namedtuple("CollectionsClass", ["x", "y"])):
    def f(self, a):
        pass

TypingCall = NamedTuple("TypingCall", [("x", int | None), ("y", int)])

class TypingClass(NamedTuple):
    x: int | None
    y: str

    def f(self, a):
        pass

class TypingClassWeird(NamedTuple("TypingClassWeird", [("x", int | None), ("y", str)])):
    z: float | None

    def f(self, a):
        pass

class TypingDotClassWeird(typing.NamedTuple("TypingClassWeird", [("x", int | None), ("y", str)])):
    def f(self, a):
        pass
```

Output diff (before and after):
```diff
diff --git a/before.pyi b/after.pyi
index c88530e2c..95ef843b4 100644
--- a/before.pyi
+++ b/after.pyi
@@ -1,26 +1,29 @@
+import typing
 from _typeshed import Incomplete
 from typing_extensions import NamedTuple

 class CollectionsCall(NamedTuple):
     x: Incomplete
     y: Incomplete

-class CollectionsClass:
+class CollectionsClass(NamedTuple('CollectionsClass', [('x', Incomplete), ('y', Incomplete)])):
     def f(self, a) -> None: ...

-class CollectionsDotClass:
+class CollectionsDotClass(NamedTuple('CollectionsClass', [('x', Incomplete), ('y', Incomplete)])):
     def f(self, a) -> None: ...

-TypingCall: Incomplete
+class TypingCall(NamedTuple):
+    x: int | None
+    y: int

 class TypingClass(NamedTuple):
     x: int | None
     y: str
     def f(self, a) -> None: ...

-class TypingClassWeird:
+class TypingClassWeird(NamedTuple('TypingClassWeird', [('x', int | None), ('y', str)])):
     z: float | None
     def f(self, a) -> None: ...

-class TypingDotClassWeird:
+class TypingDotClassWeird(typing.NamedTuple('TypingClassWeird', [('x', int | None), ('y', str)])):
     def f(self, a) -> None: ...
```
</details>
  • Loading branch information
hamdanal authored May 6, 2023
1 parent d710fdd commit 171e6f8
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 23 deletions.
120 changes: 99 additions & 21 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
TupleExpr,
TypeInfo,
UnaryExpr,
is_StrExpr_list,
)
from mypy.options import Options as MypyOptions
from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures
Expand All @@ -129,6 +128,7 @@
from mypy.types import (
OVERLOAD_NAMES,
TPDICT_NAMES,
TYPED_NAMEDTUPLE_NAMES,
AnyType,
CallableType,
Instance,
Expand Down Expand Up @@ -400,10 +400,12 @@ def visit_str_expr(self, node: StrExpr) -> str:
def visit_index_expr(self, node: IndexExpr) -> str:
base = node.base.accept(self)
index = node.index.accept(self)
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
index = index[1:-1]
return f"{base}[{index}]"

def visit_tuple_expr(self, node: TupleExpr) -> str:
return ", ".join(n.accept(self) for n in node.items)
return f"({', '.join(n.accept(self) for n in node.items)})"

def visit_list_expr(self, node: ListExpr) -> str:
return f"[{', '.join(n.accept(self) for n in node.items)}]"
Expand Down Expand Up @@ -1010,6 +1012,37 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
elif isinstance(base, IndexExpr):
p = AliasPrinter(self)
base_types.append(base.accept(p))
elif isinstance(base, CallExpr):
# namedtuple(typename, fields), NamedTuple(typename, fields) calls can
# be used as a base class. The first argument is a string literal that
# is usually the same as the class name.
#
# Note:
# A call-based named tuple as a base class cannot be safely converted to
# a class-based NamedTuple definition because class attributes defined
# in the body of the class inheriting from the named tuple call are not
# namedtuple fields at runtime.
if self.is_namedtuple(base):
nt_fields = self._get_namedtuple_fields(base)
assert isinstance(base.args[0], StrExpr)
typename = base.args[0].value
if nt_fields is not None:
# A valid namedtuple() call, use NamedTuple() instead with
# Incomplete as field types
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
base_types.append(f"NamedTuple({typename!r}, [{fields_str}])")
self.add_typing_import("NamedTuple")
else:
# Invalid namedtuple() call, cannot determine fields
base_types.append("Incomplete")
elif self.is_typed_namedtuple(base):
p = AliasPrinter(self)
base_types.append(base.accept(p))
else:
# At this point, we don't know what the base class is, so we
# just use Incomplete as the base class.
base_types.append("Incomplete")
self.import_tracker.require_name("Incomplete")
return base_types

def visit_block(self, o: Block) -> None:
Expand All @@ -1022,8 +1055,11 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
foundl = []

for lvalue in o.lvalues:
if isinstance(lvalue, NameExpr) and self.is_namedtuple(o.rvalue):
assert isinstance(o.rvalue, CallExpr)
if (
isinstance(lvalue, NameExpr)
and isinstance(o.rvalue, CallExpr)
and (self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue))
):
self.process_namedtuple(lvalue, o.rvalue)
continue
if (
Expand Down Expand Up @@ -1069,37 +1105,79 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
if all(foundl):
self._state = VAR

def is_namedtuple(self, expr: Expression) -> bool:
if not isinstance(expr, CallExpr):
return False
def is_namedtuple(self, expr: CallExpr) -> bool:
callee = expr.callee
return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or (
isinstance(callee, MemberExpr) and callee.name == "namedtuple"
return (
isinstance(callee, NameExpr)
and (self.refers_to_fullname(callee.name, "collections.namedtuple"))
) or (
isinstance(callee, MemberExpr)
and isinstance(callee.expr, NameExpr)
and f"{callee.expr.name}.{callee.name}" == "collections.namedtuple"
)

def is_typed_namedtuple(self, expr: CallExpr) -> bool:
callee = expr.callee
return (
isinstance(callee, NameExpr)
and self.refers_to_fullname(callee.name, TYPED_NAMEDTUPLE_NAMES)
) or (
isinstance(callee, MemberExpr)
and isinstance(callee.expr, NameExpr)
and f"{callee.expr.name}.{callee.name}" in TYPED_NAMEDTUPLE_NAMES
)

def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None:
if self.is_namedtuple(call):
fields_arg = call.args[1]
if isinstance(fields_arg, StrExpr):
field_names = fields_arg.value.replace(",", " ").split()
elif isinstance(fields_arg, (ListExpr, TupleExpr)):
field_names = []
for field in fields_arg.items:
if not isinstance(field, StrExpr):
return None
field_names.append(field.value)
else:
return None # Invalid namedtuple fields type
if field_names:
self.import_tracker.require_name("Incomplete")
return [(field_name, "Incomplete") for field_name in field_names]
elif self.is_typed_namedtuple(call):
fields_arg = call.args[1]
if not isinstance(fields_arg, (ListExpr, TupleExpr)):
return None
fields: list[tuple[str, str]] = []
b = AliasPrinter(self)
for field in fields_arg.items:
if not (isinstance(field, TupleExpr) and len(field.items) == 2):
return None
field_name, field_type = field.items
if not isinstance(field_name, StrExpr):
return None
fields.append((field_name.value, field_type.accept(b)))
return fields
else:
return None # Not a named tuple call

def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
if self._state != EMPTY:
self.add("\n")
if isinstance(rvalue.args[1], StrExpr):
items = rvalue.args[1].value.replace(",", " ").split()
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
list_items = rvalue.args[1].items
assert is_StrExpr_list(list_items)
items = [item.value for item in list_items]
else:
fields = self._get_namedtuple_fields(rvalue)
if fields is None:
self.add(f"{self._indent}{lvalue.name}: Incomplete")
self.import_tracker.require_name("Incomplete")
return
self.import_tracker.require_name("NamedTuple")
self.add(f"{self._indent}class {lvalue.name}(NamedTuple):")
if not items:
if len(fields) == 0:
self.add(" ...\n")
self._state = EMPTY_CLASS
else:
self.import_tracker.require_name("Incomplete")
self.add("\n")
for item in items:
self.add(f"{self._indent} {item}: Incomplete\n")
self._state = CLASS
for f_name, f_type in fields:
self.add(f"{self._indent} {f_name}: {f_type}\n")
self._state = CLASS

def is_typeddict(self, expr: CallExpr) -> bool:
callee = expr.callee
Expand Down
71 changes: 69 additions & 2 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ class A:
def _bar(cls) -> None: ...

[case testNamedtuple]
import collections, x
import collections, typing, x
X = collections.namedtuple('X', ['a', 'b'])
Y = typing.NamedTuple('Y', [('a', int), ('b', str)])
[out]
from _typeshed import Incomplete
from typing import NamedTuple
Expand All @@ -651,14 +652,21 @@ class X(NamedTuple):
a: Incomplete
b: Incomplete

class Y(NamedTuple):
a: int
b: str

[case testEmptyNamedtuple]
import collections
import collections, typing
X = collections.namedtuple('X', [])
Y = typing.NamedTuple('Y', [])
[out]
from typing import NamedTuple

class X(NamedTuple): ...

class Y(NamedTuple): ...

[case testNamedtupleAltSyntax]
from collections import namedtuple, xx
X = namedtuple('X', 'a b')
Expand Down Expand Up @@ -697,8 +705,10 @@ class X(NamedTuple):

[case testNamedtupleWithUnderscore]
from collections import namedtuple as _namedtuple
from typing import NamedTuple as _NamedTuple
def f(): ...
X = _namedtuple('X', 'a b')
Y = _NamedTuple('Y', [('a', int), ('b', str)])
def g(): ...
[out]
from _typeshed import Incomplete
Expand All @@ -710,6 +720,10 @@ class X(NamedTuple):
a: Incomplete
b: Incomplete

class Y(NamedTuple):
a: int
b: str

def g() -> None: ...

[case testNamedtupleBaseClass]
Expand All @@ -728,10 +742,14 @@ class Y(_X): ...

[case testNamedtupleAltSyntaxFieldsTuples]
from collections import namedtuple, xx
from typing import NamedTuple
X = namedtuple('X', ())
Y = namedtuple('Y', ('a',))
Z = namedtuple('Z', ('a', 'b', 'c', 'd', 'e'))
xx
R = NamedTuple('R', ())
S = NamedTuple('S', (('a', int),))
T = NamedTuple('T', (('a', int), ('b', str)))
[out]
from _typeshed import Incomplete
from typing import NamedTuple
Expand All @@ -748,13 +766,62 @@ class Z(NamedTuple):
d: Incomplete
e: Incomplete

class R(NamedTuple): ...

class S(NamedTuple):
a: int

class T(NamedTuple):
a: int
b: str

[case testDynamicNamedTuple]
from collections import namedtuple
from typing import NamedTuple
N = namedtuple('N', ['x', 'y'] + ['z'])
M = NamedTuple('M', [('x', int), ('y', str)] + [('z', float)])
class X(namedtuple('X', ['a', 'b'] + ['c'])): ...
[out]
from _typeshed import Incomplete

N: Incomplete
M: Incomplete
class X(Incomplete): ...

[case testNamedTupleInClassBases]
import collections, typing
from collections import namedtuple
from typing import NamedTuple
class X(namedtuple('X', ['a', 'b'])): ...
class Y(NamedTuple('Y', [('a', int), ('b', str)])): ...
class R(collections.namedtuple('R', ['a', 'b'])): ...
class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ...
[out]
import typing
from _typeshed import Incomplete
from typing import NamedTuple

class X(NamedTuple('X', [('a', Incomplete), ('b', Incomplete)])): ...
class Y(NamedTuple('Y', [('a', int), ('b', str)])): ...
class R(NamedTuple('R', [('a', Incomplete), ('b', Incomplete)])): ...
class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ...

[case testNotNamedTuple]
from not_collections import namedtuple
from not_typing import NamedTuple
from collections import notnamedtuple
from typing import NotNamedTuple
X = namedtuple('X', ['a', 'b'])
Y = notnamedtuple('Y', ['a', 'b'])
Z = NamedTuple('Z', [('a', int), ('b', str)])
W = NotNamedTuple('W', [('a', int), ('b', str)])
[out]
from _typeshed import Incomplete

X: Incomplete
Y: Incomplete
Z: Incomplete
W: Incomplete

[case testArbitraryBaseClass]
import x
Expand Down

0 comments on commit 171e6f8

Please sign in to comment.