Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: Fix call-based namedtuple omitted from class bases #14680

Merged
merged 9 commits into from
May 6, 2023
120 changes: 99 additions & 21 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
TupleExpr,
TypeInfo,
UnaryExpr,
is_StrExpr_list,
)
from mypy.options import Options as MypyOptions
from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures
Expand All @@ -129,6 +128,7 @@
from mypy.types import (
OVERLOAD_NAMES,
TPDICT_NAMES,
TYPED_NAMEDTUPLE_NAMES,
AnyType,
CallableType,
Instance,
Expand Down Expand Up @@ -400,10 +400,12 @@ def visit_str_expr(self, node: StrExpr) -> str:
def visit_index_expr(self, node: IndexExpr) -> str:
base = node.base.accept(self)
index = node.index.accept(self)
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
index = index[1:-1]
return f"{base}[{index}]"

def visit_tuple_expr(self, node: TupleExpr) -> str:
return ", ".join(n.accept(self) for n in node.items)
return f"({', '.join(n.accept(self) for n in node.items)})"

def visit_list_expr(self, node: ListExpr) -> str:
return f"[{', '.join(n.accept(self) for n in node.items)}]"
Expand Down Expand Up @@ -1010,6 +1012,37 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
elif isinstance(base, IndexExpr):
p = AliasPrinter(self)
base_types.append(base.accept(p))
elif isinstance(base, CallExpr):
# namedtuple(typename, fields), NamedTuple(typename, fields) calls can
# be used as a base class. The first argument is a string literal that
# is usually the same as the class name.
hamdanal marked this conversation as resolved.
Show resolved Hide resolved
#
# Note:
# A call-based named tuple as a base class cannot be safely converted to
# a class-based NamedTuple definition because class attributes defined
# in the body of the class inheriting from the named tuple call are not
# namedtuple fields at runtime.
if self.is_namedtuple(base):
nt_fields = self._get_namedtuple_fields(base)
assert isinstance(base.args[0], StrExpr)
typename = base.args[0].value
if nt_fields is not None:
# A valid namedtuple() call, use NamedTuple() instead with
# Incomplete as field types
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
else:
# Invalid namedtuple() call, cannot determine fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you generate NamedTuple(name, [Incomplete]) as the base in this case. That seems wrong, as it's an invalid type. I think we should make Incomplete the base in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is this code path tested?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I fixed it. The tests now cover the invalid namedtuple call code path.

fields_str = "Incomplete"
base_types.append(f"NamedTuple({typename!r}, [{fields_str}])")
self.add_typing_import("NamedTuple")
elif self.is_typed_namedtuple(base):
p = AliasPrinter(self)
base_types.append(base.accept(p))
else:
# At this point, we don't know what the base class is, so we
# just use Incomplete as the base class.
base_types.append("Incomplete")
self.import_tracker.require_name("Incomplete")
return base_types

def visit_block(self, o: Block) -> None:
Expand All @@ -1022,8 +1055,11 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
foundl = []

for lvalue in o.lvalues:
if isinstance(lvalue, NameExpr) and self.is_namedtuple(o.rvalue):
assert isinstance(o.rvalue, CallExpr)
if (
isinstance(lvalue, NameExpr)
and isinstance(o.rvalue, CallExpr)
and (self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue))
):
self.process_namedtuple(lvalue, o.rvalue)
continue
if (
Expand Down Expand Up @@ -1069,37 +1105,79 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
if all(foundl):
self._state = VAR

def is_namedtuple(self, expr: Expression) -> bool:
if not isinstance(expr, CallExpr):
return False
def is_namedtuple(self, expr: CallExpr) -> bool:
callee = expr.callee
return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or (
isinstance(callee, MemberExpr) and callee.name == "namedtuple"
return (
isinstance(callee, NameExpr)
and (self.refers_to_fullname(callee.name, "collections.namedtuple"))
) or (
isinstance(callee, MemberExpr)
and isinstance(callee.expr, NameExpr)
and f"{callee.expr.name}.{callee.name}" == "collections.namedtuple"
)

def is_typed_namedtuple(self, expr: CallExpr) -> bool:
callee = expr.callee
return (
isinstance(callee, NameExpr)
and self.refers_to_fullname(callee.name, TYPED_NAMEDTUPLE_NAMES)
) or (
isinstance(callee, MemberExpr)
and isinstance(callee.expr, NameExpr)
and f"{callee.expr.name}.{callee.name}" in TYPED_NAMEDTUPLE_NAMES
)

def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None:
if self.is_namedtuple(call):
fields_arg = call.args[1]
if isinstance(fields_arg, StrExpr):
field_names = fields_arg.value.replace(",", " ").split()
elif isinstance(fields_arg, (ListExpr, TupleExpr)):
field_names = []
for field in fields_arg.items:
if not isinstance(field, StrExpr):
return None
field_names.append(field.value)
else:
return None # Invalid namedtuple fields type
if field_names:
self.import_tracker.require_name("Incomplete")
return [(field_name, "Incomplete") for field_name in field_names]
elif self.is_typed_namedtuple(call):
fields_arg = call.args[1]
if not isinstance(fields_arg, (ListExpr, TupleExpr)):
return None
fields: list[tuple[str, str]] = []
hamdanal marked this conversation as resolved.
Show resolved Hide resolved
b = AliasPrinter(self)
for field in fields_arg.items:
if not (isinstance(field, TupleExpr) and len(field.items) == 2):
return None
field_name, field_type = field.items
if not isinstance(field_name, StrExpr):
return None
fields.append((field_name.value, field_type.accept(b)))
return fields
else:
return None # Not a named tuple call

def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
if self._state != EMPTY:
self.add("\n")
if isinstance(rvalue.args[1], StrExpr):
items = rvalue.args[1].value.replace(",", " ").split()
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
list_items = rvalue.args[1].items
assert is_StrExpr_list(list_items)
items = [item.value for item in list_items]
else:
fields = self._get_namedtuple_fields(rvalue)
if fields is None:
self.add(f"{self._indent}{lvalue.name}: Incomplete")
self.import_tracker.require_name("Incomplete")
return
self.import_tracker.require_name("NamedTuple")
self.add(f"{self._indent}class {lvalue.name}(NamedTuple):")
if not items:
if len(fields) == 0:
self.add(" ...\n")
self._state = EMPTY_CLASS
else:
self.import_tracker.require_name("Incomplete")
self.add("\n")
for item in items:
self.add(f"{self._indent} {item}: Incomplete\n")
self._state = CLASS
for f_name, f_type in fields:
self.add(f"{self._indent} {f_name}: {f_type}\n")
self._state = CLASS

def is_typeddict(self, expr: CallExpr) -> bool:
callee = expr.callee
Expand Down
69 changes: 67 additions & 2 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ class A:
def _bar(cls) -> None: ...

[case testNamedtuple]
import collections, x
import collections, typing, x
X = collections.namedtuple('X', ['a', 'b'])
Y = typing.NamedTuple('Y', [('a', int), ('b', str)])
[out]
from _typeshed import Incomplete
from typing import NamedTuple
Expand All @@ -651,14 +652,21 @@ class X(NamedTuple):
a: Incomplete
b: Incomplete

class Y(NamedTuple):
a: int
b: str

[case testEmptyNamedtuple]
import collections
import collections, typing
X = collections.namedtuple('X', [])
Y = typing.NamedTuple('Y', [])
[out]
from typing import NamedTuple

class X(NamedTuple): ...

class Y(NamedTuple): ...

[case testNamedtupleAltSyntax]
from collections import namedtuple, xx
X = namedtuple('X', 'a b')
Expand Down Expand Up @@ -697,8 +705,10 @@ class X(NamedTuple):

[case testNamedtupleWithUnderscore]
from collections import namedtuple as _namedtuple
from typing import NamedTuple as _NamedTuple
def f(): ...
X = _namedtuple('X', 'a b')
Y = _NamedTuple('Y', [('a', int), ('b', str)])
def g(): ...
[out]
from _typeshed import Incomplete
Expand All @@ -710,6 +720,10 @@ class X(NamedTuple):
a: Incomplete
b: Incomplete

class Y(NamedTuple):
a: int
b: str

def g() -> None: ...

[case testNamedtupleBaseClass]
Expand All @@ -728,10 +742,14 @@ class Y(_X): ...

[case testNamedtupleAltSyntaxFieldsTuples]
from collections import namedtuple, xx
from typing import NamedTuple
X = namedtuple('X', ())
Y = namedtuple('Y', ('a',))
Z = namedtuple('Z', ('a', 'b', 'c', 'd', 'e'))
xx
R = NamedTuple('R', ())
S = NamedTuple('S', (('a', int),))
T = NamedTuple('T', (('a', int), ('b', str)))
[out]
from _typeshed import Incomplete
from typing import NamedTuple
Expand All @@ -748,13 +766,60 @@ class Z(NamedTuple):
d: Incomplete
e: Incomplete

class R(NamedTuple): ...

class S(NamedTuple):
a: int

class T(NamedTuple):
a: int
b: str

[case testDynamicNamedTuple]
from collections import namedtuple
from typing import NamedTuple
N = namedtuple('N', ['x', 'y'] + ['z'])
M = NamedTuple('M', [('x', int), ('y', str)] + [('z', float)])
[out]
from _typeshed import Incomplete

N: Incomplete
M: Incomplete

[case testNamedTupleInClassBases]
import collections, typing
from collections import namedtuple
from typing import NamedTuple
class X(namedtuple('X', ['a', 'b'])): ...
class Y(NamedTuple('Y', [('a', int), ('b', str)])): ...
class R(collections.namedtuple('R', ['a', 'b'])): ...
class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ...
[out]
import typing
from _typeshed import Incomplete
from typing import NamedTuple

class X(NamedTuple('X', [('a', Incomplete), ('b', Incomplete)])): ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, in typeshed we prefer to avoid call-based namedtuples wherever possible, so we would usually do something like this for this kind of thing:

class _XBase(NamedTuple):
    a: Incomplete
    b: Incomplete

class X(_XBase): ...

I'm guessing it might be hard to achieve that from stubgen, though?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, call-based syntax isn't desirable. It is however tricky to get the conversion to a stub-only class definition right. In fact I tried briefly when I started with the PR then decided it is safer and much simpler to keep the call-based definition.

Note that even though this syntax is ugly/undesirable, it type-checks perfectly fine by both mypy and pyright and this PR still fixes the issue where stubgen generated wrong stubs without any warning. It also makes a posterior step of manual conversion to a class definition much simpler as the information is now there in the stub instead of having to grep the python source for all namedtuple bases in class definitions.

Having said that, I don't mind implementing this, whether in this PR or in a follow up one. I do like to get the opinion of a stubgen maintainer/expert before working on it though as it will require some work. If the answer is do it, I have a couple of questions regarding this step that I can ask then.

Copy link
Collaborator Author

@hamdanal hamdanal May 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexWaygood are you comfortable enough this change will be welcome so that I can work on it or should we ping someone?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to generate the stubs as in this PR. If you fix the merge conflict, I can review this PR.

class Y(NamedTuple('Y', [('a', int), ('b', str)])): ...
class R(NamedTuple('R', [('a', Incomplete), ('b', Incomplete)])): ...
class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ...

[case testNotNamedTuple]
from not_collections import namedtuple
from not_typing import NamedTuple
from collections import notnamedtuple
from typing import NotNamedTuple
X = namedtuple('X', ['a', 'b'])
Y = notnamedtuple('Y', ['a', 'b'])
Z = NamedTuple('Z', [('a', int), ('b', str)])
W = NotNamedTuple('W', [('a', int), ('b', str)])
[out]
from _typeshed import Incomplete

X: Incomplete
Y: Incomplete
Z: Incomplete
W: Incomplete

[case testArbitraryBaseClass]
import x
Expand Down