Skip to content

Commit

Permalink
Fix NotEqual position issue (Instagram#325)
Browse files Browse the repository at this point in the history
Co-authored-by: Jimmy Lai <jimmylai@fb.com>
  • Loading branch information
jimmylai and jimmylai authored Jun 29, 2020
1 parent 2d56ba7 commit 3a7ffaf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
10 changes: 4 additions & 6 deletions libcst/_nodes/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def _get_token(self) -> str:

@add_slots
@dataclass(frozen=True)
class NotEqual(BaseCompOp):
class NotEqual(BaseCompOp, _BaseOneTokenOp):
"""
A comparison operator that can be used in a :class:`Comparison` expression.
Expand All @@ -691,7 +691,7 @@ def _validate(self) -> None:
if self.value not in ["!=", "<>"]:
raise CSTValidationError("Invalid value for NotEqual node.")

def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "BaseCompOp":
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NotEqual":
return self.__class__(
whitespace_before=visit_required(
self, "whitespace_before", self.whitespace_before, visitor
Expand All @@ -702,10 +702,8 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "BaseCompOp":
),
)

def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token(self.value)
self.whitespace_after._codegen(state)
def _get_token(self) -> str:
return self.value


@add_slots
Expand Down
24 changes: 22 additions & 2 deletions libcst/metadata/tests/test_position_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import libcst as cst
from libcst import parse_module
from libcst._batched_visitor import BatchableCSTVisitor
from libcst._visitors import CSTTransformer
from libcst._visitors import CSTVisitor
from libcst.metadata import (
CodeRange,
MetadataWrapper,
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_visitor_provider(self) -> None:
"""
test = self

class DependentVisitor(CSTTransformer):
class DependentVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_Pass(self, node: cst.Pass) -> None:
Expand All @@ -49,6 +49,26 @@ def visit_Pass(self, node: cst.Pass) -> None:
wrapper = MetadataWrapper(parse_module("pass"))
wrapper.visit(DependentVisitor())

def test_equal_range(self) -> None:
test = self
expected_range = CodeRange((1, 4), (1, 6))

class EqualPositionVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_Equal(self, node: cst.Equal) -> None:
test.assertEqual(
self.get_metadata(PositionProvider, node), expected_range
)

def visit_NotEqual(self, node: cst.NotEqual) -> None:
test.assertEqual(
self.get_metadata(PositionProvider, node), expected_range
)

MetadataWrapper(parse_module("var == 1")).visit(EqualPositionVisitor())
MetadataWrapper(parse_module("var != 1")).visit(EqualPositionVisitor())

def test_batchable_provider(self) -> None:
test = self

Expand Down

0 comments on commit 3a7ffaf

Please sign in to comment.