Skip to content

Commit

Permalink
Add NamedTuple to dataclass conversion codemod (Instagram#299)
Browse files Browse the repository at this point in the history
Add NamedTuple to dataclass conversion codemod.
  • Loading branch information
josieesh authored May 27, 2020
1 parent 3c5aa26 commit 7462265
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
74 changes: 74 additions & 0 deletions libcst/codemod/commands/convert_namedtuple_to_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from typing import List, Optional, Sequence

import libcst as cst
from libcst.codemod import VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import (
ProviderT,
QualifiedName,
QualifiedNameProvider,
QualifiedNameSource,
)


class ConvertNamedTupleToDataclassCommand(VisitorBasedCodemodCommand):
"""
Convert NamedTuple class declarations to Python 3.7 dataclasses.
This only performs a conversion at the class declaration level.
It does not perform type annotation conversions, nor does it convert
NamedTuple-specific attributes and methods.
"""

DESCRIPTION: str = "Convert NamedTuple class declarations to Python 3.7 dataclasses using the @dataclass decorator."
METADATA_DEPENDENCIES: Sequence[ProviderT] = (QualifiedNameProvider,)

# The 'NamedTuple' we are interested in
qualified_namedtuple: QualifiedName = QualifiedName(
name="typing.NamedTuple", source=QualifiedNameSource.IMPORT
)

def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
new_bases: List[cst.Arg] = []
namedtuple_base: Optional[cst.Arg] = None

# Need to examine the original node's bases since they are directly tied to import metadata
for base_class in original_node.bases:
# Compare the base class's qualified name against the expected typing.NamedTuple
if not QualifiedNameProvider.has_name(
self, base_class.value, self.qualified_namedtuple
):
# Keep all bases that are not of type typing.NamedTuple
new_bases.append(base_class)
else:
namedtuple_base = base_class

# We still want to return the updated node in case some of its children have been modified
if namedtuple_base is None:
return updated_node

AddImportsVisitor.add_needed_import(self.context, "dataclasses", "dataclass")
RemoveImportsVisitor.remove_unused_import_by_node(
self.context, namedtuple_base.value
)

call = cst.ensure_type(
cst.parse_expression(
"dataclass(frozen=True)", config=self.module.config_for_parsing
),
cst.Call,
)
return updated_node.with_changes(
lpar=cst.MaybeSentinel.DEFAULT,
rpar=cst.MaybeSentinel.DEFAULT,
bases=new_bases,
decorators=[*original_node.decorators, cst.Decorator(decorator=call)],
)
179 changes: 179 additions & 0 deletions libcst/codemod/commands/tests/test_convert_namedtuple_to_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from libcst.codemod import CodemodTest
from libcst.codemod.commands.convert_namedtuple_to_dataclass import (
ConvertNamedTupleToDataclassCommand,
)


class ConvertNamedTupleToDataclassCommandTest(CodemodTest):

TRANSFORM = ConvertNamedTupleToDataclassCommand

def test_no_change(self) -> None:
"""
Should result in no change as there are no children of NamedTuple.
"""

before = """
@dataclass(frozen=True)
class Foo:
pass
"""
after = """
@dataclass(frozen=True)
class Foo:
pass
"""
self.assertCodemod(before, after)

def test_change(self) -> None:
"""
Should remove the NamedTuple import along with its use as a base class for Foo.
Should import dataclasses.dataclass and annotate Foo.
"""

before = """
from typing import NamedTuple
class Foo(NamedTuple):
pass
"""
after = """
from dataclasses import dataclass
@dataclass(frozen=True)
class Foo:
pass
"""
self.assertCodemod(before, after)

def test_with_decorator_already(self) -> None:
"""
Should retain existing decorator.
"""

before = """
from typing import NamedTuple
@other_decorator
class Foo(NamedTuple):
pass
"""
after = """
from dataclasses import dataclass
@other_decorator
@dataclass(frozen=True)
class Foo:
pass
"""
self.assertCodemod(before, after)

def test_multiple_bases(self) -> None:
"""
Should retain all existing bases other than NamedTuple.
"""

before = """
from typing import NamedTuple
class Foo(NamedTuple, OtherBase, YetAnotherBase):
pass
"""
after = """
from dataclasses import dataclass
@dataclass(frozen=True)
class Foo(OtherBase, YetAnotherBase):
pass
"""
self.assertCodemod(before, after)

def test_nested_classes(self) -> None:
"""
Should perform expected changes on inner classes.
"""

before = """
from typing import NamedTuple
class OuterClass:
class InnerClass(NamedTuple):
pass
"""
after = """
from dataclasses import dataclass
class OuterClass:
@dataclass(frozen=True)
class InnerClass:
pass
"""
self.assertCodemod(before, after)

def test_aliased_object_import(self) -> None:
"""
Should detect aliased NamedTuple object import and base.
"""

before = """
from typing import NamedTuple as nt
class Foo(nt):
pass
"""
after = """
from dataclasses import dataclass
@dataclass(frozen=True)
class Foo:
pass
"""
self.assertCodemod(before, after)

def test_aliased_module_import(self) -> None:
"""
Should detect aliased `typing` module import and base.
"""

before = """
import typing as typ
class Foo(typ.NamedTuple):
pass
"""
after = """
from dataclasses import dataclass
@dataclass(frozen=True)
class Foo:
pass
"""
self.assertCodemod(before, after)

def test_other_unused_imports_not_removed(self) -> None:
"""
Should not remove any imports other than NamedTuple, even if they are also unused.
"""

before = """
from typing import NamedTuple
import SomeUnusedImport
class Foo(NamedTuple):
pass
"""
after = """
import SomeUnusedImport
from dataclasses import dataclass
@dataclass(frozen=True)
class Foo:
pass
"""
self.assertCodemod(before, after)

0 comments on commit 7462265

Please sign in to comment.