Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assume value ranges in transmute #109993

Merged
merged 2 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rustc_middle::mir::Operand;
use rustc_middle::ty::cast::{CastTy, IntTy};
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, adjustment::PointerCast, Instance, Ty, TyCtxt};
use rustc_session::config::OptLevel;
use rustc_span::source_map::{Span, DUMMY_SP};
use rustc_target::abi::{self, FIRST_VARIANT};

Expand Down Expand Up @@ -231,10 +232,16 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
(ScalarOrZst::Scalar(in_scalar), ScalarOrZst::Scalar(out_scalar))
if in_scalar.size(self.cx) == out_scalar.size(self.cx) =>
{
let operand_bty = bx.backend_type(operand.layout);
let cast_bty = bx.backend_type(cast);
Some(OperandValue::Immediate(
self.transmute_immediate(bx, imm, in_scalar, out_scalar, cast_bty),
))
Some(OperandValue::Immediate(self.transmute_immediate(
bx,
imm,
in_scalar,
operand_bty,
out_scalar,
cast_bty,
)))
}
_ => None,
}
Expand All @@ -250,11 +257,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
&& in_a.size(self.cx) == out_a.size(self.cx)
&& in_b.size(self.cx) == out_b.size(self.cx)
{
let in_a_ibty = bx.scalar_pair_element_backend_type(operand.layout, 0, false);
let in_b_ibty = bx.scalar_pair_element_backend_type(operand.layout, 1, false);
let out_a_ibty = bx.scalar_pair_element_backend_type(cast, 0, false);
let out_b_ibty = bx.scalar_pair_element_backend_type(cast, 1, false);
Some(OperandValue::Pair(
self.transmute_immediate(bx, imm_a, in_a, out_a, out_a_ibty),
self.transmute_immediate(bx, imm_b, in_b, out_b, out_b_ibty),
self.transmute_immediate(bx, imm_a, in_a, in_a_ibty, out_a, out_a_ibty),
self.transmute_immediate(bx, imm_b, in_b, in_b_ibty, out_b, out_b_ibty),
))
} else {
None
Expand All @@ -273,13 +282,21 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
bx: &mut Bx,
mut imm: Bx::Value,
from_scalar: abi::Scalar,
from_backend_ty: Bx::Type,
to_scalar: abi::Scalar,
to_backend_ty: Bx::Type,
) -> Bx::Value {
debug_assert_eq!(from_scalar.size(self.cx), to_scalar.size(self.cx));

use abi::Primitive::*;
imm = bx.from_immediate(imm);

// When scalars are passed by value, there's no metadata recording their
// valid ranges. For example, `char`s are passed as just `i32`, with no
// way for LLVM to know that they're 0x10FFFF at most. Thus we assume
// the range of the input value too, not just the output range.
self.assume_scalar_range(bx, imm, from_scalar, from_backend_ty);
oli-obk marked this conversation as resolved.
Show resolved Hide resolved

imm = match (from_scalar.primitive(), to_scalar.primitive()) {
(Int(..) | F32 | F64, Int(..) | F32 | F64) => bx.bitcast(imm, to_backend_ty),
(Pointer(..), Pointer(..)) => bx.pointercast(imm, to_backend_ty),
Expand All @@ -294,10 +311,55 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
bx.bitcast(int_imm, to_backend_ty)
}
};
self.assume_scalar_range(bx, imm, to_scalar, to_backend_ty);
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
imm = bx.to_immediate_scalar(imm, to_scalar);
imm
}

