Skip to content

Commit

Permalink
interpret/allocation: fix aliasing issue in interpreter and refactor …
Browse files Browse the repository at this point in the history
…getters a bit

- rename mutating functions to be more scary
- add a new raw bytes getter
  • Loading branch information
RalfJung committed Mar 15, 2024
1 parent ee03c28 commit 21b8f06
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 22 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_const_eval/src/interpret/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ pub macro compile_time_machine(<$mir: lifetime, $tcx: lifetime>) {

type AllocExtra = ();
type FrameExtra = ();
type Bytes = Box<[u8]>;
type Bytes = Vec<u8>;

#[inline(always)]
fn ignore_optional_overflow_checks(_ecx: &InterpCx<$mir, $tcx, Self>) -> bool {
Expand Down
21 changes: 16 additions & 5 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ pub struct Memory<'mir, 'tcx, M: Machine<'mir, 'tcx>> {
/// A reference to some allocation that was already bounds-checked for the given region
/// and had the on-access machine hooks run.
#[derive(Copy, Clone)]
pub struct AllocRef<'a, 'tcx, Prov: Provenance, Extra, Bytes: AllocBytes = Box<[u8]>> {
pub struct AllocRef<'a, 'tcx, Prov: Provenance, Extra, Bytes: AllocBytes = Vec<u8>> {
alloc: &'a Allocation<Prov, Extra, Bytes>,
range: AllocRange,
tcx: TyCtxt<'tcx>,
alloc_id: AllocId,
}
/// A reference to some allocation that was already bounds-checked for the given region
/// and had the on-access machine hooks run.
pub struct AllocRefMut<'a, 'tcx, Prov: Provenance, Extra, Bytes: AllocBytes = Box<[u8]>> {
pub struct AllocRefMut<'a, 'tcx, Prov: Provenance, Extra, Bytes: AllocBytes = Vec<u8>> {
alloc: &'a mut Allocation<Prov, Extra, Bytes>,
range: AllocRange,
tcx: TyCtxt<'tcx>,
Expand Down Expand Up @@ -1157,11 +1157,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
};

// Side-step AllocRef and directly access the underlying bytes more efficiently.
// (We are staying inside the bounds here so all is good.)
// (We are staying inside the bounds here and all bytes do get overwritten so all is good.)
let alloc_id = alloc_ref.alloc_id;
let bytes = alloc_ref
.alloc
.get_bytes_mut(&alloc_ref.tcx, alloc_ref.range)
.get_bytes_unchecked_for_overwrite(&alloc_ref.tcx, alloc_ref.range)
.map_err(move |e| e.to_interp_error(alloc_id))?;
// `zip` would stop when the first iterator ends; we want to definitely
// cover all of `bytes`.
Expand All @@ -1182,6 +1182,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
self.mem_copy_repeatedly(src, dest, size, 1, nonoverlapping)
}

/// Performs `num_copies` many copies of `size` many bytes from `src` to `dest + i*size` (where
/// `i` is the index of the copy).
///
/// Either `nonoverlapping` must be true or `num_copies` must be 1; doing repeated copies that
/// may overlap is not supported.
pub fn mem_copy_repeatedly(
&mut self,
src: Pointer<Option<M::Provenance>>,
Expand Down Expand Up @@ -1243,8 +1248,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
(dest_alloc_id, dest_prov),
dest_range,
)?;
// Yes we do overwrite all bytes in `dest_bytes`.
let dest_bytes = dest_alloc
.get_bytes_mut_ptr(&tcx, dest_range)
.get_bytes_unchecked_for_overwrite_ptr(&tcx, dest_range)
.map_err(|e| e.to_interp_error(dest_alloc_id))?
.as_mut_ptr();

Expand Down Expand Up @@ -1278,6 +1284,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
}
}
}
if num_copies > 1 {
assert!(nonoverlapping, "multi-copy only supported in non-overlapping mode");
}

let size_in_bytes = size.bytes_usize();
// For particularly large arrays (where this is perf-sensitive) it's common that
Expand All @@ -1290,6 +1299,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
} else if src_alloc_id == dest_alloc_id {
let mut dest_ptr = dest_bytes;
for _ in 0..num_copies {
// Here we rely on `src` and `dest` being non-overlapping if there is more than
// one copy.
ptr::copy(src_bytes, dest_ptr, size_in_bytes);
dest_ptr = dest_ptr.add(size_in_bytes);
}
Expand Down
52 changes: 37 additions & 15 deletions compiler/rustc_middle/src/mir/interpret/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,33 @@ pub trait AllocBytes:
/// Create a zeroed `AllocBytes` of the specified size and alignment.
/// Returns `None` if we ran out of memory on the host.
fn zeroed(size: Size, _align: Align) -> Option<Self>;

/// Gives direct access to the raw underlying storage.
///
/// Crucially this pointer is compatible with:
/// - other pointers retunred by this method, and
/// - references returned from `deref()`, as long as there was no write.
fn as_mut_ptr(&mut self) -> *mut u8;
}

