Skip to content

Commit

Permalink
Move round_* functions from shims::x86::sse41 module to shims::x86
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardosm committed Dec 7, 2023
1 parent c37f4d6 commit 8c5882e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 84 deletions.
83 changes: 83 additions & 0 deletions src/tools/miri/src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,89 @@ fn unary_op_ps<'tcx>(
Ok(())
}

// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
let res = op0.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, 0)?,
)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

// Rounds all elements of `op` according to `rounding`.
fn round_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

for i in 0..dest_len {
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
let res = op.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, i)?,
)?;
}

Ok(())
}

/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
/// `round.{ss,sd,ps,pd}` intrinsics.
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
// The fourth bit of `rounding` only affects the SSE status
// register, which cannot be accessed from Miri (or from Rust,
// for that matter), so we can ignore it.
match rounding & !0b1000 {
// When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
0b011 => Ok(rustc_apfloat::Round::TowardZero),
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}
}

/// Converts each element of `op` from floating point to signed integer.
///
/// When the input value is NaN or out of range, fall back to minimum value.
Expand Down
85 changes: 1 addition & 84 deletions src/tools/miri/src/shims/x86/sse41.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use rustc_middle::mir;
use rustc_span::Symbol;
use rustc_target::abi::Size;
use rustc_target::spec::abi::Abi;

use super::{round_all, round_first};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -283,86 +283,3 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
Ok(EmulateForeignItemResult::NeedsJumping)
}
}

// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
let res = op0.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, 0)?,
)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

// Rounds all elements of `op` according to `rounding`.
fn round_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

for i in 0..dest_len {
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
let res = op.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, i)?,
)?;
}

Ok(())
}

/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
/// `round.{ss,sd,ps,pd}` intrinsics.
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
// The fourth bit of `rounding` only affects the SSE status
// register, which cannot be accessed from Miri (or from Rust,
// for that matter), so we can ignore it.
match rounding & !0b1000 {
// When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
0b011 => Ok(rustc_apfloat::Round::TowardZero),
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}
}

0 comments on commit 8c5882e

Please sign in to comment.