fn assume_scalar_range(
&self,
bx: &mut Bx,
imm: Bx::Value,
scalar: abi::Scalar,
backend_ty: Bx::Type,
) {
if matches!(self.cx.sess().opts.optimize, OptLevel::No | OptLevel::Less)
// For now, the critical niches are all over `Int`eger values.
// Should floating-point values or pointers ever get more complex
// niches, then this code will probably want to handle them too.
|| !matches!(scalar.primitive(), abi::Primitive::Int(..))
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
|| scalar.is_always_valid(self.cx)
{
return;
}

let abi::WrappingRange { start, end } = scalar.valid_range(self.cx);

if start <= end {
if start > 0 {
let low = bx.const_uint_big(backend_ty, start);
let cmp = bx.icmp(IntPredicate::IntUGE, imm, low);
bx.assume(cmp);
}

let type_max = scalar.size(self.cx).unsigned_int_max();
if end < type_max {
let high = bx.const_uint_big(backend_ty, end);
let cmp = bx.icmp(IntPredicate::IntULE, imm, high);
bx.assume(cmp);
}
} else {
let low = bx.const_uint_big(backend_ty, start);
let cmp_low = bx.icmp(IntPredicate::IntUGE, imm, low);

let high = bx.const_uint_big(backend_ty, end);
let cmp_high = bx.icmp(IntPredicate::IntULE, imm, high);

let or = bx.or(cmp_low, cmp_high);
bx.assume(or);
}
}

