From 8ec5d8d59c10f919fa96846eb29d6702bee71aa2 Mon Sep 17 00:00:00 2001 From: KotlinIsland Date: Sat, 19 Aug 2023 19:57:20 +1000 Subject: [PATCH] StubGenerator.add_typing_import returns the name --- mypy/stubgen.py | 55 ++++++++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index b6fc3e8b7377..aca836c52ce8 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -786,25 +786,20 @@ def visit_func_def(self, o: FuncDef) -> None: elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES: retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name] elif has_yield_expression(o) or has_yield_from_expression(o): - self.add_typing_import("Generator") + generator_name = self.add_typing_import("Generator") yield_name = "None" send_name = "None" return_name = "None" if has_yield_from_expression(o): - self.add_typing_import("Incomplete") - yield_name = send_name = self.typing_name("Incomplete") + yield_name = send_name = self.add_typing_import("Incomplete") else: for expr, in_assignment in all_yield_expressions(o): if expr.expr is not None and not self.is_none_expr(expr.expr): - self.add_typing_import("Incomplete") - yield_name = self.typing_name("Incomplete") + yield_name = self.add_typing_import("Incomplete") if in_assignment: - self.add_typing_import("Incomplete") - send_name = self.typing_name("Incomplete") + send_name = self.add_typing_import("Incomplete") if has_return_statement(o): - self.add_typing_import("Incomplete") - return_name = self.typing_name("Incomplete") - generator_name = self.typing_name("Generator") + return_name = self.add_typing_import("Incomplete") retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]" elif not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: retname = "None" @@ -965,21 +960,19 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: nt_fields = self._get_namedtuple_fields(base) assert isinstance(base.args[0], StrExpr) typename = base.args[0].value - if nt_fields is not None: - fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields) - namedtuple_name = self.typing_name("NamedTuple") - base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])") - self.add_typing_import("NamedTuple") - else: + if nt_fields is None: # Invalid namedtuple() call, cannot determine fields - base_types.append(self.typing_name("Incomplete")) + base_types.append(self.add_typing_import("Incomplete")) + continue + fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields) + namedtuple_name = self.add_typing_import("NamedTuple") + base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])") elif self.is_typed_namedtuple(base): base_types.append(base.accept(p)) else: # At this point, we don't know what the base class is, so we # just use Incomplete as the base class. - base_types.append(self.typing_name("Incomplete")) - self.add_typing_import("Incomplete") + base_types.append(self.add_typing_import("Incomplete")) for name, value in cdef.keywords.items(): if name == "metaclass": continue # handled separately @@ -1059,9 +1052,9 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None field_names.append(field.value) else: return None # Invalid namedtuple fields type - if field_names: - self.add_typing_import("Incomplete") - incomplete = self.typing_name("Incomplete") + if not field_names: + return [] + incomplete = self.add_typing_import("Incomplete") return [(field_name, incomplete) for field_name in field_names] elif self.is_typed_namedtuple(call): fields_arg = call.args[1] @@ -1092,8 +1085,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: if fields is None: self.annotate_as_incomplete(lvalue) return - self.add_typing_import("NamedTuple") - bases = self.typing_name("NamedTuple") + bases = self.add_typing_import("NamedTuple") # TODO: Add support for generic NamedTuples. Requires `Generic` as base class. class_def = f"{self._indent}class {lvalue.name}({bases}):" if len(fields) == 0: @@ -1143,14 +1135,13 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: total = arg else: items.append((arg_name, arg)) - self.add_typing_import("TypedDict") + bases = self.add_typing_import("TypedDict") p = AliasPrinter(self) if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items): # Keep the call syntax if there are non-identifier or reserved keyword keys. self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n") self._state = VAR else: - bases = self.typing_name("TypedDict") # TODO: Add support for generic TypedDicts. Requires `Generic` as base class. if total is not None: bases += f", total={total.accept(p)}" @@ -1167,8 +1158,7 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: self._state = CLASS def annotate_as_incomplete(self, lvalue: NameExpr) -> None: - self.add_typing_import("Incomplete") - self.add(f"{self._indent}{lvalue.name}: {self.typing_name('Incomplete')}\n") + self.add(f"{self._indent}{lvalue.name}: {self.add_typing_import('Incomplete')}\n") self._state = VAR def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: @@ -1384,13 +1374,14 @@ def typing_name(self, name: str) -> str: else: return name - def add_typing_import(self, name: str) -> None: + def add_typing_import(self, name: str) -> str: """Add a name to be imported for typing, unless it's imported already. The import will be internal to the stub. """ name = self.typing_name(name) self.import_tracker.require_name(name) + return name def add_import_line(self, line: str) -> None: """Add a line of text to the import section, unless it's already there.""" @@ -1448,11 +1439,9 @@ def get_str_type_of_node( if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"): return "bool" if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None": - self.add_typing_import("Incomplete") - return f"{self.typing_name('Incomplete')} | None" + return f"{self.add_typing_import('Incomplete')} | None" if can_be_any: - self.add_typing_import("Incomplete") - return self.typing_name("Incomplete") + return self.add_typing_import("Incomplete") else: return ""