Skip to content

Commit

Permalink
Add argument for ignoring existing annotations.
Browse files Browse the repository at this point in the history
This will allow us to override existing types based on the stub.
  • Loading branch information
pradeep90 authored and jimmylai committed May 12, 2020
1 parent 93a389b commit bfcc456
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 14 deletions.
55 changes: 42 additions & 13 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer):
This is one of the transforms that is available automatically to you when
running a codemod. To use it in this manner, import
:class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call the static
:meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.add_stub_to_context` method,
:meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.store_stub_in_context` method,
giving it the current context (found as ``self.context`` for all subclasses of
:class:`~libcst.codemod.Codemod`), the stub module from which you wish to add annotations.
For example, you can store the type annotation ``int`` for ``x`` using::
stub_module = parse_module("x: int = ...")
ApplyTypeAnnotationsVisitor.add_stub_to_context(self.context, stub_module)
ApplyTypeAnnotationsVisitor.store_stub_in_context(self.context, stub_module)
You can apply the type annotation using::
Expand All @@ -223,33 +223,52 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer):
x: int = 1
If the function or attribute already has a type annotation, it will not be overwritten.
To overwrite existing annotations when applying annotations from a stub,
use the keyword argument ``overwrite_existing_annotations=True`` when
constructing the codemod or when calling ``store_stub_in_context``.
"""

CONTEXT_KEY = "ApplyTypeAnnotationsVisitor"

def __init__(
self, context: CodemodContext, annotations: Optional[Annotations] = None
self,
context: CodemodContext,
annotations: Optional[Annotations] = None,
overwrite_existing_annotations: bool = False,
) -> None:
super().__init__(context)
# Qualifier for storing the canonical name of the current function.
self.qualifier: List[str] = []
self.annotations: Annotations = annotations or Annotations()
self.toplevel_annotations: Dict[str, cst.Annotation] = {}
self.visited_classes: Set[str] = set()
self.overwrite_existing_annotations = overwrite_existing_annotations

# We use this to determine the end of the import block so that we can
# insert top-level annotations.
self.import_statements: List[cst.ImportFrom] = []

@staticmethod
def add_stub_to_context(context: CodemodContext, stub: cst.Module) -> None:
def store_stub_in_context(
context: CodemodContext,
stub: cst.Module,
overwrite_existing_annotations: bool = False,
) -> None:
"""
Add a stub module to the :class:`~libcst.codemod.CodemodContext` so
Store a stub module in the :class:`~libcst.codemod.CodemodContext` so
that type annotations from the stub can be applied in a later
invocation of this class.
If the ``overwrite_existing_annotations`` flag is ``True``, the
codemod will overwrite any existing annotations.
If you call this function multiple times, only the last values of
``stub`` and ``overwrite_existing_annotations`` will take effect.
"""
context.scratch.setdefault(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, []).append(
stub
context.scratch[ApplyTypeAnnotationsVisitor.CONTEXT_KEY] = (
stub,
overwrite_existing_annotations,
)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
Expand All @@ -262,8 +281,14 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
tree.visit(import_gatherer)
existing_import_names = _get_import_names(import_gatherer.all_imports)

stubs = self.context.scratch.get(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, [])
for stub in stubs:
context_contents = self.context.scratch.get(
ApplyTypeAnnotationsVisitor.CONTEXT_KEY
)
if context_contents:
stub, overwrite_existing_annotations = context_contents
self.overwrite_existing_annotations = (
self.overwrite_existing_annotations or overwrite_existing_annotations
)
visitor = TypeCollector(existing_import_names, self.context)
stub.visit(visitor)
self.annotations.function_annotations.update(visitor.function_annotations)
Expand Down Expand Up @@ -339,7 +364,8 @@ def _update_parameters(
self, annotations: FunctionAnnotation, updated_node: cst.FunctionDef
) -> cst.Parameters:
# Update params and default params with annotations
# don't override existing annotations or default values
# Don't override existing annotations or default values unless asked
# to overwrite existing annotations.
def update_annotation(
parameters: Sequence[cst.Param], annotations: Sequence[cst.Param]
) -> List[cst.Param]:
Expand All @@ -350,7 +376,9 @@ def update_annotation(
parameter_annotations[parameter.name.value] = parameter.annotation
for parameter in parameters:
key = parameter.name.value
if key in parameter_annotations and not parameter.annotation:
if key in parameter_annotations and (
self.overwrite_existing_annotations or not parameter.annotation
):
parameter = parameter.with_changes(
annotation=parameter_annotations[key]
)
Expand Down Expand Up @@ -409,8 +437,9 @@ def leave_FunctionDef(
self.qualifier.pop()
if key in self.annotations.function_annotations:
function_annotation = self.annotations.function_annotations[key]
# Only add new annotation if one doesn't already exist
if not updated_node.returns:
# Only add new annotation if explicitly told to overwrite existing
# annotations or if one doesn't already exist.
if self.overwrite_existing_annotations or not updated_node.returns:
updated_node = updated_node.with_changes(
returns=function_annotation.returns
)
Expand Down
40 changes: 39 additions & 1 deletion libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,45 @@ def foo() -> typing.Sequence[int]:
)
def test_annotate_functions(self, stub: str, before: str, after: str) -> None:
context = CodemodContext()
ApplyTypeAnnotationsVisitor.add_stub_to_context(
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context, parse_module(textwrap.dedent(stub.rstrip()))
)
self.assertCodemod(before, after, context_override=context)

@data_provider(
(
(
"""
def fully_annotated_with_different_stub(a: bool, b: bool) -> str: ...
""",
"""
def fully_annotated_with_different_stub(a: int, b: str) -> bool:
return 'hello'
""",
"""
def fully_annotated_with_different_stub(a: bool, b: bool) -> str:
return 'hello'
""",
),
)
)
def test_annotate_functions_with_existing_annotations(
self, stub: str, before: str, after: str
) -> None:
context = CodemodContext()
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context, parse_module(textwrap.dedent(stub.rstrip()))
)
# Test setting the overwrite flag on the codemod instance.
self.assertCodemod(
before, after, context_override=context, overwrite_existing_annotations=True
)

# Test setting the flag when storing the stub in the context.
context = CodemodContext()
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context,
parse_module(textwrap.dedent(stub.rstrip())),
overwrite_existing_annotations=True,
)
self.assertCodemod(before, after, context_override=context)

0 comments on commit bfcc456

Please sign in to comment.