Skip to content

Commit

Permalink
re-architect the tag visitor traits
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed Oct 4, 2022
1 parent e0a4915 commit d2552d2
Show file tree
Hide file tree
Showing 15 changed files with 306 additions and 222 deletions.
12 changes: 12 additions & 0 deletions src/concurrency/data_race.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,12 @@ pub struct VClockAlloc {
alloc_ranges: RefCell<RangeMap<MemoryCellClocks>>,
}

impl VisitTags for VClockAlloc {
fn visit_tags(&self, _visit: &mut dyn FnMut(SbTag)) {
// No tags here.
}
}

impl VClockAlloc {
/// Create a new data-race detector for newly allocated memory.
pub fn new_allocation(
Expand Down Expand Up @@ -1239,6 +1245,12 @@ pub struct GlobalState {
pub track_outdated_loads: bool,
}

impl VisitTags for GlobalState {
fn visit_tags(&self, _visit: &mut dyn FnMut(SbTag)) {
// We don't have any tags.
}
}

impl GlobalState {
/// Create a new global state, setup with just thread-id=0
/// advanced to timestamp = 1.
Expand Down
41 changes: 18 additions & 23 deletions src/concurrency/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub enum SchedulingAction {

/// Timeout callbacks can be created by synchronization primitives to tell the
/// scheduler that they should be called once some period of time passes.
pub trait MachineCallback<'mir, 'tcx>: VisitMachineValues {
pub trait MachineCallback<'mir, 'tcx>: VisitTags {
fn call(&self, ecx: &mut InterpCx<'mir, 'tcx, MiriMachine<'mir, 'tcx>>) -> InterpResult<'tcx>;
}

Expand Down Expand Up @@ -183,25 +183,21 @@ impl<'mir, 'tcx> Thread<'mir, 'tcx> {
}
}

impl VisitMachineValues for Thread<'_, '_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
impl VisitTags for Thread<'_, '_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let Thread { panic_payload, last_error, stack, state: _, thread_name: _, join_status: _ } =
self;

if let Some(payload) = panic_payload {
visit.visit(*payload);
}
if let Some(error) = last_error {
visit.visit(**error);
}
panic_payload.visit_tags(visit);
last_error.visit_tags(visit);
for frame in stack {
frame.visit_machine_values(visit)
frame.visit_tags(visit)
}
}
}

impl VisitMachineValues for Frame<'_, '_, Provenance, FrameData<'_>> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
impl VisitTags for Frame<'_, '_, Provenance, FrameData<'_>> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let Frame {
return_place,
locals,
Expand All @@ -210,21 +206,20 @@ impl VisitMachineValues for Frame<'_, '_, Provenance, FrameData<'_>> {
instance: _,
return_to_block: _,
loc: _,
// There are some private fields we cannot access; they contain no tags.
..
} = self;

// Return place.
if let Place::Ptr(mplace) = **return_place {
visit.visit(mplace);
}
return_place.visit_tags(visit);
// Locals.
for local in locals.iter() {
if let LocalValue::Live(value) = &local.value {
visit.visit(value);
value.visit_tags(visit);
}
}

extra.visit_machine_values(visit);
extra.visit_tags(visit);
}
}

Expand Down Expand Up @@ -300,8 +295,8 @@ impl<'mir, 'tcx> Default for ThreadManager<'mir, 'tcx> {
}
}

