Skip to content

Commit

Permalink
Auto merge of #116199 - Urgau:simplify-invalid_ref_casting, r=cjgillot
Browse files Browse the repository at this point in the history
Simplify some of the logic in the `invalid_reference_casting` lint

This PR simplifies 2 areas of the logic for the `invalid_reference_casting` lint:
 - The init detection: we now use the newly added `expr_or_init` function instead of a manual detection
 - The ref-to-mut-ptr casting detection logic: I simplified this logic by caring less hardly about the order of the casting operations

Those two simplifications permits us to detect more cases, as can be seen in the test output changes.
  • Loading branch information
bors committed Sep 28, 2023
2 parents 42faef5 + 1b2c1a8 commit 1393ef1
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 134 deletions.
39 changes: 38 additions & 1 deletion compiler/rustc_lint/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,40 @@ impl<'tcx> LateContext<'tcx> {
})
}

/// If the given expression is a local binding, find the initializer expression.
/// If that initializer expression is another local binding, find its initializer again.
///
/// This process repeats as long as possible (but usually no more than once).
/// Type-check adjustments are not taken in account in this function.
///
/// Examples:
/// ```
/// let abc = 1;
/// let def = abc + 2;
/// // ^^^^^^^ output
/// let def = def;
/// dbg!(def);
/// // ^^^ input
/// ```
pub fn expr_or_init<'a>(&self, mut expr: &'a hir::Expr<'tcx>) -> &'a hir::Expr<'tcx> {
expr = expr.peel_blocks();

while let hir::ExprKind::Path(ref qpath) = expr.kind
&& let Some(parent_node) = match self.qpath_res(qpath, expr.hir_id) {
Res::Local(hir_id) => self.tcx.hir().find_parent(hir_id),
_ => None,
}
&& let Some(init) = match parent_node {
hir::Node::Expr(expr) => Some(expr),
hir::Node::Local(hir::Local { init, .. }) => *init,
_ => None
}
{
expr = init.peel_blocks();
}
expr
}

/// If the given expression is a local binding, find the initializer expression.
/// If that initializer expression is another local or **outside** (`const`/`static`)
/// binding, find its initializer again.
Expand All @@ -1338,7 +1372,10 @@ impl<'tcx> LateContext<'tcx> {
/// dbg!(def);
/// // ^^^ input
/// ```
pub fn expr_or_init<'a>(&self, mut expr: &'a hir::Expr<'tcx>) -> &'a hir::Expr<'tcx> {
pub fn expr_or_init_with_outside_body<'a>(
&self,
mut expr: &'a hir::Expr<'tcx>,
) -> &'a hir::Expr<'tcx> {
expr = expr.peel_blocks();

while let hir::ExprKind::Path(ref qpath) = expr.kind
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_lint/src/invalid_from_utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ impl<'tcx> LateLintPass<'tcx> for InvalidFromUtf8 {
)
};

