Skip to content

Commit

Permalink
ApplyTypeAnnotationsVisitor: fix default value of keyword only and po…
Browse files Browse the repository at this point in the history
…sitional-only args (Instagram#314)

* ApplyTypeAnnotationsVisitor: fix default values of keyword only args

* ApplyTypeAnnotationsVisitor: fix default values of positional-only args

Co-authored-by: hauntsaninja <>
  • Loading branch information
hauntsaninja authored Jun 16, 2020
1 parent c023fa7 commit 5992a7d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
9 changes: 8 additions & 1 deletion libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,14 @@ def update_annotation(
return annotations.parameters.with_changes(
params=update_annotation(
updated_node.params.params, annotations.parameters.params
)
),
kwonly_params=update_annotation(
updated_node.params.kwonly_params, annotations.parameters.kwonly_params
),
posonly_params=update_annotation(
updated_node.params.posonly_params,
annotations.parameters.posonly_params,
),
)

def _insert_empty_line(
Expand Down
71 changes: 71 additions & 0 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# LICENSE file in the root directory of this source tree.
#

import sys
import textwrap
import unittest
from typing import Type

from libcst import parse_module
Expand All @@ -31,6 +33,25 @@ def foo() -> int:
return 1
""",
),
(
"""
def foo(
b: str, c: int = ..., *, d: str = ..., e: int, f: int = ...
) -> int: ...
""",
"""
def foo(
b, c=5, *, d="a", e, f=10
) -> int:
return 1
""",
"""
def foo(
b: str, c: int=5, *, d: str="a", e: int, f: int=10
) -> int:
return 1
""",
),
(
"""
import bar
Expand Down Expand Up @@ -612,6 +633,56 @@ def test_annotate_functions(self, stub: str, before: str, after: str) -> None:
)
self.assertCodemod(before, after, context_override=context)

@data_provider(
(
(
"""
def foo(
a: int, /, b: str, c: int = ..., *, d: str = ..., e: int, f: int = ...
) -> int: ...
""",
"""
def foo(
a, /, b, c=5, *, d="a", e, f=10
) -> int:
return 1
""",
"""
def foo(
a: int, /, b: str, c: int=5, *, d: str="a", e: int, f: int=10
) -> int:
return 1
""",
),
(
"""
def foo(
a: int, b: int = ..., /, c: int = ..., *, d: str = ..., e: int, f: int = ...
) -> int: ...
""",
"""
def foo(
a, b = 5, /, c = 10, *, d = "a", e, f = 20
) -> int:
return 1
""",
"""
def foo(
a: int, b: int = 5, /, c: int = 10, *, d: str = "a", e: int, f: int = 20
) -> int:
return 1
""",
),
)
)
@unittest.skipIf(sys.version_info < (3, 8), "Unsupported Python version")
def test_annotate_functions_py38(self, stub: str, before: str, after: str) -> None:
context = CodemodContext()
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context, parse_module(textwrap.dedent(stub.rstrip()))
)
self.assertCodemod(before, after, context_override=context)

@data_provider(
(
(
Expand Down

0 comments on commit 5992a7d

Please sign in to comment.