impl VisitMachineValues for ThreadManager<'_, '_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
impl VisitTags for ThreadManager<'_, '_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let ThreadManager {
threads,
thread_local_alloc_ids,
Expand All @@ -312,13 +307,13 @@ impl VisitMachineValues for ThreadManager<'_, '_> {
} = self;

for thread in threads {
thread.visit_machine_values(visit);
thread.visit_tags(visit);
}
for ptr in thread_local_alloc_ids.borrow().values().copied() {
visit.visit(ptr);
for ptr in thread_local_alloc_ids.borrow().values() {
ptr.visit_tags(visit);
}
for callback in timeout_callbacks.values() {
callback.callback.visit_machine_values(visit);
callback.callback.visit_tags(visit);
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/concurrency/weak_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ pub struct StoreBufferAlloc {
store_buffers: RefCell<RangeObjectMap<StoreBuffer>>,
}

impl VisitMachineValues for StoreBufferAlloc {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
for val in self
.store_buffers
impl VisitTags for StoreBufferAlloc {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let Self { store_buffers } = self;
for val in store_buffers
.borrow()
.iter()
.flat_map(|buf| buf.buffer.iter().map(|element| &element.val))
{
visit.visit(val);
val.visit_tags(visit);
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/intptrcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ pub struct GlobalStateInner {
provenance_mode: ProvenanceMode,
}

impl VisitTags for GlobalStateInner {
fn visit_tags(&self, _visit: &mut dyn FnMut(SbTag)) {
// Nothing to visit here.
}
}

impl GlobalStateInner {
pub fn new(config: &MiriConfig) -> Self {
GlobalStateInner {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub use crate::range_map::RangeMap;
pub use crate::stacked_borrows::{
CallId, EvalContextExt as StackedBorEvalContextExt, Item, Permission, SbTag, Stack, Stacks,
};
pub use crate::tag_gc::{EvalContextExt as _, ProvenanceVisitor, VisitMachineValues};
pub use crate::tag_gc::{EvalContextExt as _, VisitTags};

/// Insert rustc arguments at the beginning of the argument list that Miri wants to be
/// set per default, for maximal validation power.
Expand Down
93 changes: 58 additions & 35 deletions src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,12 @@ impl<'tcx> std::fmt::Debug for FrameData<'tcx> {
}
}

impl VisitMachineValues for FrameData<'_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
let FrameData { catch_unwind, stacked_borrows: _, timing: _ } = self;
impl VisitTags for FrameData<'_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let FrameData { catch_unwind, stacked_borrows, timing: _ } = self;

if let Some(catch_unwind) = catch_unwind {
catch_unwind.visit_machine_values(visit);
}
catch_unwind.visit_tags(visit);
stacked_borrows.visit_tags(visit);
}
}

Expand Down Expand Up @@ -261,17 +260,13 @@ pub struct AllocExtra {
pub weak_memory: Option<weak_memory::AllocExtra>,
}

impl VisitMachineValues for AllocExtra {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
let AllocExtra { stacked_borrows, data_race: _, weak_memory } = self;

if let Some(stacked_borrows) = stacked_borrows {
stacked_borrows.borrow().visit_machine_values(visit);
}
impl VisitTags for AllocExtra {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let AllocExtra { stacked_borrows, data_race, weak_memory } = self;

if let Some(weak_memory) = weak_memory {
weak_memory.visit_machine_values(visit);
}
stacked_borrows.visit_tags(visit);
data_race.visit_tags(visit);
weak_memory.visit_tags(visit);
}
}

Expand Down Expand Up @@ -615,8 +610,9 @@ impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
}
}

impl VisitMachineValues for MiriMachine<'_, '_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
impl VisitTags for MiriMachine<'_, '_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
#[rustfmt::skip]
let MiriMachine {
threads,
tls,
Expand All @@ -626,25 +622,52 @@ impl VisitMachineValues for MiriMachine<'_, '_> {
cmd_line,
extern_statics,
dir_handler,
..
stacked_borrows,
data_race,
intptrcast,
file_handler,
tcx: _,
isolated_op: _,
validate: _,
enforce_abi: _,
clock: _,
layouts: _,
static_roots: _,
profiler: _,
string_cache: _,
exported_symbols_cache: _,
panic_on_unsupported: _,
backtrace_style: _,
local_crates: _,
rng: _,
tracked_alloc_ids: _,
check_alignment: _,
cmpxchg_weak_failure_rate: _,
mute_stdout_stderr: _,
weak_memory: _,
preemption_rate: _,
report_progress: _,
basic_block_count: _,
#[cfg(unix)]
external_so_lib: _,
gc_interval: _,
since_gc: _,
num_cpus: _,
} = self;

threads.visit_machine_values(visit);
tls.visit_machine_values(visit);
env_vars.visit_machine_values(visit);
dir_handler.visit_machine_values(visit);

if let Some(argc) = argc {
visit.visit(argc);
}
if let Some(argv) = argv {
visit.visit(argv);
}
if let Some(cmd_line) = cmd_line {
visit.visit(cmd_line);
}
for ptr in extern_statics.values().copied() {
visit.visit(ptr);
threads.visit_tags(visit);
tls.visit_tags(visit);
env_vars.visit_tags(visit);
dir_handler.visit_tags(visit);
file_handler.visit_tags(visit);
data_race.visit_tags(visit);
stacked_borrows.visit_tags(visit);
intptrcast.visit_tags(visit);
argc.visit_tags(visit);
argv.visit_tags(visit);
cmd_line.visit_tags(visit);
for ptr in extern_statics.values() {
ptr.visit_tags(visit);
}
}
}
Expand Down
10 changes: 4 additions & 6 deletions src/shims/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ pub struct EnvVars<'tcx> {
pub(crate) environ: Option<MPlaceTy<'tcx, Provenance>>,
}

impl VisitMachineValues for EnvVars<'_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
impl VisitTags for EnvVars<'_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let EnvVars { map, environ } = self;

environ.visit_tags(visit);
for ptr in map.values() {
visit.visit(*ptr);
}
if let Some(env) = environ {
visit.visit(**env);
ptr.visit_tags(visit);
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions src/shims/panic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ pub struct CatchUnwindData<'tcx> {
ret: mir::BasicBlock,
}

impl VisitMachineValues for CatchUnwindData<'_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
let CatchUnwindData { catch_fn, data, dest: _, ret: _ } = self;
visit.visit(catch_fn);
visit.visit(data);
impl VisitTags for CatchUnwindData<'_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let CatchUnwindData { catch_fn, data, dest, ret: _ } = self;
catch_fn.visit_tags(visit);
data.visit_tags(visit);
dest.visit_tags(visit);
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/shims/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.register_timeout_callback(
active_thread,
Time::Monotonic(timeout_time),
Box::new(Callback { active_thread }),
Box::new(UnblockCallback { thread_to_unblock: active_thread }),
);

Ok(0)
Expand All @@ -242,24 +242,24 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.register_timeout_callback(
active_thread,
Time::Monotonic(timeout_time),
Box::new(Callback { active_thread }),
Box::new(UnblockCallback { thread_to_unblock: active_thread }),
);

Ok(())
}
}

struct Callback {
active_thread: ThreadId,
struct UnblockCallback {
thread_to_unblock: ThreadId,
}

impl VisitMachineValues for Callback {
fn visit_machine_values(&self, _visit: &mut ProvenanceVisitor) {}
impl VisitTags for UnblockCallback {
fn visit_tags(&self, _visit: &mut dyn FnMut(SbTag)) {}
}

impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback {
impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for UnblockCallback {
fn call(&self, ecx: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
ecx.unblock_thread(self.active_thread);
ecx.unblock_thread(self.thread_to_unblock);
Ok(())
}
}
8 changes: 4 additions & 4 deletions src/shims/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ impl<'tcx> TlsData<'tcx> {
}
}

impl VisitMachineValues for TlsData<'_> {
fn visit_machine_values(&self, visit: &mut ProvenanceVisitor) {
impl VisitTags for TlsData<'_> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let TlsData { keys, macos_thread_dtors, next_key: _, dtors_running: _ } = self;

for scalar in keys.values().flat_map(|v| v.data.values()) {
visit.visit(scalar);
scalar.visit_tags(visit);
}
for (_, scalar) in macos_thread_dtors.values() {
visit.visit(scalar);
scalar.visit_tags(visit);
}
}
}
Expand Down
Loading

0 comments on commit d2552d2

Please sign in to comment.