From d98cc5a162729727c5c90e98b27404eebff14bb5 Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Fri, 19 May 2023 01:32:10 +0200 Subject: [PATCH] stubgen: fixes and simplifications (#15232) This PR refactors name resolution in stubgen to make it simple to use and fixes a bunch of related bugs especially when run in `--parse-only` mode. It mainly does the following: 1. Adds a `get_fullname` method that resolves the full name of both `NameExpr` and `MemberExpr` taking into account import aliases 2. Use this method everywhere manual resolution was done including the decorator processors 3. Simplify decorator processing by consolidating `process_decorator`, `process_name_expr_decorator`, and `process_member_expr_decorator` into a single method thanks to the common name resolution using `get_fullname` 4. Fix occurrences of some hard-coded implicitly-added names (like `Incomplete` from `_typeshed`) by using `self.typing_name()` and `self.add_typing_import` which take into account objects of the same name defined in the file 5. Fix some inconsistencies in generated empty lines. This improves the consistency within stubgen as well as with the black code style for stubs. Out of the 9 test cases added, only `testAbstractPropertyImportAlias` passes on master (added it any way because it was not tested). The other 8 all fail without this PR. The last test case `testUseTypingName` demonstrates point `4.`. In existing test cases, only white space changes in four of the cases due to point `5.` above were made. This will allow easier additions in the future as no manual name resolution has to be made any more. --- mypy/stubgen.py | 397 +++++++++++++----------------------- test-data/unit/stubgen.test | 201 +++++++++++++++++- 2 files changed, 337 insertions(+), 261 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 44b27e4751dd..32dc6a615f8c 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -74,6 +74,7 @@ ARG_STAR, ARG_STAR2, IS_ABSTRACT, + NOT_ABSTRACT, AssignmentStmt, Block, BytesExpr, @@ -104,6 +105,7 @@ TupleExpr, TypeInfo, UnaryExpr, + Var, ) from mypy.options import Options as MypyOptions from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures @@ -652,7 +654,7 @@ def visit_mypy_file(self, o: MypyFile) -> None: self.referenced_names = find_referenced_names(o) known_imports = { "_typeshed": ["Incomplete"], - "typing": ["Any", "TypeVar"], + "typing": ["Any", "TypeVar", "NamedTuple"], "collections.abc": ["Generator"], "typing_extensions": ["TypedDict"], } @@ -673,34 +675,30 @@ def visit_mypy_file(self, o: MypyFile) -> None: self.add(f"# {name}\n") def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: - """@property with setters and getters, or @overload chain""" + """@property with setters and getters, @overload chain and some others.""" overload_chain = False for item in o.items: if not isinstance(item, Decorator): continue - if self.is_private_name(item.func.name, item.func.fullname): continue - is_abstract, is_overload = self.process_decorator(item) - + self.process_decorator(item) if not overload_chain: - self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload) - if is_overload: + self.visit_func_def(item.func) + if item.func.is_overload: overload_chain = True - elif is_overload: - self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload) + elif item.func.is_overload: + self.visit_func_def(item.func) else: # skip the overload implementation and clear the decorator we just processed self.clear_decorators() - def visit_func_def( - self, o: FuncDef, is_abstract: bool = False, is_overload: bool = False - ) -> None: + def visit_func_def(self, o: FuncDef) -> None: if ( self.is_private_name(o.name, o.fullname) or self.is_not_in_all(o.name) - or (self.is_recorded_name(o.name) and not is_overload) + or (self.is_recorded_name(o.name) and not o.is_overload) ): self.clear_decorators() return @@ -777,23 +775,23 @@ def visit_func_def( elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES: retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name] elif has_yield_expression(o): - self.add_abc_import("Generator") + self.add_typing_import("Generator") yield_name = "None" send_name = "None" return_name = "None" for expr, in_assignment in all_yield_expressions(o): if expr.expr is not None and not self.is_none_expr(expr.expr): self.add_typing_import("Incomplete") - yield_name = "Incomplete" + yield_name = self.typing_name("Incomplete") if in_assignment: self.add_typing_import("Incomplete") - send_name = "Incomplete" + send_name = self.typing_name("Incomplete") if has_return_statement(o): self.add_typing_import("Incomplete") - return_name = "Incomplete" + return_name = self.typing_name("Incomplete") generator_name = self.typing_name("Generator") retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]" - elif not has_return_statement(o) and not is_abstract: + elif not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: retname = "None" retfield = "" if retname is not None: @@ -809,151 +807,74 @@ def is_none_expr(self, expr: Expression) -> bool: def visit_decorator(self, o: Decorator) -> None: if self.is_private_name(o.func.name, o.func.fullname): return + self.process_decorator(o) + self.visit_func_def(o.func) - is_abstract, _ = self.process_decorator(o) - self.visit_func_def(o.func, is_abstract=is_abstract) - - def process_decorator(self, o: Decorator) -> tuple[bool, bool]: + def process_decorator(self, o: Decorator) -> None: """Process a series of decorators. Only preserve certain special decorators such as @abstractmethod. - - Return a pair of booleans: - - True if any of the decorators makes a method abstract. - - True if any of the decorators is typing.overload. """ - is_abstract = False - is_overload = False for decorator in o.original_decorators: - if isinstance(decorator, NameExpr): - i_is_abstract, i_is_overload = self.process_name_expr_decorator(decorator, o) - is_abstract = is_abstract or i_is_abstract - is_overload = is_overload or i_is_overload - elif isinstance(decorator, MemberExpr): - i_is_abstract, i_is_overload = self.process_member_expr_decorator(decorator, o) - is_abstract = is_abstract or i_is_abstract - is_overload = is_overload or i_is_overload - return is_abstract, is_overload - - def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> tuple[bool, bool]: - """Process a function decorator of form @foo. - - Only preserve certain special decorators such as @abstractmethod. - - Return a pair of booleans: - - True if the decorator makes a method abstract. - - True if the decorator is typing.overload. - """ - is_abstract = False - is_overload = False - name = expr.name - if name in ("property", "staticmethod", "classmethod"): - self.add_decorator(name) - elif self.import_tracker.module_for.get(name) in ( - "asyncio", - "asyncio.coroutines", - "types", - ): - self.add_coroutine_decorator(context.func, name, name) - elif self.refers_to_fullname(name, "abc.abstractmethod"): - self.add_decorator(name) - self.import_tracker.require_name(name) - is_abstract = True - elif self.refers_to_fullname(name, "abc.abstractproperty"): - self.add_decorator("property") - self.add_decorator("abc.abstractmethod") - is_abstract = True - elif self.refers_to_fullname(name, "functools.cached_property"): - self.import_tracker.require_name(name) - self.add_decorator(name) - elif self.refers_to_fullname(name, OVERLOAD_NAMES): - self.add_decorator(name) - self.add_typing_import("overload") - is_overload = True - return is_abstract, is_overload - - def refers_to_fullname(self, name: str, fullname: str | tuple[str, ...]) -> bool: - if isinstance(fullname, tuple): - return any(self.refers_to_fullname(name, fname) for fname in fullname) - module, short = fullname.rsplit(".", 1) - return self.import_tracker.module_for.get(name) == module and ( - name == short or self.import_tracker.reverse_alias.get(name) == short - ) - - def process_member_expr_decorator( - self, expr: MemberExpr, context: Decorator - ) -> tuple[bool, bool]: - """Process a function decorator of form @foo.bar. - - Only preserve certain special decorators such as @abstractmethod. - - Return a pair of booleans: - - True if the decorator makes a method abstract. - - True if the decorator is typing.overload. - """ - is_abstract = False - is_overload = False - if expr.name == "setter" and isinstance(expr.expr, NameExpr): - self.add_decorator(f"{expr.expr.name}.setter") - elif ( - isinstance(expr.expr, NameExpr) - and ( - expr.expr.name == "abc" - or self.import_tracker.reverse_alias.get(expr.expr.name) == "abc" - ) - and expr.name in ("abstractmethod", "abstractproperty") - ): - if expr.name == "abstractproperty": - self.import_tracker.require_name(expr.expr.name) - self.add_decorator("property") - self.add_decorator(f"{expr.expr.name}.abstractmethod") - else: - self.import_tracker.require_name(expr.expr.name) - self.add_decorator(f"{expr.expr.name}.{expr.name}") - is_abstract = True - elif expr.name == "cached_property" and isinstance(expr.expr, NameExpr): - explicit_name = expr.expr.name - reverse = self.import_tracker.reverse_alias.get(explicit_name) - if reverse == "functools" or (reverse is None and explicit_name == "functools"): - if reverse is not None: - self.import_tracker.add_import(reverse, alias=explicit_name) - self.import_tracker.require_name(explicit_name) - self.add_decorator(f"{explicit_name}.{expr.name}") - elif expr.name == "coroutine": - if ( - isinstance(expr.expr, MemberExpr) - and expr.expr.name == "coroutines" - and isinstance(expr.expr.expr, NameExpr) - and ( - expr.expr.expr.name == "asyncio" - or self.import_tracker.reverse_alias.get(expr.expr.expr.name) == "asyncio" - ) + if not isinstance(decorator, (NameExpr, MemberExpr)): + continue + qualname = get_qualified_name(decorator) + fullname = self.get_fullname(decorator) + if fullname in ( + "builtins.property", + "builtins.staticmethod", + "builtins.classmethod", + "functools.cached_property", ): - self.add_coroutine_decorator( - context.func, - f"{expr.expr.expr.name}.coroutines.coroutine", - expr.expr.expr.name, - ) - elif isinstance(expr.expr, NameExpr) and ( - expr.expr.name in ("asyncio", "types") - or self.import_tracker.reverse_alias.get(expr.expr.name) - in ("asyncio", "asyncio.coroutines", "types") + self.add_decorator(qualname, require_name=True) + elif fullname in ( + "asyncio.coroutine", + "asyncio.coroutines.coroutine", + "types.coroutine", ): - self.add_coroutine_decorator( - context.func, expr.expr.name + ".coroutine", expr.expr.name - ) - elif ( - isinstance(expr.expr, NameExpr) - and ( - expr.expr.name in TYPING_MODULE_NAMES - or self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES - ) - and expr.name == "overload" + o.func.is_awaitable_coroutine = True + self.add_decorator(qualname, require_name=True) + elif fullname == "abc.abstractmethod": + self.add_decorator(qualname, require_name=True) + o.func.abstract_status = IS_ABSTRACT + elif fullname in ( + "abc.abstractproperty", + "abc.abstractstaticmethod", + "abc.abstractclassmethod", + ): + abc_module = qualname.rpartition(".")[0] + if not abc_module: + self.import_tracker.add_import("abc") + builtin_decorator_replacement = fullname[len("abc.abstract") :] + self.add_decorator(builtin_decorator_replacement, require_name=False) + self.add_decorator(f"{abc_module or 'abc'}.abstractmethod", require_name=True) + o.func.abstract_status = IS_ABSTRACT + elif fullname in OVERLOAD_NAMES: + self.add_decorator(qualname, require_name=True) + o.func.is_overload = True + elif qualname.endswith(".setter"): + self.add_decorator(qualname, require_name=False) + + def get_fullname(self, expr: Expression) -> str: + """Return the full name resolving imports and import aliases.""" + if ( + self.analyzed + and isinstance(expr, (NameExpr, MemberExpr)) + and expr.fullname + and not (isinstance(expr.node, Var) and expr.node.is_suppressed_import) ): - self.import_tracker.require_name(expr.expr.name) - self.add_decorator(f"{expr.expr.name}.overload") - is_overload = True - return is_abstract, is_overload + return expr.fullname + name = get_qualified_name(expr) + if "." not in name: + real_module = self.import_tracker.module_for.get(name) + real_short = self.import_tracker.reverse_alias.get(name, name) + if real_module is None and real_short not in self.defined_names: + real_module = "builtins" # not imported and not defined, must be a builtin + else: + name_module, real_short = name.split(".", 1) + real_module = self.import_tracker.reverse_alias.get(name_module, name_module) + resolved_name = real_short if real_module is None else f"{real_module}.{real_short}" + return resolved_name def visit_class_def(self, o: ClassDef) -> None: self.method_names = find_method_names(o.defs.body) @@ -1004,12 +925,9 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: base_types: list[str] = [] p = AliasPrinter(self) for base in cdef.base_type_exprs: - if isinstance(base, NameExpr): - if base.name != "object": - base_types.append(base.name) - elif isinstance(base, MemberExpr): - modname = get_qualified_name(base.expr) - base_types.append(f"{modname}.{base.name}") + if isinstance(base, (NameExpr, MemberExpr)): + if self.get_fullname(base) != "builtins.object": + base_types.append(get_qualified_name(base)) elif isinstance(base, IndexExpr): base_types.append(base.accept(p)) elif isinstance(base, CallExpr): @@ -1027,21 +945,20 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: assert isinstance(base.args[0], StrExpr) typename = base.args[0].value if nt_fields is not None: - # A valid namedtuple() call, use NamedTuple() instead with - # Incomplete as field types fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields) - base_types.append(f"NamedTuple({typename!r}, [{fields_str}])") + namedtuple_name = self.typing_name("NamedTuple") + base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])") self.add_typing_import("NamedTuple") else: # Invalid namedtuple() call, cannot determine fields - base_types.append("Incomplete") + base_types.append(self.typing_name("Incomplete")) elif self.is_typed_namedtuple(base): base_types.append(base.accept(p)) else: # At this point, we don't know what the base class is, so we # just use Incomplete as the base class. - base_types.append("Incomplete") - self.import_tracker.require_name("Incomplete") + base_types.append(self.typing_name("Incomplete")) + self.add_typing_import("Incomplete") for name, value in cdef.keywords.items(): if name == "metaclass": continue # handled separately @@ -1058,26 +975,20 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: foundl = [] for lvalue in o.lvalues: - if ( - isinstance(lvalue, NameExpr) - and isinstance(o.rvalue, CallExpr) - and (self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue)) - ): - self.process_namedtuple(lvalue, o.rvalue) - continue - if ( - isinstance(lvalue, NameExpr) - and isinstance(o.rvalue, CallExpr) - and self.is_typeddict(o.rvalue) - ): - self.process_typeddict(lvalue, o.rvalue) - continue + if isinstance(lvalue, NameExpr) and isinstance(o.rvalue, CallExpr): + if self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue): + self.process_namedtuple(lvalue, o.rvalue) + foundl.append(False) # state is updated in process_namedtuple + continue + if self.is_typeddict(o.rvalue): + self.process_typeddict(lvalue, o.rvalue) + foundl.append(False) # state is updated in process_typeddict + continue if ( isinstance(lvalue, NameExpr) and not self.is_private_name(lvalue.name) - and # it is never an alias with explicit annotation - not o.unanalyzed_type + and not o.unanalyzed_type and self.is_alias_expression(o.rvalue) ): self.process_typealias(lvalue, o.rvalue) @@ -1109,26 +1020,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: self._state = VAR def is_namedtuple(self, expr: CallExpr) -> bool: - callee = expr.callee - return ( - isinstance(callee, NameExpr) - and (self.refers_to_fullname(callee.name, "collections.namedtuple")) - ) or ( - isinstance(callee, MemberExpr) - and isinstance(callee.expr, NameExpr) - and f"{callee.expr.name}.{callee.name}" == "collections.namedtuple" - ) + return self.get_fullname(expr.callee) == "collections.namedtuple" def is_typed_namedtuple(self, expr: CallExpr) -> bool: - callee = expr.callee - return ( - isinstance(callee, NameExpr) - and self.refers_to_fullname(callee.name, TYPED_NAMEDTUPLE_NAMES) - ) or ( - isinstance(callee, MemberExpr) - and isinstance(callee.expr, NameExpr) - and f"{callee.expr.name}.{callee.name}" in TYPED_NAMEDTUPLE_NAMES - ) + return self.get_fullname(expr.callee) in TYPED_NAMEDTUPLE_NAMES def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None: if self.is_namedtuple(call): @@ -1144,113 +1039,117 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None else: return None # Invalid namedtuple fields type if field_names: - self.import_tracker.require_name("Incomplete") - return [(field_name, "Incomplete") for field_name in field_names] + self.add_typing_import("Incomplete") + incomplete = self.typing_name("Incomplete") + return [(field_name, incomplete) for field_name in field_names] elif self.is_typed_namedtuple(call): fields_arg = call.args[1] if not isinstance(fields_arg, (ListExpr, TupleExpr)): return None fields: list[tuple[str, str]] = [] - b = AliasPrinter(self) + p = AliasPrinter(self) for field in fields_arg.items: if not (isinstance(field, TupleExpr) and len(field.items) == 2): return None field_name, field_type = field.items if not isinstance(field_name, StrExpr): return None - fields.append((field_name.value, field_type.accept(b))) + fields.append((field_name.value, field_type.accept(p))) return fields else: return None # Not a named tuple call def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: - if self._state != EMPTY: + if self._state == CLASS: self.add("\n") + + if not isinstance(rvalue.args[0], StrExpr): + self.annotate_as_incomplete(lvalue) + return + fields = self._get_namedtuple_fields(rvalue) if fields is None: - self.add(f"{self._indent}{lvalue.name}: Incomplete") - self.import_tracker.require_name("Incomplete") + self.annotate_as_incomplete(lvalue) return - self.import_tracker.require_name("NamedTuple") - self.add(f"{self._indent}class {lvalue.name}(NamedTuple):") + self.add_typing_import("NamedTuple") + bases = self.typing_name("NamedTuple") + # TODO: Add support for generic NamedTuples. Requires `Generic` as base class. + class_def = f"{self._indent}class {lvalue.name}({bases}):" if len(fields) == 0: - self.add(" ...\n") + self.add(f"{class_def} ...\n") self._state = EMPTY_CLASS else: - self.add("\n") + if self._state not in (EMPTY, CLASS): + self.add("\n") + self.add(f"{class_def}\n") for f_name, f_type in fields: self.add(f"{self._indent} {f_name}: {f_type}\n") self._state = CLASS def is_typeddict(self, expr: CallExpr) -> bool: - callee = expr.callee - return ( - isinstance(callee, NameExpr) and self.refers_to_fullname(callee.name, TPDICT_NAMES) - ) or ( - isinstance(callee, MemberExpr) - and isinstance(callee.expr, NameExpr) - and f"{callee.expr.name}.{callee.name}" in TPDICT_NAMES - ) + return self.get_fullname(expr.callee) in TPDICT_NAMES def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: - if self._state != EMPTY: + if self._state == CLASS: self.add("\n") if not isinstance(rvalue.args[0], StrExpr): - self.add(f"{self._indent}{lvalue.name}: Incomplete") - self.import_tracker.require_name("Incomplete") + self.annotate_as_incomplete(lvalue) return items: list[tuple[str, Expression]] = [] total: Expression | None = None if len(rvalue.args) > 1 and rvalue.arg_kinds[1] == ARG_POS: if not isinstance(rvalue.args[1], DictExpr): - self.add(f"{self._indent}{lvalue.name}: Incomplete") - self.import_tracker.require_name("Incomplete") + self.annotate_as_incomplete(lvalue) return for attr_name, attr_type in rvalue.args[1].items: if not isinstance(attr_name, StrExpr): - self.add(f"{self._indent}{lvalue.name}: Incomplete") - self.import_tracker.require_name("Incomplete") + self.annotate_as_incomplete(lvalue) return items.append((attr_name.value, attr_type)) if len(rvalue.args) > 2: if rvalue.arg_kinds[2] != ARG_NAMED or rvalue.arg_names[2] != "total": - self.add(f"{self._indent}{lvalue.name}: Incomplete") - self.import_tracker.require_name("Incomplete") + self.annotate_as_incomplete(lvalue) return total = rvalue.args[2] else: for arg_name, arg in zip(rvalue.arg_names[1:], rvalue.args[1:]): if not isinstance(arg_name, str): - self.add(f"{self._indent}{lvalue.name}: Incomplete") - self.import_tracker.require_name("Incomplete") + self.annotate_as_incomplete(lvalue) return if arg_name == "total": total = arg else: items.append((arg_name, arg)) - self.import_tracker.require_name("TypedDict") + self.add_typing_import("TypedDict") p = AliasPrinter(self) if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items): - # Keep the call syntax if there are non-identifier or keyword keys. + # Keep the call syntax if there are non-identifier or reserved keyword keys. self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n") self._state = VAR else: - bases = "TypedDict" + bases = self.typing_name("TypedDict") # TODO: Add support for generic TypedDicts. Requires `Generic` as base class. if total is not None: bases += f", total={total.accept(p)}" - self.add(f"{self._indent}class {lvalue.name}({bases}):") + class_def = f"{self._indent}class {lvalue.name}({bases}):" if len(items) == 0: - self.add(" ...\n") + self.add(f"{class_def} ...\n") self._state = EMPTY_CLASS else: - self.add("\n") + if self._state not in (EMPTY, CLASS): + self.add("\n") + self.add(f"{class_def}\n") for key, key_type in items: self.add(f"{self._indent} {key}: {key_type.accept(p)}\n") self._state = CLASS + def annotate_as_incomplete(self, lvalue: NameExpr) -> None: + self.add_typing_import("Incomplete") + self.add(f"{self._indent}{lvalue.name}: {self.typing_name('Incomplete')}\n") + self._state = VAR + def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: """Return True for things that look like target for an alias. @@ -1258,10 +1157,9 @@ def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: or module alias. """ # Assignment of TypeVar(...) are passed through - if ( - isinstance(expr, CallExpr) - and isinstance(expr.callee, NameExpr) - and expr.callee.name == "TypeVar" + if isinstance(expr, CallExpr) and self.get_fullname(expr.callee) in ( + "typing.TypeVar", + "typing_extensions.TypeVar", ): return True elif isinstance(expr, EllipsisExpr): @@ -1431,7 +1329,9 @@ def add(self, string: str) -> None: """Add text to generated stub.""" self._output.append(string) - def add_decorator(self, name: str) -> None: + def add_decorator(self, name: str, require_name: bool = False) -> None: + if require_name: + self.import_tracker.require_name(name) if not self._indent and self._state not in (EMPTY, FUNC): self._decorators.append("\n") self._decorators.append(f"{self._indent}@{name}\n") @@ -1447,15 +1347,7 @@ def typing_name(self, name: str) -> str: return name def add_typing_import(self, name: str) -> None: - """Add a name to be imported from typing, unless it's imported already. - - The import will be internal to the stub. - """ - name = self.typing_name(name) - self.import_tracker.require_name(name) - - def add_abc_import(self, name: str) -> None: - """Add a name to be imported from collections.abc, unless it's imported already. + """Add a name to be imported for typing, unless it's imported already. The import will be internal to the stub. """ @@ -1467,11 +1359,6 @@ def add_import_line(self, line: str) -> None: if line not in self._import_lines: self._import_lines.append(line) - def add_coroutine_decorator(self, func: FuncDef, name: str, require_name: str) -> None: - func.is_awaitable_coroutine = True - self.add_decorator(name) - self.import_tracker.require_name(require_name) - def output(self) -> str: """Return the text for the stub.""" imports = "" diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index ecace7270e95..1834284ef48e 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -664,7 +664,6 @@ Y = typing.NamedTuple('Y', []) from typing import NamedTuple class X(NamedTuple): ... - class Y(NamedTuple): ... [case testNamedtupleAltSyntax] @@ -786,6 +785,7 @@ from _typeshed import Incomplete N: Incomplete M: Incomplete + class X(Incomplete): ... [case testNamedTupleInClassBases] @@ -823,6 +823,29 @@ Y: Incomplete Z: Incomplete W: Incomplete +[case testNamedTupleFromImportAlias] +import collections as c +import typing as t +import typing_extensions as te +X = c.namedtuple('X', ['a', 'b']) +Y = t.NamedTuple('Y', [('a', int), ('b', str)]) +Z = te.NamedTuple('Z', [('a', int), ('b', str)]) +[out] +from _typeshed import Incomplete +from typing import NamedTuple + +class X(NamedTuple): + a: Incomplete + b: Incomplete + +class Y(NamedTuple): + a: int + b: str + +class Z(NamedTuple): + a: int + b: str + [case testArbitraryBaseClass] import x class D(x.C): ... @@ -870,6 +893,12 @@ class A(object): ... [out] class A: ... +[case testObjectBaseClassWithImport] +import builtins as b +class A(b.object): ... +[out] +class A: ... + [case testEmptyLines] def x(): ... def f(): @@ -1018,6 +1047,42 @@ from typing import TypeVar tv = TypeVar('tv', bound=bool, covariant=True) +[case TypeVarImportAlias] +from typing import TypeVar as t_TV +from typing_extensions import TypeVar as te_TV +from x import TypeVar as x_TV + +T = t_TV('T') +U = te_TV('U') +V = x_TV('V') + +[out] +from _typeshed import Incomplete +from typing import TypeVar as t_TV +from typing_extensions import TypeVar as te_TV + +T = t_TV('T') +U = te_TV('U') +V: Incomplete + +[case testTypeVarFromImportAlias] +import typing as t +import typing_extensions as te +import x + +T = t.TypeVar('T') +U = te.TypeVar('U') +V = x.TypeVar('V') + +[out] +import typing as t +import typing_extensions as te +from _typeshed import Incomplete + +T = t.TypeVar('T') +U = te.TypeVar('U') +V: Incomplete + [case testTypeAliasPreserved] alias = str @@ -1903,6 +1968,36 @@ class A: ... class C(A): def f(self) -> None: ... +[case testAbstractPropertyImportAlias] +import abc as abc_alias + +class A: + @abc_alias.abstractproperty + def x(self): pass + +[out] +import abc as abc_alias + +class A: + @property + @abc_alias.abstractmethod + def x(self): ... + +[case testAbstractPropertyFromImportAlias] +from abc import abstractproperty as ap + +class A: + @ap + def x(self): pass + +[out] +import abc + +class A: + @property + @abc.abstractmethod + def x(self): ... + [case testAbstractProperty1_semanal] import other import abc @@ -2814,6 +2909,27 @@ def g(x: int, y: int) -> int: ... @te.overload def g(x: t.Tuple[int, int]) -> int: ... +[case testOverloadFromImportAlias] +from typing import overload as t_overload +from typing_extensions import overload as te_overload + +@t_overload +def f(x: int, y: int) -> int: + ... + +@te_overload +def g(x: int, y: int) -> int: + ... + +[out] +from typing import overload as t_overload +from typing_extensions import overload as te_overload + +@t_overload +def f(x: int, y: int) -> int: ... +@te_overload +def g(x: int, y: int) -> int: ... + [case testProtocol_semanal] from typing import Protocol, TypeVar @@ -2969,9 +3085,7 @@ Z = TypedDict('X', {'a': int, 'in': str}) from typing import TypedDict X = TypedDict('X', {'a-b': int, 'c': str}) - Y = TypedDict('X', {'a-b': int, 'c': str}, total=False) - Z = TypedDict('X', {'a': int, 'in': str}) [case testEmptyTypeddict] @@ -2984,11 +3098,8 @@ W = typing.TypedDict('W', total=False) from typing_extensions import TypedDict class X(TypedDict): ... - class Y(TypedDict, total=False): ... - class Z(TypedDict): ... - class W(TypedDict, total=False): ... [case testTypeddictAliased] @@ -3013,6 +3124,22 @@ class Y(TypedDict): def g() -> None: ... +[case testTypeddictFromImportAlias] +import typing as t +import typing_extensions as te +X = t.TypedDict('X', {'a': int, 'b': str}) +Y = te.TypedDict('Y', {'a': int, 'b': str}) +[out] +from typing_extensions import TypedDict + +class X(TypedDict): + a: int + b: str + +class Y(TypedDict): + a: int + b: str + [case testNotTypeddict] from x import TypedDict import y @@ -3041,3 +3168,65 @@ T: Incomplete U: Incomplete V: Incomplete W: Incomplete + +[case testUseTypingName] +import collections +import typing +from typing import NamedTuple, TypedDict + +class Incomplete: ... +class Generator: ... +class NamedTuple: ... +class TypedDict: ... + +nt = collections.namedtuple("nt", "a b") +NT = typing.NamedTuple("NT", [("a", int), ("b", str)]) +NT1 = typing.NamedTuple("NT1", [("a", int)] + [("b", str)]) +NT2 = typing.NamedTuple("NT2", [(xx, int), ("b", str)]) +NT3 = typing.NamedTuple(xx, [("a", int), ("b", str)]) +TD = typing.TypedDict("TD", {"a": int, "b": str}) +TD1 = typing.TypedDict("TD1", {"a": int, "b": str}, totale=False) +TD2 = typing.TypedDict("TD2", {xx: int, "b": str}) +TD3 = typing.TypedDict(xx, {"a": int, "b": str}) + +def gen(): + y = yield x + return z + +class X(unknown_call("X", "a b")): ... +class Y(collections.namedtuple("Y", xx)): ... +[out] +from _typeshed import Incomplete as _Incomplete +from collections.abc import Generator as _Generator +from typing import NamedTuple as _NamedTuple +from typing_extensions import TypedDict as _TypedDict + +class Incomplete: ... +class Generator: ... +class NamedTuple: ... +class TypedDict: ... + +class nt(_NamedTuple): + a: _Incomplete + b: _Incomplete + +class NT(_NamedTuple): + a: int + b: str + +NT1: _Incomplete +NT2: _Incomplete +NT3: _Incomplete + +class TD(_TypedDict): + a: int + b: str + +TD1: _Incomplete +TD2: _Incomplete +TD3: _Incomplete + +def gen() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ... + +class X(_Incomplete): ... +class Y(_Incomplete): ...