diff --git a/resources/test/fixtures/flake8_simplify/SIM103.py b/resources/test/fixtures/flake8_simplify/SIM103.py index c3101611d9c22..f867006131879 100644 --- a/resources/test/fixtures/flake8_simplify/SIM103.py +++ b/resources/test/fixtures/flake8_simplify/SIM103.py @@ -6,6 +6,14 @@ def f(): return False +def f(): + # SIM103 + if a == b: + return True + else: + return False + + def f(): # SIM103 if a: diff --git a/src/rules/flake8_simplify/rules/ast_if.rs b/src/rules/flake8_simplify/rules/ast_if.rs index 1c9c7a540384a..b8db4a098df49 100644 --- a/src/rules/flake8_simplify/rules/ast_if.rs +++ b/src/rules/flake8_simplify/rules/ast_if.rs @@ -184,16 +184,21 @@ pub fn return_bool_condition_directly(checker: &mut Checker, stmt: &Stmt) { && matches!(else_return, Bool::False) && !has_comments(stmt, checker.locator) { - let return_stmt = create_stmt(StmtKind::Return { - value: Some(Box::new(create_expr(ExprKind::Call { - func: Box::new(create_expr(ExprKind::Name { - id: "bool".to_string(), - ctx: ExprContext::Load, - })), - args: vec![(**test).clone()], - keywords: vec![], - }))), - }); + let return_stmt = match test.node { + ExprKind::Compare { .. } => create_stmt(StmtKind::Return { + value: Some(test.clone()), + }), + _ => create_stmt(StmtKind::Return { + value: Some(Box::new(create_expr(ExprKind::Call { + func: Box::new(create_expr(ExprKind::Name { + id: "bool".to_string(), + ctx: ExprContext::Load, + })), + args: vec![(**test).clone()], + keywords: vec![], + }))), + }), + }; diagnostic.amend(Fix::replacement( unparse_stmt(&return_stmt, checker.stylist), stmt.location, diff --git a/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM103_SIM103.py.snap b/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM103_SIM103.py.snap index 22c0a7af3f0e0..36d8dc8829f56 100644 --- a/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM103_SIM103.py.snap +++ b/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM103_SIM103.py.snap @@ -21,52 +21,71 @@ expression: diagnostics row: 6 column: 20 parent: ~ +- kind: + ReturnBoolConditionDirectly: + cond: a == b + location: + row: 11 + column: 4 + end_location: + row: 14 + column: 20 + fix: + content: + - return a == b + location: + row: 11 + column: 4 + end_location: + row: 14 + column: 20 + parent: ~ - kind: ReturnBoolConditionDirectly: cond: b location: - row: 13 + row: 21 column: 4 end_location: - row: 16 + row: 24 column: 20 fix: content: - return bool(b) location: - row: 13 + row: 21 column: 4 end_location: - row: 16 + row: 24 column: 20 parent: ~ - kind: ReturnBoolConditionDirectly: cond: b location: - row: 24 + row: 32 column: 8 end_location: - row: 27 + row: 35 column: 24 fix: content: - return bool(b) location: - row: 24 + row: 32 column: 8 end_location: - row: 27 + row: 35 column: 24 parent: ~ - kind: ReturnBoolConditionDirectly: cond: a location: - row: 49 + row: 57 column: 4 end_location: - row: 52 + row: 60 column: 19 fix: ~ parent: ~