// Default `bytes` for `Allocation` is a `Box<[u8]>`.
impl AllocBytes for Box<[u8]> {
/// Default `bytes` for `Allocation` is a `Vec<u8>`.
///
/// We use `Vec`, not `Box`, since we need `Vec::as_mut_ptr` and how it interacts with other
/// pointers to the backing buffer. `Box` has no corresponding method.
impl AllocBytes for Vec<u8> {
fn from_bytes<'a>(slice: impl Into<Cow<'a, [u8]>>, _align: Align) -> Self {
Box::<[u8]>::from(slice.into())
slice.into().into_owned()
}

fn zeroed(size: Size, _align: Align) -> Option<Self> {
let bytes = Box::<[u8]>::try_new_zeroed_slice(size.bytes_usize()).ok()?;
// SAFETY: the box was zero-allocated, which is a valid initial value for Box<[u8]>
let bytes = unsafe { bytes.assume_init() };
Some(bytes)
Some(bytes.into())
}

fn as_mut_ptr(&mut self) -> *mut u8 {
Vec::as_mut_ptr(self)
}
}

Expand All @@ -62,7 +76,7 @@ impl AllocBytes for Box<[u8]> {
// hashed. (see the `Hash` impl below for more details), so the impl is not derived.
#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable)]
#[derive(HashStable)]
pub struct Allocation<Prov: Provenance = CtfeProvenance, Extra = (), Bytes = Box<[u8]>> {
pub struct Allocation<Prov: Provenance = CtfeProvenance, Extra = (), Bytes = Vec<u8>> {
/// The actual bytes of the allocation.
/// Note that the bytes of a pointer represent the offset of the pointer.
bytes: Bytes,
Expand Down Expand Up @@ -399,10 +413,6 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>

/// Byte accessors.
impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes> {
pub fn base_addr(&self) -> *const u8 {
self.bytes.as_ptr()
}

/// This is the entirely abstraction-violating way to just grab the raw bytes without
/// caring about provenance or initialization.
///
Expand Down Expand Up @@ -452,13 +462,14 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
Ok(self.get_bytes_unchecked(range))
}

/// Just calling this already marks everything as defined and removes provenance,
/// so be sure to actually put data there!
/// This is the entirely abstraction-violating way to just get mutable access to the raw bytes.
/// Just calling this already marks everything as defined and removes provenance, so be sure to
/// actually overwrite all the data there!
///
/// It is the caller's responsibility to check bounds and alignment beforehand.
/// Most likely, you want to use the `PlaceTy` and `OperandTy`-based methods
/// on `InterpCx` instead.
pub fn get_bytes_mut(
pub fn get_bytes_unchecked_for_overwrite(
&mut self,
cx: &impl HasDataLayout,
range: AllocRange,
Expand All @@ -469,8 +480,9 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
Ok(&mut self.bytes[range.start.bytes_usize()..range.end().bytes_usize()])
}

/// A raw pointer variant of `get_bytes_mut` that avoids invalidating existing aliases into this memory.
pub fn get_bytes_mut_ptr(
/// A raw pointer variant of `get_bytes_unchecked_for_overwrite` that avoids invalidating existing immutable aliases
/// into this memory.
pub fn get_bytes_unchecked_for_overwrite_ptr(
&mut self,
cx: &impl HasDataLayout,
range: AllocRange,
Expand All @@ -479,10 +491,19 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
self.provenance.clear(range, cx)?;

assert!(range.end().bytes_usize() <= self.bytes.len()); // need to do our own bounds-check
// Cruciall, we go via `AllocBytes::as_mut_ptr`, not `AllocBytes::deref_mut`.
let begin_ptr = self.bytes.as_mut_ptr().wrapping_add(range.start.bytes_usize());
let len = range.end().bytes_usize() - range.start.bytes_usize();
Ok(ptr::slice_from_raw_parts_mut(begin_ptr, len))
}

/// This gives direct mutable access to the entire buffer, just exposing their internal state
/// without reseting anything. Directly exposes `AllocBytes::as_mut_ptr`. Only works if
/// `OFFSET_IS_ADDR` is true.
pub fn get_bytes_unchecked_raw_mut(&mut self) -> *mut u8 {
assert!(Prov::OFFSET_IS_ADDR);
self.bytes.as_mut_ptr()
}
}

/// Reading and writing.
Expand Down Expand Up @@ -589,7 +610,8 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
};

let endian = cx.data_layout().endian;
let dst = self.get_bytes_mut(cx, range)?;
// Yes we do overwrite all the bytes in `dst`.
let dst = self.get_bytes_unchecked_for_overwrite(cx, range)?;
write_target_uint(endian, dst, bytes).unwrap();

// See if we have to also store some provenance.
Expand Down
2 changes: 1 addition & 1 deletion src/tools/miri/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {

type Provenance = Provenance;
type ProvenanceExtra = ProvenanceExtra;
type Bytes = Box<[u8]>;
type Bytes = Vec<u8>;

type MemoryMap = MonoHashMap<
AllocId,
Expand Down

0 comments on commit 21b8f06

Please sign in to comment.