pub fn codegen_rvalue_unsized(
&mut self,
bx: &mut Bx,
Expand Down
184 changes: 184 additions & 0 deletions tests/codegen/intrinsics/transmute-niched.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// revisions: OPT DBG
// [OPT] compile-flags: -C opt-level=3 -C no-prepopulate-passes
// [DBG] compile-flags: -C opt-level=0 -C no-prepopulate-passes
// only-64bit (so I don't need to worry about usize)
// min-llvm-version: 15.0 # this test assumes `ptr`s

#![crate_type = "lib"]

use std::mem::transmute;
use std::num::NonZeroU32;

#[repr(u8)]
pub enum SmallEnum {
A = 10,
B = 11,
C = 12,
}

// CHECK-LABEL: @check_to_enum(
#[no_mangle]
pub unsafe fn check_to_enum(x: i8) -> SmallEnum {
// OPT: %0 = icmp uge i8 %x, 10
// OPT: call void @llvm.assume(i1 %0)
// OPT: %1 = icmp ule i8 %x, 12
// OPT: call void @llvm.assume(i1 %1)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i8 %x

transmute(x)
}

// CHECK-LABEL: @check_from_enum(
#[no_mangle]
pub unsafe fn check_from_enum(x: SmallEnum) -> i8 {
// OPT: %0 = icmp uge i8 %x, 10
// OPT: call void @llvm.assume(i1 %0)
// OPT: %1 = icmp ule i8 %x, 12
// OPT: call void @llvm.assume(i1 %1)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i8 %x

transmute(x)
}

// CHECK-LABEL: @check_to_ordering(
#[no_mangle]
pub unsafe fn check_to_ordering(x: u8) -> std::cmp::Ordering {
// OPT: %0 = icmp uge i8 %x, -1
// OPT: %1 = icmp ule i8 %x, 1
// OPT: %2 = or i1 %0, %1
// OPT: call void @llvm.assume(i1 %2)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i8 %x

transmute(x)
}

// CHECK-LABEL: @check_from_ordering(
#[no_mangle]
pub unsafe fn check_from_ordering(x: std::cmp::Ordering) -> u8 {
// OPT: %0 = icmp uge i8 %x, -1
// OPT: %1 = icmp ule i8 %x, 1
// OPT: %2 = or i1 %0, %1
// OPT: call void @llvm.assume(i1 %2)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i8 %x

transmute(x)
}

#[repr(i32)]
pub enum Minus100ToPlus100 {
A = -100,
B = -90,
C = -80,
D = -70,
E = -60,
F = -50,
G = -40,
H = -30,
I = -20,
J = -10,
K = 0,
L = 10,
M = 20,
N = 30,
O = 40,
P = 50,
Q = 60,
R = 70,
S = 80,
T = 90,
U = 100,
}

// CHECK-LABEL: @check_enum_from_char(
#[no_mangle]
pub unsafe fn check_enum_from_char(x: char) -> Minus100ToPlus100 {
// OPT: %0 = icmp ule i32 %x, 1114111
// OPT: call void @llvm.assume(i1 %0)
// OPT: %1 = icmp uge i32 %x, -100
// OPT: %2 = icmp ule i32 %x, 100
// OPT: %3 = or i1 %1, %2
// OPT: call void @llvm.assume(i1 %3)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i32 %x

transmute(x)
}

// CHECK-LABEL: @check_enum_to_char(
#[no_mangle]
pub unsafe fn check_enum_to_char(x: Minus100ToPlus100) -> char {
// OPT: %0 = icmp uge i32 %x, -100
// OPT: %1 = icmp ule i32 %x, 100
// OPT: %2 = or i1 %0, %1
// OPT: call void @llvm.assume(i1 %2)
// OPT: %3 = icmp ule i32 %x, 1114111
// OPT: call void @llvm.assume(i1 %3)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i32 %x

transmute(x)
}

// CHECK-LABEL: @check_swap_pair(
#[no_mangle]
pub unsafe fn check_swap_pair(x: (char, NonZeroU32)) -> (NonZeroU32, char) {
// OPT: %0 = icmp ule i32 %x.0, 1114111
// OPT: call void @llvm.assume(i1 %0)
// OPT: %1 = icmp uge i32 %x.0, 1
// OPT: call void @llvm.assume(i1 %1)
// OPT: %2 = icmp uge i32 %x.1, 1
// OPT: call void @llvm.assume(i1 %2)
// OPT: %3 = icmp ule i32 %x.1, 1114111
// OPT: call void @llvm.assume(i1 %3)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: %[[P1:.+]] = insertvalue { i32, i32 } poison, i32 %x.0, 0
// CHECK: %[[P2:.+]] = insertvalue { i32, i32 } %[[P1]], i32 %x.1, 1
// CHECK: ret { i32, i32 } %[[P2]]

transmute(x)
}

// CHECK-LABEL: @check_bool_from_ordering(
#[no_mangle]
pub unsafe fn check_bool_from_ordering(x: std::cmp::Ordering) -> bool {
// OPT: %0 = icmp uge i8 %x, -1
// OPT: %1 = icmp ule i8 %x, 1
// OPT: %2 = or i1 %0, %1
// OPT: call void @llvm.assume(i1 %2)
// OPT: %3 = icmp ule i8 %x, 1
// OPT: call void @llvm.assume(i1 %3)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: %[[R:.+]] = trunc i8 %x to i1
// CHECK: ret i1 %[[R]]

transmute(x)
}

// CHECK-LABEL: @check_bool_to_ordering(
#[no_mangle]
pub unsafe fn check_bool_to_ordering(x: bool) -> std::cmp::Ordering {
// CHECK: %0 = zext i1 %x to i8
// OPT: %1 = icmp ule i8 %0, 1
// OPT: call void @llvm.assume(i1 %1)
// OPT: %2 = icmp uge i8 %0, -1
// OPT: %3 = icmp ule i8 %0, 1
// OPT: %4 = or i1 %2, %3
// OPT: call void @llvm.assume(i1 %4)
// DBG-NOT: icmp
// DBG-NOT: assume
// CHECK: ret i8 %0

transmute(x)
}
8 changes: 4 additions & 4 deletions tests/codegen/intrinsics/transmute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,17 @@ pub unsafe fn check_aggregate_from_bool(x: bool) -> Aggregate8 {
#[no_mangle]
pub unsafe fn check_byte_to_bool(x: u8) -> bool {
// CHECK-NOT: alloca
// CHECK: %0 = trunc i8 %x to i1
// CHECK: ret i1 %0
// CHECK: %[[R:.+]] = trunc i8 %x to i1
// CHECK: ret i1 %[[R]]
transmute(x)
}

// CHECK-LABEL: @check_byte_from_bool(
#[no_mangle]
pub unsafe fn check_byte_from_bool(x: bool) -> u8 {
// CHECK-NOT: alloca
// CHECK: %0 = zext i1 %x to i8
// CHECK: ret i8 %0
// CHECK: %[[R:.+]] = zext i1 %x to i8
// CHECK: ret i8 %[[R:.+]]
transmute(x)
}

Expand Down
Loading