diff --git a/crates/ruff/src/rules/flake8_return/helpers.rs b/crates/ruff/src/rules/flake8_return/helpers.rs index 307d8936c324c6..bf115210fb3288 100644 --- a/crates/ruff/src/rules/flake8_return/helpers.rs +++ b/crates/ruff/src/rules/flake8_return/helpers.rs @@ -1,4 +1,5 @@ use ruff_text_size::TextSize; +use rustpython_parser::ast; use rustpython_parser::ast::{Expr, Ranged, Stmt}; use ruff_python_ast::source_code::Locator; @@ -6,15 +7,14 @@ use ruff_python_whitespace::UniversalNewlines; /// Return `true` if a function's return statement include at least one /// non-`None` value. -pub(super) fn result_exists(returns: &[(&Stmt, Option<&Expr>)]) -> bool { - returns.iter().any(|(_, expr)| { - expr.map(|expr| { +pub(super) fn result_exists(returns: &[&ast::StmtReturn]) -> bool { + returns.iter().any(|stmt| { + stmt.value.as_deref().map_or(false, |value| { !matches!( - expr, - Expr::Constant(ref constant) if constant.value.is_none() + value, + Expr::Constant(constant) if constant.value.is_none() ) }) - .unwrap_or(false) }) } @@ -26,12 +26,11 @@ pub(super) fn result_exists(returns: &[(&Stmt, Option<&Expr>)]) -> bool { /// This method assumes that the statement is the last statement in its body; specifically, that /// the statement isn't followed by a semicolon, followed by a multi-line statement. pub(super) fn end_of_last_statement(stmt: &Stmt, locator: &Locator) -> TextSize { - // End-of-file, so just return the end of the statement. if stmt.end() == locator.text_len() { + // End-of-file, so just return the end of the statement. stmt.end() - } - // Otherwise, find the end of the last line that's "part of" the statement. - else { + } else { + // Otherwise, find the end of the last line that's "part of" the statement. let contents = locator.after(stmt.end()); for line in contents.universal_newlines() { diff --git a/crates/ruff/src/rules/flake8_return/rules/function.rs b/crates/ruff/src/rules/flake8_return/rules/function.rs index 9bfb85ade3b4bb..c2affd8ed25668 100644 --- a/crates/ruff/src/rules/flake8_return/rules/function.rs +++ b/crates/ruff/src/rules/flake8_return/rules/function.rs @@ -1,5 +1,3 @@ -use itertools::Itertools; -use ruff_text_size::{TextRange, TextSize}; use rustpython_parser::ast::{self, Constant, Expr, Ranged, Stmt}; use ruff_diagnostics::{AlwaysAutofixableViolation, Violation}; @@ -139,8 +137,8 @@ impl AlwaysAutofixableViolation for ImplicitReturn { } /// ## What it does -/// Checks for variable assignments that are unused between the assignment and -/// a `return` of the variable. +/// Checks for variable assignments that immediately precede a `return` of the +/// assigned variable. /// /// ## Why is this bad? /// The variable assignment is not necessary as the value can be returned @@ -159,12 +157,15 @@ impl AlwaysAutofixableViolation for ImplicitReturn { /// return 1 /// ``` #[violation] -pub struct UnnecessaryAssign; +pub struct UnnecessaryAssign { + name: String, +} impl Violation for UnnecessaryAssign { #[derive_message_formats] fn message(&self) -> String { - format!("Unnecessary variable assignment before `return` statement") + let UnnecessaryAssign { name } = self; + format!("Unnecessary assignment to `{name}` before `return` statement") } } @@ -326,8 +327,8 @@ impl Violation for SuperfluousElseBreak { /// RET501 fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) { - for (stmt, expr) in &stack.returns { - let Some(expr) = expr else { + for stmt in &stack.returns { + let Some(expr) = stmt.value.as_deref() else { continue; }; if !matches!( @@ -339,10 +340,9 @@ fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) { ) { continue; } - let mut diagnostic = Diagnostic::new(UnnecessaryReturnNone, stmt.range()); + let mut diagnostic = Diagnostic::new(UnnecessaryReturnNone, stmt.range); if checker.patch(diagnostic.kind.rule()) { - #[allow(deprecated)] - diagnostic.set_fix(Fix::unspecified(Edit::range_replacement( + diagnostic.set_fix(Fix::automatic(Edit::range_replacement( "return".to_string(), stmt.range(), ))); @@ -353,16 +353,15 @@ fn unnecessary_return_none(checker: &mut Checker, stack: &Stack) { /// RET502 fn implicit_return_value(checker: &mut Checker, stack: &Stack) { - for (stmt, expr) in &stack.returns { - if expr.is_some() { + for stmt in &stack.returns { + if stmt.value.is_some() { continue; } - let mut diagnostic = Diagnostic::new(ImplicitReturnValue, stmt.range()); + let mut diagnostic = Diagnostic::new(ImplicitReturnValue, stmt.range); if checker.patch(diagnostic.kind.rule()) { - #[allow(deprecated)] - diagnostic.set_fix(Fix::unspecified(Edit::range_replacement( + diagnostic.set_fix(Fix::automatic(Edit::range_replacement( "return None".to_string(), - stmt.range(), + stmt.range, ))); } checker.diagnostics.push(diagnostic); @@ -417,8 +416,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) { content.push_str(checker.stylist.line_ending().as_str()); content.push_str(indent); content.push_str("return None"); - #[allow(deprecated)] - diagnostic.set_fix(Fix::unspecified(Edit::insertion( + diagnostic.set_fix(Fix::suggested(Edit::insertion( content, end_of_last_statement(stmt, checker.locator), ))); @@ -456,8 +454,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) { content.push_str(checker.stylist.line_ending().as_str()); content.push_str(indent); content.push_str("return None"); - #[allow(deprecated)] - diagnostic.set_fix(Fix::unspecified(Edit::insertion( + diagnostic.set_fix(Fix::suggested(Edit::insertion( content, end_of_last_statement(stmt, checker.locator), ))); @@ -494,8 +491,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) { content.push_str(checker.stylist.line_ending().as_str()); content.push_str(indent); content.push_str("return None"); - #[allow(deprecated)] - diagnostic.set_fix(Fix::unspecified(Edit::insertion( + diagnostic.set_fix(Fix::suggested(Edit::insertion( content, end_of_last_statement(stmt, checker.locator), ))); @@ -506,129 +502,51 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) { } } -/// Return `true` if the `id` has multiple declarations within the function. -fn has_multiple_declarations(id: &str, stack: &Stack) -> bool { - stack - .declarations - .get(&id) - .map_or(false, |declarations| declarations.len() > 1) -} - -/// Return `true` if the `id` has a (read) reference between the `return_location` and its -/// preceding declaration. -fn has_references_before_next_declaration( - id: &str, - return_range: TextRange, - stack: &Stack, -) -> bool { - let mut declaration_before_return: Option = None; - let mut declaration_after_return: Option = None; - if let Some(assignments) = stack.declarations.get(&id) { - for location in assignments.iter().sorted() { - if *location > return_range.start() { - declaration_after_return = Some(*location); - break; - } - declaration_before_return = Some(*location); - } - } - - // If there is no declaration before the return, then the variable must be declared in - // some other way (e.g., a function argument). No need to check for references. - let Some(declaration_before_return) = declaration_before_return else { - return true; - }; - - if let Some(references) = stack.references.get(&id) { - for location in references { - if return_range.contains(*location) { - continue; - } - - if declaration_before_return < *location { - if let Some(declaration_after_return) = declaration_after_return { - if *location <= declaration_after_return { - return true; - } - } else { - return true; - } - } - } - } +/// RET504 +fn unnecessary_assign(checker: &mut Checker, stack: &Stack) { + for (stmt_assign, stmt_return) in &stack.assignments { + // Identify, e.g., `return x`. + let Some(value) = stmt_return.value.as_ref() else { + continue; + }; - false -} + let Expr::Name(ast::ExprName { id: returned_id, .. }) = value.as_ref() else { + continue; + }; -/// Return `true` if the `id` has a read or write reference within a `try` or loop body. -fn has_references_or_declarations_within_try_or_loop(id: &str, stack: &Stack) -> bool { - if let Some(references) = stack.references.get(&id) { - for location in references { - for try_range in &stack.tries { - if try_range.contains(*location) { - return true; - } - } - for loop_range in &stack.loops { - if loop_range.contains(*location) { - return true; - } - } - } - } - if let Some(references) = stack.declarations.get(&id) { - for location in references { - for try_range in &stack.tries { - if try_range.contains(*location) { - return true; - } - } - for loop_range in &stack.loops { - if loop_range.contains(*location) { - return true; - } - } + // Identify, e.g., `x = 1`. + if stmt_assign.targets.len() > 1 { + continue; } - } - false -} -/// RET504 -fn unnecessary_assign(checker: &mut Checker, stack: &Stack, expr: &Expr) { - if let Expr::Name(ast::ExprName { id, .. }) = expr { - if !stack.assigned_names.contains(id.as_str()) { - return; - } + let Some(target) = stmt_assign.targets.first() else { + continue; + }; - if !stack.references.contains_key(id.as_str()) { - checker - .diagnostics - .push(Diagnostic::new(UnnecessaryAssign, expr.range())); - return; - } + let Expr::Name(ast::ExprName { id: assigned_id, .. }) = target else { + continue; + }; - if has_multiple_declarations(id, stack) - || has_references_before_next_declaration(id, expr.range(), stack) - || has_references_or_declarations_within_try_or_loop(id, stack) - { - return; + if returned_id != assigned_id { + continue; } - if stack.non_locals.contains(id.as_str()) { - return; + if stack.non_locals.contains(assigned_id.as_str()) { + continue; } - checker - .diagnostics - .push(Diagnostic::new(UnnecessaryAssign, expr.range())); + checker.diagnostics.push(Diagnostic::new( + UnnecessaryAssign { + name: assigned_id.to_string(), + }, + value.range(), + )); } } /// RET505, RET506, RET507, RET508 -fn superfluous_else_node(checker: &mut Checker, stmt: &Stmt, branch: Branch) -> bool { - let Stmt::If(ast::StmtIf { body, .. }) = stmt else { - return false; - }; +fn superfluous_else_node(checker: &mut Checker, stmt: &ast::StmtIf, branch: Branch) -> bool { + let ast::StmtIf { body, .. } = stmt; for child in body { if child.is_return_stmt() { let diagnostic = Diagnostic::new( @@ -708,7 +626,7 @@ pub(crate) fn function(checker: &mut Checker, body: &[Stmt], returns: Option<&Ex }; // Avoid false positives for generators. - if !stack.yields.is_empty() { + if stack.is_generator { return; } @@ -737,11 +655,7 @@ pub(crate) fn function(checker: &mut Checker, body: &[Stmt], returns: Option<&Ex } if checker.enabled(Rule::UnnecessaryAssign) { - for (_, expr) in &stack.returns { - if let Some(expr) = expr { - unnecessary_assign(checker, &stack, expr); - } - } + unnecessary_assign(checker, &stack); } } else { if checker.enabled(Rule::UnnecessaryReturnNone) { diff --git a/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET501_RET501.py.snap b/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET501_RET501.py.snap index ce6692be56a277..d0eeb67af00e5c 100644 --- a/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET501_RET501.py.snap +++ b/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET501_RET501.py.snap @@ -10,7 +10,7 @@ RET501.py:4:5: RET501 [*] Do not explicitly `return None` in function if it is t | = help: Remove explicit `return None` -ℹ Suggested fix +ℹ Fix 1 1 | def x(y): 2 2 | if not y: 3 3 | return @@ -29,7 +29,7 @@ RET501.py:14:9: RET501 [*] Do not explicitly `return None` in function if it is | = help: Remove explicit `return None` -ℹ Suggested fix +ℹ Fix 11 11 | 12 12 | def get(self, key: str) -> None: 13 13 | print(f"{key} not found") diff --git a/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET502_RET502.py.snap b/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET502_RET502.py.snap index 0265e4d405d3bb..6cc2760a270d4f 100644 --- a/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET502_RET502.py.snap +++ b/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET502_RET502.py.snap @@ -11,7 +11,7 @@ RET502.py:3:9: RET502 [*] Do not implicitly `return None` in function able to re | = help: Add explicit `None` return value -ℹ Suggested fix +ℹ Fix 1 1 | def x(y): 2 2 | if not y: 3 |- return # error diff --git a/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET504_RET504.py.snap b/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET504_RET504.py.snap index 69332b49762370..0d00ceb935f880 100644 --- a/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET504_RET504.py.snap +++ b/crates/ruff/src/rules/flake8_return/snapshots/ruff__rules__flake8_return__tests__RET504_RET504.py.snap @@ -1,7 +1,7 @@ --- source: crates/ruff/src/rules/flake8_return/mod.rs --- -RET504.py:6:12: RET504 Unnecessary variable assignment before `return` statement +RET504.py:6:12: RET504 Unnecessary assignment to `a` before `return` statement | 4 | def x(): 5 | a = 1 @@ -9,7 +9,39 @@ RET504.py:6:12: RET504 Unnecessary variable assignment before `return` statement | ^ RET504 | -RET504.py:250:12: RET504 Unnecessary variable assignment before `return` statement +RET504.py:23:12: RET504 Unnecessary assignment to `formatted` before `return` statement + | +21 | # clean up after any blank components +22 | formatted = formatted.replace("()", "").replace(" ", " ").strip() +23 | return formatted + | ^^^^^^^^^ RET504 + | + +RET504.py:219:12: RET504 Unnecessary assignment to `app` before `return` statement + | +217 | return "Hello, World!" +218 | +219 | return app + | ^^^ RET504 + | + +RET504.py:228:12: RET504 Unnecessary assignment to `y` before `return` statement + | +226 | return x +227 | +228 | return y + | ^ RET504 + | + +RET504.py:245:12: RET504 Unnecessary assignment to `queryset` before `return` statement + | +243 | queryset = Model.filter(a=1) +244 | queryset = queryset.filter(c=3) +245 | return queryset + | ^^^^^^^^ RET504 + | + +RET504.py:250:12: RET504 Unnecessary assignment to `queryset` before `return` statement | 248 | def get_queryset(): 249 | queryset = Model.filter(a=1) @@ -17,7 +49,7 @@ RET504.py:250:12: RET504 Unnecessary variable assignment before `return` stateme | ^^^^^^^^ RET504 | -RET504.py:268:12: RET504 Unnecessary variable assignment before `return` statement +RET504.py:268:12: RET504 Unnecessary assignment to `val` before `return` statement | 266 | return val 267 | val = 1 diff --git a/crates/ruff/src/rules/flake8_return/visitor.rs b/crates/ruff/src/rules/flake8_return/visitor.rs index c6ac93c4ced0ca..a869736c2826f4 100644 --- a/crates/ruff/src/rules/flake8_return/visitor.rs +++ b/crates/ruff/src/rules/flake8_return/visitor.rs @@ -1,118 +1,60 @@ -use ruff_text_size::{TextRange, TextSize}; -use rustc_hash::{FxHashMap, FxHashSet}; -use rustpython_parser::ast::{self, Expr, Identifier, Ranged, Stmt}; +use rustc_hash::FxHashSet; +use rustpython_parser::ast::{self, Expr, Identifier, Stmt}; use ruff_python_ast::visitor; use ruff_python_ast::visitor::Visitor; #[derive(Default)] pub(crate) struct Stack<'a> { - pub(crate) returns: Vec<(&'a Stmt, Option<&'a Expr>)>, - pub(crate) yields: Vec<&'a Expr>, - pub(crate) elses: Vec<&'a Stmt>, - pub(crate) elifs: Vec<&'a Stmt>, - /// The names that are assigned to in the current scope (e.g., anything on the left-hand side of - /// an assignment). - pub(crate) assigned_names: FxHashSet<&'a str>, - /// The names that are declared in the current scope, and the ranges of those declarations - /// (e.g., assignments, but also function and class definitions). - pub(crate) declarations: FxHashMap<&'a str, Vec>, - pub(crate) references: FxHashMap<&'a str, Vec>, + /// The `return` statements in the current function. + pub(crate) returns: Vec<&'a ast::StmtReturn>, + /// The `else` statements in the current function. + pub(crate) elses: Vec<&'a ast::StmtIf>, + /// The `elif` statements in the current function. + pub(crate) elifs: Vec<&'a ast::StmtIf>, + /// The non-local variables in the current function. pub(crate) non_locals: FxHashSet<&'a str>, - pub(crate) loops: Vec, - pub(crate) tries: Vec, + /// Whether the current function is a generator. + pub(crate) is_generator: bool, + /// The `assignment`-to-`return` statement pairs in the current function. + pub(crate) assignments: Vec<(&'a ast::StmtAssign, &'a ast::StmtReturn)>, } #[derive(Default)] pub(crate) struct ReturnVisitor<'a> { + /// The current stack of nodes. pub(crate) stack: Stack<'a>, + /// The preceding sibling of the current node. + sibling: Option<&'a Stmt>, + /// The parent nodes of the current node. parents: Vec<&'a Stmt>, } -impl<'a> ReturnVisitor<'a> { - fn visit_assign_target(&mut self, expr: &'a Expr) { - match expr { - Expr::Tuple(ast::ExprTuple { elts, .. }) => { - for elt in elts { - self.visit_assign_target(elt); - } - return; - } - Expr::Name(ast::ExprName { id, .. }) => { - self.stack.assigned_names.insert(id.as_str()); - self.stack - .declarations - .entry(id) - .or_insert_with(Vec::new) - .push(expr.start()); - return; - } - Expr::Attribute(_) => { - // Attribute assignments are often side-effects (e.g., `self.property = value`), - // so we conservatively treat them as references to every known - // variable. - for name in self.stack.declarations.keys() { - self.stack - .references - .entry(name) - .or_insert_with(Vec::new) - .push(expr.start()); - } - } - _ => {} - } - visitor::walk_expr(self, expr); - } -} - impl<'a> Visitor<'a> for ReturnVisitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { match stmt { - Stmt::Global(ast::StmtGlobal { names, range: _ }) - | Stmt::Nonlocal(ast::StmtNonlocal { names, range: _ }) => { - self.stack - .non_locals - .extend(names.iter().map(Identifier::as_str)); - } - Stmt::ClassDef(ast::StmtClassDef { - decorator_list, - name, - .. - }) => { - // Mark a declaration. - self.stack - .declarations - .entry(name.as_str()) - .or_insert_with(Vec::new) - .push(stmt.start()); - - // Don't recurse into the body, but visit the decorators, etc. + Stmt::ClassDef(ast::StmtClassDef { decorator_list, .. }) => { + // Visit the decorators, etc. for decorator in decorator_list { visitor::walk_decorator(self, decorator); } + + // But don't recurse into the body. + return; } Stmt::FunctionDef(ast::StmtFunctionDef { - name, args, decorator_list, returns, .. }) | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - name, args, decorator_list, returns, .. }) => { - // Mark a declaration. - self.stack - .declarations - .entry(name.as_str()) - .or_insert_with(Vec::new) - .push(stmt.start()); - - // Don't recurse into the body, but visit the decorators, etc. + // Visit the decorators, etc. for decorator in decorator_list { visitor::walk_decorator(self, decorator); } @@ -120,17 +62,26 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> { visitor::walk_expr(self, returns); } visitor::walk_arguments(self, args); + + // But don't recurse into the body. + return; } - Stmt::Return(ast::StmtReturn { value, range: _ }) => { + Stmt::Global(ast::StmtGlobal { names, range: _ }) + | Stmt::Nonlocal(ast::StmtNonlocal { names, range: _ }) => { self.stack - .returns - .push((stmt, value.as_ref().map(|expr| &**expr))); + .non_locals + .extend(names.iter().map(Identifier::as_str)); + } + Stmt::Return(stmt_return) => { + // If the `return` statement is preceded by an `assignment` statement, then the + // `assignment` statement may be redundant. + if let Some(stmt_assign) = self.sibling.and_then(Stmt::as_assign_stmt) { + self.stack.assignments.push((stmt_assign, stmt_return)); + } - self.parents.push(stmt); - visitor::walk_stmt(self, stmt); - self.parents.pop(); + self.stack.returns.push(stmt_return); } - Stmt::If(ast::StmtIf { orelse, .. }) => { + Stmt::If(stmt_if) => { let is_elif_arm = self.parents.iter().any(|parent| { if let Stmt::If(ast::StmtIf { orelse, .. }) = parent { orelse.len() == 1 && &orelse[0] == stmt @@ -141,88 +92,40 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> { if !is_elif_arm { let has_elif = - orelse.len() == 1 && matches!(orelse.first().unwrap(), Stmt::If(_)); - let has_else = !orelse.is_empty(); + stmt_if.orelse.len() == 1 && stmt_if.orelse.first().unwrap().is_if_stmt(); + let has_else = !stmt_if.orelse.is_empty(); if has_elif { // `stmt` is an `if` block followed by an `elif` clause. - self.stack.elifs.push(stmt); + self.stack.elifs.push(stmt_if); } else if has_else { // `stmt` is an `if` block followed by an `else` clause. - self.stack.elses.push(stmt); + self.stack.elses.push(stmt_if); } } - - self.parents.push(stmt); - visitor::walk_stmt(self, stmt); - self.parents.pop(); - } - Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { - if let Expr::Name(ast::ExprName { id, .. }) = value.as_ref() { - self.stack - .references - .entry(id) - .or_insert_with(Vec::new) - .push(value.start()); - } - - visitor::walk_expr(self, value); - - if let Some(target) = targets.first() { - // Skip unpacking assignments, like `x, y = my_object`. - if target.is_tuple_expr() && !value.is_tuple_expr() { - return; - } - - self.visit_assign_target(target); - } - } - Stmt::For(_) | Stmt::AsyncFor(_) | Stmt::While(_) => { - self.stack.loops.push(stmt.range()); - - self.parents.push(stmt); - visitor::walk_stmt(self, stmt); - self.parents.pop(); - } - Stmt::Try(_) | Stmt::TryStar(_) => { - self.stack.tries.push(stmt.range()); - - self.parents.push(stmt); - visitor::walk_stmt(self, stmt); - self.parents.pop(); - } - _ => { - self.parents.push(stmt); - visitor::walk_stmt(self, stmt); - self.parents.pop(); } + _ => {} } + + self.sibling = Some(stmt); + self.parents.push(stmt); + visitor::walk_stmt(self, stmt); + self.parents.pop(); } fn visit_expr(&mut self, expr: &'a Expr) { match expr { - Expr::Call(_) => { - // Arbitrary function calls can have side effects, so we conservatively treat - // every function call as a reference to every known variable. - for name in self.stack.declarations.keys() { - self.stack - .references - .entry(name) - .or_insert_with(Vec::new) - .push(expr.start()); - } - } - Expr::Name(ast::ExprName { id, .. }) => { - self.stack - .references - .entry(id) - .or_insert_with(Vec::new) - .push(expr.start()); - } Expr::YieldFrom(_) | Expr::Yield(_) => { - self.stack.yields.push(expr); + self.stack.is_generator = true; } _ => visitor::walk_expr(self, expr), } } + + fn visit_body(&mut self, body: &'a [Stmt]) { + let sibling = self.sibling; + self.sibling = None; + visitor::walk_body(self, body); + self.sibling = sibling; + } } diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index 6b4a2ddbff37cc..205a8cfc735ab1 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -1202,10 +1202,8 @@ pub fn first_colon_range(range: TextRange, locator: &Locator) -> Option Option { - let Stmt::If(ast::StmtIf { body, orelse, .. } )= stmt else { - return None; - }; +pub fn elif_else_range(stmt: &ast::StmtIf, locator: &Locator) -> Option { + let ast::StmtIf { body, orelse, .. } = stmt; let start = body.last().expect("Expected body to be non-empty").end(); @@ -1619,7 +1617,7 @@ mod tests { use anyhow::Result; use ruff_text_size::{TextLen, TextRange, TextSize}; - use rustpython_ast::Suite; + use rustpython_ast::{Stmt, Suite}; use rustpython_parser::ast::Cmpop; use rustpython_parser::Parse; @@ -1819,6 +1817,7 @@ elif b: .trim_start(); let program = Suite::parse(contents, "")?; let stmt = program.first().unwrap(); + let stmt = Stmt::as_if_stmt(stmt).unwrap(); let locator = Locator::new(contents); let range = elif_else_range(stmt, &locator).unwrap(); assert_eq!(range.start(), TextSize::from(14)); @@ -1833,6 +1832,7 @@ else: .trim_start(); let program = Suite::parse(contents, "")?; let stmt = program.first().unwrap(); + let stmt = Stmt::as_if_stmt(stmt).unwrap(); let locator = Locator::new(contents); let range = elif_else_range(stmt, &locator).unwrap(); assert_eq!(range.start(), TextSize::from(14));