Skip to content

Commit

Permalink
StubGenerator.add_typing_import returns the name
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Aug 19, 2023
1 parent b02ddf1 commit 8ec5d8d
Showing 1 changed file with 22 additions and 33 deletions.
55 changes: 22 additions & 33 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,25 +786,20 @@ def visit_func_def(self, o: FuncDef) -> None:
elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES:
retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name]
elif has_yield_expression(o) or has_yield_from_expression(o):
self.add_typing_import("Generator")
generator_name = self.add_typing_import("Generator")
yield_name = "None"
send_name = "None"
return_name = "None"
if has_yield_from_expression(o):
self.add_typing_import("Incomplete")
yield_name = send_name = self.typing_name("Incomplete")
yield_name = send_name = self.add_typing_import("Incomplete")
else:
for expr, in_assignment in all_yield_expressions(o):
if expr.expr is not None and not self.is_none_expr(expr.expr):
self.add_typing_import("Incomplete")
yield_name = self.typing_name("Incomplete")
yield_name = self.add_typing_import("Incomplete")
if in_assignment:
self.add_typing_import("Incomplete")
send_name = self.typing_name("Incomplete")
send_name = self.add_typing_import("Incomplete")
if has_return_statement(o):
self.add_typing_import("Incomplete")
return_name = self.typing_name("Incomplete")
generator_name = self.typing_name("Generator")
return_name = self.add_typing_import("Incomplete")
retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]"
elif not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
retname = "None"
Expand Down Expand Up @@ -965,21 +960,19 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
nt_fields = self._get_namedtuple_fields(base)
assert isinstance(base.args[0], StrExpr)
typename = base.args[0].value
if nt_fields is not None:
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
namedtuple_name = self.typing_name("NamedTuple")
base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])")
self.add_typing_import("NamedTuple")
else:
if nt_fields is None:
# Invalid namedtuple() call, cannot determine fields
base_types.append(self.typing_name("Incomplete"))
base_types.append(self.add_typing_import("Incomplete"))
continue
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
namedtuple_name = self.add_typing_import("NamedTuple")
base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])")
elif self.is_typed_namedtuple(base):
base_types.append(base.accept(p))
else:
# At this point, we don't know what the base class is, so we
# just use Incomplete as the base class.
base_types.append(self.typing_name("Incomplete"))
self.add_typing_import("Incomplete")
base_types.append(self.add_typing_import("Incomplete"))
for name, value in cdef.keywords.items():
if name == "metaclass":
continue # handled separately
Expand Down Expand Up @@ -1059,9 +1052,9 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None
field_names.append(field.value)
else:
return None # Invalid namedtuple fields type
if field_names:
self.add_typing_import("Incomplete")
incomplete = self.typing_name("Incomplete")
if not field_names:
return []
incomplete = self.add_typing_import("Incomplete")
return [(field_name, incomplete) for field_name in field_names]
elif self.is_typed_namedtuple(call):
fields_arg = call.args[1]
Expand Down Expand Up @@ -1092,8 +1085,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
if fields is None:
self.annotate_as_incomplete(lvalue)
return
self.add_typing_import("NamedTuple")
bases = self.typing_name("NamedTuple")
bases = self.add_typing_import("NamedTuple")
# TODO: Add support for generic NamedTuples. Requires `Generic` as base class.
class_def = f"{self._indent}class {lvalue.name}({bases}):"
if len(fields) == 0:
Expand Down Expand Up @@ -1143,14 +1135,13 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
total = arg
else:
items.append((arg_name, arg))
self.add_typing_import("TypedDict")
bases = self.add_typing_import("TypedDict")
p = AliasPrinter(self)
if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items):
# Keep the call syntax if there are non-identifier or reserved keyword keys.
self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n")
self._state = VAR
else:
bases = self.typing_name("TypedDict")
# TODO: Add support for generic TypedDicts. Requires `Generic` as base class.
if total is not None:
bases += f", total={total.accept(p)}"
Expand All @@ -1167,8 +1158,7 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
self._state = CLASS

def annotate_as_incomplete(self, lvalue: NameExpr) -> None:
self.add_typing_import("Incomplete")
self.add(f"{self._indent}{lvalue.name}: {self.typing_name('Incomplete')}\n")
self.add(f"{self._indent}{lvalue.name}: {self.add_typing_import('Incomplete')}\n")
self._state = VAR

def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
Expand Down Expand Up @@ -1384,13 +1374,14 @@ def typing_name(self, name: str) -> str:
else:
return name

def add_typing_import(self, name: str) -> None:
def add_typing_import(self, name: str) -> str:
"""Add a name to be imported for typing, unless it's imported already.
The import will be internal to the stub.
"""
name = self.typing_name(name)
self.import_tracker.require_name(name)
return name

def add_import_line(self, line: str) -> None:
"""Add a line of text to the import section, unless it's already there."""
Expand Down Expand Up @@ -1448,11 +1439,9 @@ def get_str_type_of_node(
if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"):
return "bool"
if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None":
self.add_typing_import("Incomplete")
return f"{self.typing_name('Incomplete')} | None"
return f"{self.add_typing_import('Incomplete')} | None"
if can_be_any:
self.add_typing_import("Incomplete")
return self.typing_name("Incomplete")
return self.add_typing_import("Incomplete")
else:
return ""

Expand Down

0 comments on commit 8ec5d8d

Please sign in to comment.