Skip to content

Commit

Permalink
Restore the 0.4.1 behavior for libcst.helpers.get_absolute_module (In…
Browse files Browse the repository at this point in the history
  • Loading branch information
lpetre authored May 11, 2022
1 parent 460698a commit 149599e
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 21 deletions.
6 changes: 3 additions & 3 deletions libcst/codemod/commands/remove_unused_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from libcst import Import, ImportFrom, ImportStar, Module
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import GatherCommentsVisitor, RemoveImportsVisitor
from libcst.helpers import get_absolute_module_for_import
from libcst.helpers import get_absolute_module_from_package_for_import
from libcst.metadata import PositionProvider, ProviderT

DEFAULT_SUPPRESS_COMMENT_REGEX = (
Expand Down Expand Up @@ -74,8 +74,8 @@ def _handle_import(self, node: Union[Import, ImportFrom]) -> None:
asname=alias.evaluated_alias,
)
else:
module_name = get_absolute_module_for_import(
self.context.full_module_name, node
module_name = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
)
if module_name is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import
from libcst.helpers import get_absolute_module_from_package_for_import


class AddImportsVisitor(ContextAwareTransformer):
Expand Down Expand Up @@ -214,7 +214,7 @@ def leave_ImportFrom(
return updated_node

# Get the module we're importing as a string, see if we have work to do.
module = get_absolute_module_for_import(
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, updated_node
)
if (
Expand Down
6 changes: 4 additions & 2 deletions libcst/codemod/visitors/_gather_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import
from libcst.helpers import get_absolute_module_from_package_for_import


class GatherImportsVisitor(ContextAwareVisitor):
Expand Down Expand Up @@ -85,7 +85,9 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
self.all_imports.append(node)

# Get the module we're importing as a string.
module = get_absolute_module_for_import(self.context.full_package_name, node)
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
)
if module is None:
# Can't get the absolute import from relative, so we can't
# support this.
Expand Down
6 changes: 4 additions & 2 deletions libcst/codemod/visitors/_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, replace
from typing import Optional

from libcst.helpers import get_absolute_module
from libcst.helpers import get_absolute_module_from_package


@dataclass(frozen=True)
Expand Down Expand Up @@ -39,5 +39,7 @@ def resolve_relative(self, package_name: Optional[str]) -> "ImportItem":
mod = replace(mod, module_name="", obj_name=mod.module_name)
if package_name is None:
return mod
m = get_absolute_module(package_name, mod.module_name or None, self.relative)
m = get_absolute_module_from_package(
package_name, mod.module_name or None, self.relative
)
return mod if m is None else replace(mod, module_name=m, relative=0)
11 changes: 7 additions & 4 deletions libcst/codemod/visitors/_remove_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer, ContextAwareVisitor
from libcst.codemod.visitors._gather_unused_imports import GatherUnusedImportsVisitor
from libcst.helpers import get_absolute_module_for_import, get_full_name_for_node
from libcst.helpers import (
get_absolute_module_from_package_for_import,
get_full_name_for_node,
)
from libcst.metadata import Assignment, ProviderT, ScopeProvider


Expand Down Expand Up @@ -38,7 +41,7 @@ def _remove_imports_from_importfrom_stmt(
# We don't handle removing this, so ignore it.
return

module_name = get_absolute_module_for_import(
module_name = get_absolute_module_from_package_for_import(
self.context.full_package_name, import_node
)
if module_name is None:
Expand Down Expand Up @@ -248,7 +251,7 @@ def remove_unused_import_by_node(
if isinstance(names, cst.ImportStar):
# We don't handle removing this, so ignore it.
return
module_name = get_absolute_module_for_import(
module_name = get_absolute_module_from_package_for_import(
context.full_package_name, node
)
if module_name is None:
Expand Down Expand Up @@ -415,7 +418,7 @@ def leave_ImportFrom(
return updated_node

# Make sure we actually know the absolute module.
module_name = get_absolute_module_for_import(
module_name = get_absolute_module_from_package_for_import(
self.context.full_package_name, updated_node
)
if module_name is None or module_name not in self.unused_obj_imports:
Expand Down
6 changes: 6 additions & 0 deletions libcst/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
get_absolute_module,
get_absolute_module_for_import,
get_absolute_module_for_import_or_raise,
get_absolute_module_from_package,
get_absolute_module_from_package_for_import,
get_absolute_module_from_package_for_import_or_raise,
insert_header_comments,
ModuleNameAndPackage,
)
Expand All @@ -28,6 +31,9 @@
"get_absolute_module",
"get_absolute_module_for_import",
"get_absolute_module_for_import_or_raise",
"get_absolute_module_from_package",
"get_absolute_module_from_package_for_import",
"get_absolute_module_from_package_for_import_or_raise",
"get_full_name_for_node",
"get_full_name_for_node_or_raise",
"ensure_type",
Expand Down
57 changes: 53 additions & 4 deletions libcst/helpers/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,55 @@ def insert_header_comments(node: Module, comments: List[str]) -> Module:


def get_absolute_module(
current_module: Optional[str], module_name: Optional[str], num_dots: int
) -> Optional[str]:
if num_dots == 0:
# This is an absolute import, so the module is correct.
return module_name
if current_module is None:
# We don't actually have the current module available, so we can't compute
# the absolute module from relative.
return None
# We have the current module, as well as the relative, let's compute the base.
modules = current_module.split(".")
if len(modules) < num_dots:
# This relative import goes past the base of the repository, so we can't calculate it.
return None
base_module = ".".join(modules[:-num_dots])
# Finally, if the module name was supplied, append it to the end.
if module_name is not None:
# If we went all the way to the top, the base module should be empty, so we
# should return the relative bit as absolute. Otherwise, combine the base
# module and module name using a dot separator.
base_module = (
f"{base_module}.{module_name}" if len(base_module) > 0 else module_name
)
# If they tried to import all the way to the root, return None. Otherwise,
# return the module itself.
return base_module if len(base_module) > 0 else None


def get_absolute_module_for_import(
current_module: Optional[str], import_node: ImportFrom
) -> Optional[str]:
# First, let's try to grab the module name, regardless of relative status.
module = import_node.module
module_name = get_full_name_for_node(module) if module is not None else None
# Now, get the relative import location if it exists.
num_dots = len(import_node.relative)
return get_absolute_module(current_module, module_name, num_dots)


def get_absolute_module_for_import_or_raise(
current_module: Optional[str], import_node: ImportFrom
) -> str:
module = get_absolute_module_for_import(current_module, import_node)
if module is None:
raise Exception(f"Unable to compute absolute module for {import_node}")
return module


def get_absolute_module_from_package(
current_package: Optional[str], module_name: Optional[str], num_dots: int
) -> Optional[str]:
if num_dots == 0:
Expand All @@ -55,21 +104,21 @@ def get_absolute_module(
return "{}.{}".format(base, module_name) if module_name else base


def get_absolute_module_for_import(
def get_absolute_module_from_package_for_import(
current_package: Optional[str], import_node: ImportFrom
) -> Optional[str]:
# First, let's try to grab the module name, regardless of relative status.
module = import_node.module
module_name = get_full_name_for_node(module) if module is not None else None
# Now, get the relative import location if it exists.
num_dots = len(import_node.relative)
return get_absolute_module(current_package, module_name, num_dots)
return get_absolute_module_from_package(current_package, module_name, num_dots)


def get_absolute_module_for_import_or_raise(
def get_absolute_module_from_package_for_import_or_raise(
current_package: Optional[str], import_node: ImportFrom
) -> str:
module = get_absolute_module_for_import(current_package, import_node)
module = get_absolute_module_from_package_for_import(current_package, import_node)
if module is None:
raise Exception(f"Unable to compute absolute module for {import_node}")
return module
Expand Down
55 changes: 51 additions & 4 deletions libcst/helpers/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
calculate_module_and_package,
get_absolute_module_for_import,
get_absolute_module_for_import_or_raise,
get_absolute_module_from_package_for_import,
get_absolute_module_from_package_for_import_or_raise,
insert_header_comments,
ModuleNameAndPackage,
)
Expand Down Expand Up @@ -67,6 +69,44 @@ def test_insert_header_comments(self) -> None:
insert_header_comments(node, inserted_comments).code, expected_code
)

@data_provider(
(
# Simple imports that are already absolute.
(None, "from a.b import c", "a.b"),
("x.y.z", "from a.b import c", "a.b"),
# Relative import that can't be resolved due to missing module.
(None, "from ..w import c", None),
# Relative import that goes past the module level.
("x", "from ...y import z", None),
("x.y.z", "from .....w import c", None),
("x.y.z", "from ... import c", None),
# Correct resolution of absolute from relative modules.
("x.y.z", "from . import c", "x.y"),
("x.y.z", "from .. import c", "x"),
("x.y.z", "from .w import c", "x.y.w"),
("x.y.z", "from ..w import c", "x.w"),
("x.y.z", "from ...w import c", "w"),
)
)
def test_get_absolute_module(
self,
module: Optional[str],
importfrom: str,
output: Optional[str],
) -> None:
node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
assert len(node.body) == 1, "Unexpected number of statements!"
import_node = ensure_type(node.body[0], cst.ImportFrom)

self.assertEqual(get_absolute_module_for_import(module, import_node), output)
if output is None:
with self.assertRaises(Exception):
get_absolute_module_for_import_or_raise(module, import_node)
else:
self.assertEqual(
get_absolute_module_for_import_or_raise(module, import_node), output
)

@data_provider(
(
# Simple imports that are already absolute.
Expand Down Expand Up @@ -94,7 +134,7 @@ def test_insert_header_comments(self) -> None:
("x/y/z/__init__.py", "from ...w import c", "x.w"),
)
)
def test_get_absolute_module(
def test_get_absolute_module_from_package(
self,
filename: Optional[str],
importfrom: str,
Expand All @@ -108,13 +148,20 @@ def test_get_absolute_module(
assert len(node.body) == 1, "Unexpected number of statements!"
import_node = ensure_type(node.body[0], cst.ImportFrom)

self.assertEqual(get_absolute_module_for_import(package, import_node), output)
self.assertEqual(
get_absolute_module_from_package_for_import(package, import_node), output
)
if output is None:
with self.assertRaises(Exception):
get_absolute_module_for_import_or_raise(package, import_node)
get_absolute_module_from_package_for_import_or_raise(
package, import_node
)
else:
self.assertEqual(
get_absolute_module_for_import_or_raise(package, import_node), output
get_absolute_module_from_package_for_import_or_raise(
package, import_node
),
output,
)

@data_provider(
Expand Down

0 comments on commit 149599e

Please sign in to comment.