diff --git a/crates/ruff/resources/test/fixtures/pycodestyle/E712.py b/crates/ruff/resources/test/fixtures/pycodestyle/E712.py index 818afddaefe95..c0be4d7aa1c47 100644 --- a/crates/ruff/resources/test/fixtures/pycodestyle/E712.py +++ b/crates/ruff/resources/test/fixtures/pycodestyle/E712.py @@ -25,6 +25,12 @@ if res == True != False: pass +if(True) == TrueElement or x == TrueElement: + pass + +if (yield i) == True: + print("even") + #: Okay if x not in y: pass diff --git a/crates/ruff/src/rules/pycodestyle/helpers.rs b/crates/ruff/src/rules/pycodestyle/helpers.rs index 76a752b8f92fe..46b728b224533 100644 --- a/crates/ruff/src/rules/pycodestyle/helpers.rs +++ b/crates/ruff/src/rules/pycodestyle/helpers.rs @@ -1,8 +1,10 @@ -use ruff_python_ast::{CmpOp, Expr, Ranged}; -use ruff_text_size::{TextLen, TextRange}; use unicode_width::UnicodeWidthStr; +use ruff_python_ast::node::AnyNodeRef; +use ruff_python_ast::parenthesize::parenthesized_range; +use ruff_python_ast::{CmpOp, Expr, Ranged}; use ruff_source_file::{Line, Locator}; +use ruff_text_size::{TextLen, TextRange}; use crate::line_width::{LineLength, LineWidth, TabSize}; @@ -14,6 +16,7 @@ pub(super) fn generate_comparison( left: &Expr, ops: &[CmpOp], comparators: &[Expr], + parent: AnyNodeRef, locator: &Locator, ) -> String { let start = left.start(); @@ -21,7 +24,9 @@ pub(super) fn generate_comparison( let mut contents = String::with_capacity(usize::from(end - start)); // Add the left side of the comparison. - contents.push_str(locator.slice(left.range())); + contents.push_str(locator.slice( + parenthesized_range(left.into(), parent, locator.contents()).unwrap_or(left.range()), + )); for (op, comparator) in ops.iter().zip(comparators) { // Add the operator. @@ -39,7 +44,12 @@ pub(super) fn generate_comparison( }); // Add the right side of the comparison. - contents.push_str(locator.slice(comparator.range())); + contents.push_str( + locator.slice( + parenthesized_range(comparator.into(), parent, locator.contents()) + .unwrap_or(comparator.range()), + ), + ); } contents diff --git a/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs b/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs index d16fa579f74de..9c7f525b6ca72 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs @@ -279,8 +279,13 @@ pub(crate) fn literal_comparisons(checker: &mut Checker, compare: &ast::ExprComp .map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op)) .copied() .collect::>(); - let content = - generate_comparison(&compare.left, &ops, &compare.comparators, checker.locator()); + let content = generate_comparison( + &compare.left, + &ops, + &compare.comparators, + compare.into(), + checker.locator(), + ); for diagnostic in &mut diagnostics { diagnostic.set_fix(Fix::suggested(Edit::range_replacement( content.to_string(), diff --git a/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs b/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs index 0ee126d046933..7aba04d41aa0e 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs @@ -94,7 +94,13 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) { let mut diagnostic = Diagnostic::new(NotInTest, unary_op.operand.range()); if checker.patch(diagnostic.kind.rule()) { diagnostic.set_fix(Fix::automatic(Edit::range_replacement( - generate_comparison(left, &[CmpOp::NotIn], comparators, checker.locator()), + generate_comparison( + left, + &[CmpOp::NotIn], + comparators, + unary_op.into(), + checker.locator(), + ), unary_op.range(), ))); } @@ -106,7 +112,13 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) { let mut diagnostic = Diagnostic::new(NotIsTest, unary_op.operand.range()); if checker.patch(diagnostic.kind.rule()) { diagnostic.set_fix(Fix::automatic(Edit::range_replacement( - generate_comparison(left, &[CmpOp::IsNot], comparators, checker.locator()), + generate_comparison( + left, + &[CmpOp::IsNot], + comparators, + unary_op.into(), + checker.locator(), + ), unary_op.range(), ))); } diff --git a/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap b/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap index e2a9be7b88be8..ba3f1143bd887 100644 --- a/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap +++ b/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap @@ -181,7 +181,7 @@ E712.py:22:5: E712 [*] Comparison to `True` should be `cond is True` or `if cond 20 20 | var = 1 if cond == True else -1 if cond == False else cond 21 21 | #: E712 22 |-if (True) == TrueElement or x == TrueElement: - 22 |+if True is TrueElement or x == TrueElement: + 22 |+if (True) is TrueElement or x == TrueElement: 23 23 | pass 24 24 | 25 25 | if res == True != False: @@ -204,7 +204,7 @@ E712.py:25:11: E712 [*] Comparison to `True` should be `cond is True` or `if con 25 |+if res is True is not False: 26 26 | pass 27 27 | -28 28 | #: Okay +28 28 | if(True) == TrueElement or x == TrueElement: E712.py:25:19: E712 [*] Comparison to `False` should be `cond is not False` or `if cond:` | @@ -224,6 +224,46 @@ E712.py:25:19: E712 [*] Comparison to `False` should be `cond is not False` or ` 25 |+if res is True is not False: 26 26 | pass 27 27 | -28 28 | #: Okay +28 28 | if(True) == TrueElement or x == TrueElement: + +E712.py:28:4: E712 [*] Comparison to `True` should be `cond is True` or `if cond:` + | +26 | pass +27 | +28 | if(True) == TrueElement or x == TrueElement: + | ^^^^ E712 +29 | pass + | + = help: Replace with `cond is True` + +ℹ Suggested fix +25 25 | if res == True != False: +26 26 | pass +27 27 | +28 |-if(True) == TrueElement or x == TrueElement: + 28 |+if(True) is TrueElement or x == TrueElement: +29 29 | pass +30 30 | +31 31 | if (yield i) == True: + +E712.py:31:17: E712 [*] Comparison to `True` should be `cond is True` or `if cond:` + | +29 | pass +30 | +31 | if (yield i) == True: + | ^^^^ E712 +32 | print("even") + | + = help: Replace with `cond is True` + +ℹ Suggested fix +28 28 | if(True) == TrueElement or x == TrueElement: +29 29 | pass +30 30 | +31 |-if (yield i) == True: + 31 |+if (yield i) is True: +32 32 | print("even") +33 33 | +34 34 | #: Okay diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index ac615c12803aa..d28f459dd4af4 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -12,6 +12,7 @@ pub mod identifier; pub mod imports; pub mod node; mod nodes; +pub mod parenthesize; pub mod relocate; pub mod statement_visitor; pub mod stmt_if; diff --git a/crates/ruff_python_ast/src/parenthesize.rs b/crates/ruff_python_ast/src/parenthesize.rs new file mode 100644 index 0000000000000..e7b4866fcbdbf --- /dev/null +++ b/crates/ruff_python_ast/src/parenthesize.rs @@ -0,0 +1,47 @@ +use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer}; +use ruff_text_size::{TextRange, TextSize}; + +use crate::node::AnyNodeRef; +use crate::{ExpressionRef, Ranged}; + +/// Returns the [`TextRange`] of a given expression including parentheses, if the expression is +/// parenthesized; or `None`, if the expression is not parenthesized. +pub fn parenthesized_range( + expr: ExpressionRef, + parent: AnyNodeRef, + contents: &str, +) -> Option { + // If the parent is a node that brings its own parentheses, exclude the closing parenthesis + // from our search range. Otherwise, we risk matching on calls, like `func(x)`, for which + // the open and close parentheses are part of the `Arguments` node. + // + // There are a few other nodes that may have their own parentheses, but are fine to exclude: + // - `Parameters`: The parameters to a function definition. Any expressions would represent + // default arguments, and so must be preceded by _at least_ the parameter name. As such, + // we won't mistake any parentheses for the opening and closing parentheses on the + // `Parameters` node itself. + // - `Tuple`: The elements of a tuple. The only risk is a single-element tuple (e.g., `(x,)`), + // which must have a trailing comma anyway. + let exclusive_parent_end = if parent.is_arguments() { + parent.end() - TextSize::new(1) + } else { + parent.end() + }; + + // First, test if there's a closing parenthesis because it tends to be cheaper. + let tokenizer = + SimpleTokenizer::new(contents, TextRange::new(expr.end(), exclusive_parent_end)); + let right = tokenizer.skip_trivia().next()?; + + if right.kind == SimpleTokenKind::RParen { + // Next, test for the opening parenthesis. + let mut tokenizer = + SimpleTokenizer::up_to_without_back_comment(expr.start(), contents).skip_trivia(); + let left = tokenizer.next_back()?; + if left.kind == SimpleTokenKind::LParen { + return Some(TextRange::new(left.start(), right.end())); + } + } + + None +} diff --git a/crates/ruff_python_ast/src/stmt_if.rs b/crates/ruff_python_ast/src/stmt_if.rs index 77a0164badcc6..c8a1e34e836ce 100644 --- a/crates/ruff_python_ast/src/stmt_if.rs +++ b/crates/ruff_python_ast/src/stmt_if.rs @@ -46,6 +46,3 @@ pub fn if_elif_branches(stmt_if: &StmtIf) -> impl Iterator }) })) } - -#[cfg(test)] -mod test {} diff --git a/crates/ruff_python_ast/tests/parenthesize.rs b/crates/ruff_python_ast/tests/parenthesize.rs new file mode 100644 index 0000000000000..eb1ef3850de7a --- /dev/null +++ b/crates/ruff_python_ast/tests/parenthesize.rs @@ -0,0 +1,76 @@ +use ruff_python_ast::parenthesize::parenthesized_range; +use ruff_python_parser::parse_expression; + +#[test] +fn test_parenthesized_name() { + let source_code = r#"(x) + 1"#; + let expr = parse_expression(source_code, "").unwrap(); + + let bin_op = expr.as_bin_op_expr().unwrap(); + let name = bin_op.left.as_ref(); + + let parenthesized = parenthesized_range(name.into(), bin_op.into(), source_code); + assert!(parenthesized.is_some()); +} + +#[test] +fn test_non_parenthesized_name() { + let source_code = r#"x + 1"#; + let expr = parse_expression(source_code, "").unwrap(); + + let bin_op = expr.as_bin_op_expr().unwrap(); + let name = bin_op.left.as_ref(); + + let parenthesized = parenthesized_range(name.into(), bin_op.into(), source_code); + assert!(parenthesized.is_none()); +} + +#[test] +fn test_parenthesized_argument() { + let source_code = r#"f((a))"#; + let expr = parse_expression(source_code, "").unwrap(); + + let call = expr.as_call_expr().unwrap(); + let arguments = &call.arguments; + let argument = arguments.args.first().unwrap(); + + let parenthesized = parenthesized_range(argument.into(), arguments.into(), source_code); + assert!(parenthesized.is_some()); +} + +#[test] +fn test_non_parenthesized_argument() { + let source_code = r#"f(a)"#; + let expr = parse_expression(source_code, "").unwrap(); + + let call = expr.as_call_expr().unwrap(); + let arguments = &call.arguments; + let argument = arguments.args.first().unwrap(); + + let parenthesized = parenthesized_range(argument.into(), arguments.into(), source_code); + assert!(parenthesized.is_none()); +} + +#[test] +fn test_parenthesized_tuple_member() { + let source_code = r#"(a, (b))"#; + let expr = parse_expression(source_code, "").unwrap(); + + let tuple = expr.as_tuple_expr().unwrap(); + let member = tuple.elts.last().unwrap(); + + let parenthesized = parenthesized_range(member.into(), tuple.into(), source_code); + assert!(parenthesized.is_some()); +} + +#[test] +fn test_non_parenthesized_tuple_member() { + let source_code = r#"(a, b)"#; + let expr = parse_expression(source_code, "").unwrap(); + + let tuple = expr.as_tuple_expr().unwrap(); + let member = tuple.elts.last().unwrap(); + + let parenthesized = parenthesized_range(member.into(), tuple.into(), source_code); + assert!(parenthesized.is_none()); +}