Skip to content

Commit

Permalink
Make get_qualified_names_for and __contains__ not change the scope
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol authored and jimmylai committed Mar 2, 2020
1 parent 922cc72 commit a7200dc
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
28 changes: 26 additions & 2 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,18 @@ def _getitem_from_self_or_parent(self, name: str) -> Tuple[BaseAssignment, ...]:
"""Overridden by ClassScope to hide it's assignments from child scopes."""
return self[name]

def _contains_in_self_or_parent(self, name: str) -> bool:
"""Overridden by ClassScope to hide it's assignments from child scopes."""
return name in self

def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None:
"""Overridden by ClassScope to forward 'nonlocal' assignments from child scopes."""
self.record_assignment(name, node)

@abc.abstractmethod
def __contains__(self, name: str) -> bool:
""" Check if the name str exist in current scope by ``name in scope``. """
return len(self[name]) > 0
...

@abc.abstractmethod
def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
Expand Down Expand Up @@ -407,7 +412,8 @@ def f(self) -> "c":
if full_name is None:
return results
parts = full_name.split(".")
for assignment in self[parts[0]]:
assignments = self[parts[0]] if parts[0] in self else set()
for assignment in assignments:
if isinstance(assignment, Assignment):
assignment_node = assignment.node
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
Expand Down Expand Up @@ -446,6 +452,11 @@ def __init__(self) -> None:
self.globals: Scope = self # must be defined before Scope.__init__ is called
super().__init__(parent=self)

def __contains__(self, name: str) -> bool:
return hasattr(builtins, name) or (
name in self._assignments and len(self._assignments[name]) > 0
)

def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
if hasattr(builtins, name):
if not any(
Expand Down Expand Up @@ -490,6 +501,13 @@ def record_assignment(self, name: str, node: cst.CSTNode) -> None:
else:
super().record_assignment(name, node)

def __contains__(self, name: str) -> bool:
if name in self._scope_overwrites:
return name in self._scope_overwrites[name]
if name in self._assignments:
return len(self._assignments[name]) > 0
return self.parent._contains_in_self_or_parent(name)

def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
if name in self._scope_overwrites:
return self._scope_overwrites[name]._getitem_from_self_or_parent(name)
Expand Down Expand Up @@ -538,6 +556,12 @@ def _getitem_from_self_or_parent(self, name: str) -> Tuple[BaseAssignment, ...]:
"""
return self.parent._getitem_from_self_or_parent(name)

def _contains_in_self_or_parent(self, name: str) -> bool:
"""
See :meth:`_getitem_from_self_or_parent`
"""
return self.parent._contains_in_self_or_parent(name)


# even though we don't override the constructor.
class ComprehensionScope(LocalScope):
Expand Down
48 changes: 48 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ComprehensionScope,
FunctionScope,
GlobalScope,
LocalScope,
QualifiedName,
QualifiedNameSource,
Scope,
Expand Down Expand Up @@ -996,3 +997,50 @@ def test_keyword_arg_in_call(self) -> None:
scope = scopes[call]
self.assertIsInstance(scope, GlobalScope)
self.assertEqual(len(scope["arg"]), 0) # no assignment should exist

def test_global_contains_is_read_only(self) -> None:
gscope = GlobalScope()
before_assignments = list(gscope._assignments.items())
before_accesses = list(gscope._accesses.items())
self.assertFalse("doesnt_exist" in gscope)
self.assertEqual(list(gscope._accesses.items()), before_accesses)
self.assertEqual(list(gscope._assignments.items()), before_assignments)

def test_contains_is_read_only(self) -> None:
for s in [LocalScope, FunctionScope, ClassScope, ComprehensionScope]:
with self.subTest(scope=s):
gscope = GlobalScope()
scope = s(parent=gscope, node=cst.Name("lol"))
before_assignments = list(scope._assignments.items())
before_accesses = list(scope._accesses.items())
before_overwrites = list(scope._scope_overwrites.items())
before_parent_assignments = list(scope.parent._assignments.items())
before_parent_accesses = list(scope.parent._accesses.items())

self.assertFalse("doesnt_exist" in scope)
self.assertEqual(list(scope._accesses.items()), before_accesses)
self.assertEqual(list(scope._assignments.items()), before_assignments)
self.assertEqual(
list(scope._scope_overwrites.items()), before_overwrites
)
self.assertEqual(
list(scope.parent._assignments.items()), before_parent_assignments
)
self.assertEqual(
list(scope.parent._accesses.items()), before_parent_accesses
)

def test_get_qualified_names_for_is_read_only(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a
import b
"""
)
a = m.body[0]
scope = scopes[a]
assignments_len_before = len(scope._assignments)
accesses_len_before = len(scope._accesses)
scope.get_qualified_names_for("doesnt_exist")
self.assertEqual(len(scope._assignments), assignments_len_before)
self.assertEqual(len(scope._accesses), accesses_len_before)

0 comments on commit a7200dc

Please sign in to comment.