let mut init = cx.expr_or_init(arg);
let mut init = cx.expr_or_init_with_outside_body(arg);
while let ExprKind::AddrOf(.., inner) = init.kind {
init = cx.expr_or_init(inner);
init = cx.expr_or_init_with_outside_body(inner);
}
match init.kind {
ExprKind::Lit(Spanned { node: lit, .. }) => {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_lint/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ late_lint_methods!(
BoxPointers: BoxPointers,
PathStatements: PathStatements,
LetUnderscore: LetUnderscore,
InvalidReferenceCasting: InvalidReferenceCasting::default(),
InvalidReferenceCasting: InvalidReferenceCasting,
// Depends on referenced function signatures in expressions
UnusedResults: UnusedResults,
NonUpperCaseGlobals: NonUpperCaseGlobals,
Expand Down
147 changes: 39 additions & 108 deletions compiler/rustc_lint/src/reference_casting.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use rustc_ast::Mutability;
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::{def::Res, Expr, ExprKind, HirId, Local, QPath, StmtKind, UnOp};
use rustc_hir::{Expr, ExprKind, UnOp};
use rustc_middle::ty::{self, TypeAndMut};
use rustc_span::{sym, Span};
use rustc_span::sym;

use crate::{lints::InvalidReferenceCastingDiag, LateContext, LateLintPass, LintContext};

Expand Down Expand Up @@ -34,38 +33,18 @@ declare_lint! {
"casts of `&T` to `&mut T` without interior mutability"
}

#[derive(Default)]
pub struct InvalidReferenceCasting {
casted: FxHashMap<HirId, Span>,
}

impl_lint_pass!(InvalidReferenceCasting => [INVALID_REFERENCE_CASTING]);
declare_lint_pass!(InvalidReferenceCasting => [INVALID_REFERENCE_CASTING]);

impl<'tcx> LateLintPass<'tcx> for InvalidReferenceCasting {
fn check_stmt(&mut self, cx: &LateContext<'tcx>, stmt: &'tcx rustc_hir::Stmt<'tcx>) {
let StmtKind::Local(local) = stmt.kind else {
return;
};
let Local { init: Some(init), els: None, .. } = local else {
return;
};

if is_cast_from_const_to_mut(cx, init) {
self.casted.insert(local.pat.hir_id, init.span);
}
}

fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'tcx>) {
let Some((is_assignment, e)) = is_operation_we_care_about(cx, expr) else {
return;
};

let orig_cast = if is_cast_from_const_to_mut(cx, e) {
None
} else if let ExprKind::Path(QPath::Resolved(_, path)) = e.kind
&& let Res::Local(hir_id) = &path.res
&& let Some(orig_cast) = self.casted.get(hir_id) {
Some(*orig_cast)
let init = cx.expr_or_init(e);

let orig_cast = if is_cast_from_const_to_mut(cx, init) {
if init.span != e.span { Some(init.span) } else { None }
} else {
return;
};
Expand Down Expand Up @@ -125,99 +104,51 @@ fn is_operation_we_care_about<'tcx>(
deref_assign_or_addr_of(e).or_else(|| ptr_write(cx, e))
}

fn is_cast_from_const_to_mut<'tcx>(cx: &LateContext<'tcx>, e: &'tcx Expr<'tcx>) -> bool {
let e = e.peel_blocks();
fn is_cast_from_const_to_mut<'tcx>(cx: &LateContext<'tcx>, orig_expr: &'tcx Expr<'tcx>) -> bool {
let mut need_check_freeze = false;
let mut e = orig_expr;

fn from_casts<'tcx>(
cx: &LateContext<'tcx>,
e: &'tcx Expr<'tcx>,
need_check_freeze: &mut bool,
) -> Option<&'tcx Expr<'tcx>> {
// <expr> as *mut ...
let mut e = if let ExprKind::Cast(e, t) = e.kind
&& let ty::RawPtr(TypeAndMut { mutbl: Mutability::Mut, .. }) = cx.typeck_results().node_type(t.hir_id).kind() {
e
// <expr>.cast_mut()
let end_ty = cx.typeck_results().node_type(orig_expr.hir_id);

// Bail out early if the end type is **not** a mutable pointer.
if !matches!(end_ty.kind(), ty::RawPtr(TypeAndMut { ty: _, mutbl: Mutability::Mut })) {
return false;
}

loop {
e = e.peel_blocks();
// <expr> as ...
e = if let ExprKind::Cast(expr, _) = e.kind {
expr
// <expr>.cast(), <expr>.cast_mut() or <expr>.cast_const()
} else if let ExprKind::MethodCall(_, expr, [], _) = e.kind
&& let Some(def_id) = cx.typeck_results().type_dependent_def_id(e.hir_id)
&& cx.tcx.is_diagnostic_item(sym::ptr_cast_mut, def_id) {
&& matches!(
cx.tcx.get_diagnostic_name(def_id),
Some(sym::ptr_cast | sym::const_ptr_cast | sym::ptr_cast_mut | sym::ptr_cast_const)
)
{
expr
// UnsafeCell::raw_get(<expr>)
// ptr::from_ref(<expr>), UnsafeCell::raw_get(<expr>) or mem::transmute<_, _>(<expr>)
} else if let ExprKind::Call(path, [arg]) = e.kind
&& let ExprKind::Path(ref qpath) = path.kind
&& let Some(def_id) = cx.qpath_res(qpath, path.hir_id).opt_def_id()
&& cx.tcx.is_diagnostic_item(sym::unsafe_cell_raw_get, def_id)
&& matches!(
cx.tcx.get_diagnostic_name(def_id),
Some(sym::ptr_from_ref | sym::unsafe_cell_raw_get | sym::transmute)
)
{
*need_check_freeze = true;
if cx.tcx.is_diagnostic_item(sym::unsafe_cell_raw_get, def_id) {
need_check_freeze = true;
}
arg
} else {
return None;
break;
};

let mut had_at_least_one_cast = false;
loop {
e = e.peel_blocks();
// <expr> as *mut/const ... or <expr> as <uint>
e = if let ExprKind::Cast(expr, t) = e.kind
&& matches!(cx.typeck_results().node_type(t.hir_id).kind(), ty::RawPtr(_) | ty::Uint(_)) {
had_at_least_one_cast = true;
expr
// <expr>.cast(), <expr>.cast_mut() or <expr>.cast_const()
} else if let ExprKind::MethodCall(_, expr, [], _) = e.kind
&& let Some(def_id) = cx.typeck_results().type_dependent_def_id(e.hir_id)
&& matches!(
cx.tcx.get_diagnostic_name(def_id),
Some(sym::ptr_cast | sym::const_ptr_cast | sym::ptr_cast_mut | sym::ptr_cast_const)
)
{
had_at_least_one_cast = true;
expr
// ptr::from_ref(<expr>) or UnsafeCell::raw_get(<expr>)
} else if let ExprKind::Call(path, [arg]) = e.kind
&& let ExprKind::Path(ref qpath) = path.kind
&& let Some(def_id) = cx.qpath_res(qpath, path.hir_id).opt_def_id()
&& matches!(
cx.tcx.get_diagnostic_name(def_id),
Some(sym::ptr_from_ref | sym::unsafe_cell_raw_get)
)
{
if cx.tcx.is_diagnostic_item(sym::unsafe_cell_raw_get, def_id) {
*need_check_freeze = true;
}
return Some(arg);
} else if had_at_least_one_cast {
return Some(e);
} else {
return None;
};
}
}

fn from_transmute<'tcx>(
cx: &LateContext<'tcx>,
e: &'tcx Expr<'tcx>,
) -> Option<&'tcx Expr<'tcx>> {
// mem::transmute::<_, *mut _>(<expr>)
if let ExprKind::Call(path, [arg]) = e.kind
&& let ExprKind::Path(ref qpath) = path.kind
&& let Some(def_id) = cx.qpath_res(qpath, path.hir_id).opt_def_id()
&& cx.tcx.is_diagnostic_item(sym::transmute, def_id)
&& let ty::RawPtr(TypeAndMut { mutbl: Mutability::Mut, .. }) = cx.typeck_results().node_type(e.hir_id).kind() {
Some(arg)
} else {
None
}
}

let mut need_check_freeze = false;
let Some(e) = from_casts(cx, e, &mut need_check_freeze).or_else(|| from_transmute(cx, e))
else {
return false;
};

let e = e.peel_blocks();
let node_type = cx.typeck_results().node_type(e.hir_id);
if let ty::Ref(_, inner_ty, Mutability::Not) = node_type.kind() {
let start_ty = cx.typeck_results().node_type(e.hir_id);
if let ty::Ref(_, inner_ty, Mutability::Not) = start_ty.kind() {
// If an UnsafeCell method is involved we need to additionaly check the
// inner type for the presence of the Freeze trait (ie does NOT contain
// an UnsafeCell), since in that case we would incorrectly lint on valid casts.
Expand Down
23 changes: 23 additions & 0 deletions tests/ui/lint/reference_casting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ unsafe fn ref_to_mut() {
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
let _num = &mut *std::mem::transmute::<_, *mut i32>(num);
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
let _num = &mut *(std::mem::transmute::<_, *mut i32>(num) as *mut i32);
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
let _num = &mut *std::cell::UnsafeCell::raw_get(
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
num as *const i32 as *const std::cell::UnsafeCell<i32>
Expand All @@ -47,8 +49,20 @@ unsafe fn ref_to_mut() {
let deferred = (std::ptr::from_ref(num) as *const i32 as *const i32).cast_mut() as *mut i32;
let _num = &mut *deferred;
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
let deferred_rebind = deferred;
let _num = &mut *deferred_rebind;
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
let _num = &mut *(num as *const _ as usize as *mut i32);
//~^ ERROR casting `&T` to `&mut T` is undefined behavior
let _num = &mut *(std::mem::transmute::<_, *mut _>(num as *const i32) as *mut i32);
//~^ ERROR casting `&T` to `&mut T` is undefined behavior

static NUM: &'static i32 = &2;
let num = NUM as *const i32 as *mut i32;
let num = num;
let num = num;
let _num = &mut *num;
//~^ ERROR casting `&T` to `&mut T` is undefined behavior

unsafe fn generic_ref_cast_mut<T>(this: &T) -> &mut T {
&mut *((this as *const _) as *mut _)
Expand Down Expand Up @@ -85,6 +99,8 @@ unsafe fn assign_to_ref() {
//~^ ERROR assigning to `&T` is undefined behavior
*std::mem::transmute::<_, *mut i32>(num) += 1;
//~^ ERROR assigning to `&T` is undefined behavior
*(std::mem::transmute::<_, *mut i32>(num) as *mut i32) += 1;
//~^ ERROR assigning to `&T` is undefined behavior
std::ptr::write(
//~^ ERROR assigning to `&T` is undefined behavior
std::mem::transmute::<*const i32, *mut i32>(num),
Expand All @@ -94,6 +110,9 @@ unsafe fn assign_to_ref() {
let value = num as *const i32 as *mut i32;
*value = 1;
//~^ ERROR assigning to `&T` is undefined behavior
let value_rebind = value;
*value_rebind = 1;
//~^ ERROR assigning to `&T` is undefined behavior
*(num as *const i32).cast::<i32>().cast_mut() = 2;
//~^ ERROR assigning to `&T` is undefined behavior
*(num as *const _ as usize as *mut i32) = 2;
Expand All @@ -111,6 +130,7 @@ unsafe fn assign_to_ref() {
}
}

const RAW_PTR: *mut u8 = 1 as *mut u8;
unsafe fn no_warn() {
let num = &3i32;
let mut_num = &mut 3i32;
Expand All @@ -125,6 +145,9 @@ unsafe fn no_warn() {
let mut value = 3;
let value: *const i32 = &mut value;
*(value as *const i16 as *mut i16) = 42;
*RAW_PTR = 42; // RAW_PTR is defined outside the function body,
// make sure we don't ICE on it when trying to
// determine if we should lint on it or not.

fn safe_as_mut<T>(x: &std::cell::UnsafeCell<T>) -> &mut T {
unsafe { &mut *std::cell::UnsafeCell::raw_get(x as *const _ as *const _) }
Expand Down
Loading

0 comments on commit 1393ef1

Please sign in to comment.