Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpret/allocation: fix aliasing issue in interpreter and refactor getters a bit #122537

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1157,11 +1157,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
};

// Side-step AllocRef and directly access the underlying bytes more efficiently.
// (We are staying inside the bounds here so all is good.)
// (We are staying inside the bounds here and all bytes do get overwritten so all is good.)
let alloc_id = alloc_ref.alloc_id;
let bytes = alloc_ref
.alloc
.get_bytes_mut(&alloc_ref.tcx, alloc_ref.range)
.get_bytes_unchecked_for_overwrite(&alloc_ref.tcx, alloc_ref.range)
.map_err(move |e| e.to_interp_error(alloc_id))?;
// `zip` would stop when the first iterator ends; we want to definitely
// cover all of `bytes`.
Expand All @@ -1182,6 +1182,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
self.mem_copy_repeatedly(src, dest, size, 1, nonoverlapping)
}

/// Performs `num_copies` many copies of `size` many bytes from `src` to `dest + i*size` (where
/// `i` is the index of the copy).
///
/// Either `nonoverlapping` must be true or `num_copies` must be 1; doing repeated copies that
/// may overlap is not supported.
pub fn mem_copy_repeatedly(
&mut self,
src: Pointer<Option<M::Provenance>>,
Expand Down Expand Up @@ -1243,8 +1248,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
(dest_alloc_id, dest_prov),
dest_range,
)?;
// Yes we do overwrite all bytes in `dest_bytes`.
let dest_bytes = dest_alloc
.get_bytes_mut_ptr(&tcx, dest_range)
.get_bytes_unchecked_for_overwrite_ptr(&tcx, dest_range)
.map_err(|e| e.to_interp_error(dest_alloc_id))?
.as_mut_ptr();

Expand Down Expand Up @@ -1278,6 +1284,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
}
}
}
if num_copies > 1 {
assert!(nonoverlapping, "multi-copy only supported in non-overlapping mode");
}

let size_in_bytes = size.bytes_usize();
// For particularly large arrays (where this is perf-sensitive) it's common that
Expand All @@ -1290,6 +1299,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
} else if src_alloc_id == dest_alloc_id {
let mut dest_ptr = dest_bytes;
for _ in 0..num_copies {
// Here we rely on `src` and `dest` being non-overlapping if there is more than
// one copy.
ptr::copy(src_bytes, dest_ptr, size_in_bytes);
dest_ptr = dest_ptr.add(size_in_bytes);
}
Expand Down
42 changes: 31 additions & 11 deletions compiler/rustc_middle/src/mir/interpret/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ pub trait AllocBytes:
/// Create a zeroed `AllocBytes` of the specified size and alignment.
/// Returns `None` if we ran out of memory on the host.
fn zeroed(size: Size, _align: Align) -> Option<Self>;

/// Gives direct access to the raw underlying storage.
///
/// Crucially this pointer is compatible with:
/// - other pointers retunred by this method, and
/// - references returned from `deref()`, as long as there was no write.
fn as_mut_ptr(&mut self) -> *mut u8;
}

// Default `bytes` for `Allocation` is a `Box<[u8]>`.
/// Default `bytes` for `Allocation` is a `Box<u8>`.
impl AllocBytes for Box<[u8]> {
fn from_bytes<'a>(slice: impl Into<Cow<'a, [u8]>>, _align: Align) -> Self {
Box::<[u8]>::from(slice.into())
Expand All @@ -51,6 +58,11 @@ impl AllocBytes for Box<[u8]> {
let bytes = unsafe { bytes.assume_init() };
Some(bytes)
}

fn as_mut_ptr(&mut self) -> *mut u8 {
// Carefully avoiding any intermediate references.
ptr::addr_of_mut!(**self).cast()
}
}

/// This type represents an Allocation in the Miri/CTFE core engine.
Expand Down Expand Up @@ -399,10 +411,6 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>

/// Byte accessors.
impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes> {
pub fn base_addr(&self) -> *const u8 {
self.bytes.as_ptr()
}
oli-obk marked this conversation as resolved.
Show resolved Hide resolved

/// This is the entirely abstraction-violating way to just grab the raw bytes without
/// caring about provenance or initialization.
///
Expand Down Expand Up @@ -452,13 +460,14 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
Ok(self.get_bytes_unchecked(range))
}

/// Just calling this already marks everything as defined and removes provenance,
/// so be sure to actually put data there!
/// This is the entirely abstraction-violating way to just get mutable access to the raw bytes.
/// Just calling this already marks everything as defined and removes provenance, so be sure to
/// actually overwrite all the data there!
///
/// It is the caller's responsibility to check bounds and alignment beforehand.
/// Most likely, you want to use the `PlaceTy` and `OperandTy`-based methods
/// on `InterpCx` instead.
pub fn get_bytes_mut(
pub fn get_bytes_unchecked_for_overwrite(
&mut self,
cx: &impl HasDataLayout,
range: AllocRange,
Expand All @@ -469,8 +478,9 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
Ok(&mut self.bytes[range.start.bytes_usize()..range.end().bytes_usize()])
}

/// A raw pointer variant of `get_bytes_mut` that avoids invalidating existing aliases into this memory.
pub fn get_bytes_mut_ptr(
/// A raw pointer variant of `get_bytes_unchecked_for_overwrite` that avoids invalidating existing immutable aliases
/// into this memory.
pub fn get_bytes_unchecked_for_overwrite_ptr(
&mut self,
cx: &impl HasDataLayout,
range: AllocRange,
Expand All @@ -479,10 +489,19 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
self.provenance.clear(range, cx)?;

assert!(range.end().bytes_usize() <= self.bytes.len()); // need to do our own bounds-check
// Cruciall, we go via `AllocBytes::as_mut_ptr`, not `AllocBytes::deref_mut`.
let begin_ptr = self.bytes.as_mut_ptr().wrapping_add(range.start.bytes_usize());
let len = range.end().bytes_usize() - range.start.bytes_usize();
Ok(ptr::slice_from_raw_parts_mut(begin_ptr, len))
}

/// This gives direct mutable access to the entire buffer, just exposing their internal state
/// without reseting anything. Directly exposes `AllocBytes::as_mut_ptr`. Only works if
/// `OFFSET_IS_ADDR` is true.
pub fn get_bytes_unchecked_raw_mut(&mut self) -> *mut u8 {
assert!(Prov::OFFSET_IS_ADDR);
self.bytes.as_mut_ptr()
}
}

/// Reading and writing.
Expand Down Expand Up @@ -589,7 +608,8 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
};

let endian = cx.data_layout().endian;
let dst = self.get_bytes_mut(cx, range)?;
// Yes we do overwrite all the bytes in `dst`.
let dst = self.get_bytes_unchecked_for_overwrite(cx, range)?;
write_target_uint(endian, dst, bytes).unwrap();

// See if we have to also store some provenance.
Expand Down
Loading