diff --git a/src/librustc_mir/transform/nll/infer.rs b/src/librustc_mir/transform/nll/infer.rs new file mode 100644 index 0000000000000..e6e00f295ca11 --- /dev/null +++ b/src/librustc_mir/transform/nll/infer.rs @@ -0,0 +1,222 @@ +// Copyright 2017 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::{Region, RegionIndex}; +use std::mem; +use rustc::infer::InferCtxt; +use rustc::mir::{Location, Mir}; +use rustc_data_structures::indexed_vec::{Idx, IndexVec}; +use rustc_data_structures::fx::FxHashSet; + +pub struct InferenceContext { + definitions: IndexVec, + constraints: IndexVec, + errors: IndexVec, +} + +pub struct InferenceError { + pub constraint_point: Location, + pub name: (), // FIXME(nashenas88) RegionName +} + +newtype_index!(InferenceErrorIndex); + +struct VarDefinition { + name: (), // FIXME(nashenas88) RegionName + value: Region, + capped: bool, +} + +impl VarDefinition { + pub fn new(value: Region) -> Self { + Self { + name: (), + value, + capped: false, + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct Constraint { + sub: RegionIndex, + sup: RegionIndex, + point: Location, +} + +newtype_index!(ConstraintIndex); + +impl InferenceContext { + pub fn new(values: IndexVec) -> Self { + Self { + definitions: values.into_iter().map(VarDefinition::new).collect(), + constraints: IndexVec::new(), + errors: IndexVec::new(), + } + } + + #[allow(dead_code)] + pub fn cap_var(&mut self, v: RegionIndex) { + self.definitions[v].capped = true; + } + + #[allow(dead_code)] + pub fn add_live_point(&mut self, v: RegionIndex, point: Location) { + debug!("add_live_point({:?}, {:?})", v, point); + let definition = &mut self.definitions[v]; + if definition.value.add_point(point) { + if definition.capped { + self.errors.push(InferenceError { + constraint_point: point, + name: definition.name, + }); + } + } + } + + #[allow(dead_code)] + pub fn add_outlives(&mut self, sup: RegionIndex, sub: RegionIndex, point: Location) { + debug!("add_outlives({:?}: {:?} @ {:?}", sup, sub, point); + self.constraints.push(Constraint { sup, sub, point }); + } + + #[allow(dead_code)] + pub fn region(&self, v: RegionIndex) -> &Region { + &self.definitions[v].value + } + + pub fn solve<'a, 'gcx, 'tcx>( + &mut self, + infcx: &'a InferCtxt<'a, 'gcx, 'tcx>, + mir: &'a Mir<'tcx>, + ) -> IndexVec + where + 'gcx: 'tcx + 'a, + 'tcx: 'a, + { + let mut changed = true; + let mut dfs = Dfs::new(infcx, mir); + while changed { + changed = false; + for constraint in &self.constraints { + let sub = &self.definitions[constraint.sub].value.clone(); + let sup_def = &mut self.definitions[constraint.sup]; + debug!("constraint: {:?}", constraint); + debug!(" sub (before): {:?}", sub); + debug!(" sup (before): {:?}", sup_def.value); + + if dfs.copy(sub, &mut sup_def.value, constraint.point) { + changed = true; + if sup_def.capped { + // This is kind of a hack, but when we add a + // constraint, the "point" is always the point + // AFTER the action that induced the + // constraint. So report the error on the + // action BEFORE that. + assert!(constraint.point.statement_index > 0); + let p = Location { + block: constraint.point.block, + statement_index: constraint.point.statement_index - 1, + }; + + self.errors.push(InferenceError { + constraint_point: p, + name: sup_def.name, + }); + } + } + + debug!(" sup (after) : {:?}", sup_def.value); + debug!(" changed : {:?}", changed); + } + debug!("\n"); + } + + mem::replace(&mut self.errors, IndexVec::new()) + } +} + +struct Dfs<'a, 'gcx: 'tcx + 'a, 'tcx: 'a> { + #[allow(dead_code)] + infcx: &'a InferCtxt<'a, 'gcx, 'tcx>, + mir: &'a Mir<'tcx>, +} + +impl<'a, 'gcx: 'tcx, 'tcx: 'a> Dfs<'a, 'gcx, 'tcx> { + fn new(infcx: &'a InferCtxt<'a, 'gcx, 'tcx>, mir: &'a Mir<'tcx>) -> Self { + Self { infcx, mir } + } + + fn copy( + &mut self, + from_region: &Region, + to_region: &mut Region, + start_point: Location, + ) -> bool { + let mut changed = false; + + let mut stack = vec![]; + let mut visited = FxHashSet(); + + stack.push(start_point); + while let Some(p) = stack.pop() { + debug!(" dfs: p={:?}", p); + + if !from_region.may_contain(p) { + debug!(" not in from-region"); + continue; + } + + if !visited.insert(p) { + debug!(" already visited"); + continue; + } + + changed |= to_region.add_point(p); + + let block_data = &self.mir[p.block]; + let successor_points = if p.statement_index < block_data.statements.len() { + vec![Location { + statement_index: p.statement_index + 1, + ..p + }] + } else { + block_data.terminator() + .successors() + .iter() + .map(|&basic_block| Location { + statement_index: 0, + block: basic_block, + }) + .collect::>() + }; + + if successor_points.is_empty() { + // FIXME handle free regions + // If we reach the END point in the graph, then copy + // over any skolemized end points in the `from_region` + // and make sure they are included in the `to_region`. + // for region_decl in self.infcx.tcx.tables.borrow().free_region_map() { + // // FIXME(nashenas88) figure out skolemized_end points + // let block = self.env.graph.skolemized_end(region_decl.name); + // let skolemized_end_point = Location { + // block, + // statement_index: 0, + // }; + // changed |= to_region.add_point(skolemized_end_point); + // } + } else { + stack.extend(successor_points); + } + } + + changed + } +} diff --git a/src/librustc_mir/transform/nll/mod.rs b/src/librustc_mir/transform/nll/mod.rs index 4925b1fcfed28..805e9c976e4f0 100644 --- a/src/librustc_mir/transform/nll/mod.rs +++ b/src/librustc_mir/transform/nll/mod.rs @@ -8,13 +8,14 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use self::infer::InferenceContext; use rustc::ty::TypeFoldable; use rustc::ty::subst::{Kind, Substs}; use rustc::ty::{Ty, TyCtxt, ClosureSubsts, RegionVid, RegionKind}; use rustc::mir::{Mir, Location, Rvalue, BasicBlock, Statement, StatementKind}; use rustc::mir::visit::{MutVisitor, Lookup}; use rustc::mir::transform::{MirPass, MirSource}; -use rustc::infer::{self, InferCtxt}; +use rustc::infer::{self as rustc_infer, InferCtxt}; use rustc::util::nodemap::FxHashSet; use rustc_data_structures::indexed_vec::{IndexVec, Idx}; use syntax_pos::DUMMY_SP; @@ -24,15 +25,18 @@ use std::fmt; use util as mir_util; use self::mir_util::PassWhere; +mod infer; + #[allow(dead_code)] struct NLLVisitor<'a, 'gcx: 'a + 'tcx, 'tcx: 'a> { lookup_map: HashMap, regions: IndexVec, - infcx: InferCtxt<'a, 'gcx, 'tcx>, + #[allow(dead_code)] + infcx: &'a InferCtxt<'a, 'gcx, 'tcx>, } impl<'a, 'gcx, 'tcx> NLLVisitor<'a, 'gcx, 'tcx> { - pub fn new(infcx: InferCtxt<'a, 'gcx, 'tcx>) -> Self { + pub fn new(infcx: &'a InferCtxt<'a, 'gcx, 'tcx>) -> Self { NLLVisitor { infcx, lookup_map: HashMap::new(), @@ -40,14 +44,14 @@ impl<'a, 'gcx, 'tcx> NLLVisitor<'a, 'gcx, 'tcx> { } } - pub fn into_results(self) -> HashMap { - self.lookup_map + pub fn into_results(self) -> (HashMap, IndexVec) { + (self.lookup_map, self.regions) } fn renumber_regions(&mut self, value: &T) -> T where T: TypeFoldable<'tcx> { self.infcx.tcx.fold_regions(value, &mut false, |_region, _depth| { self.regions.push(Region::default()); - self.infcx.next_region_var(infer::MiscVariable(DUMMY_SP)) + self.infcx.next_region_var(rustc_infer::MiscVariable(DUMMY_SP)) }) } @@ -147,7 +151,7 @@ impl MirPass for NLL { tcx.infer_ctxt().enter(|infcx| { // Clone mir so we can mutate it without disturbing the rest of the compiler let mut renumbered_mir = mir.clone(); - let mut visitor = NLLVisitor::new(infcx); + let mut visitor = NLLVisitor::new(&infcx); visitor.visit_mir(&mut renumbered_mir); mir_util::dump_mir(tcx, None, "nll", &0, source, mir, |pass_where, out| { if let PassWhere::BeforeCFG = pass_where { @@ -157,13 +161,15 @@ impl MirPass for NLL { } Ok(()) }); - let _results = visitor.into_results(); + let (_lookup_map, regions) = visitor.into_results(); + let mut inference_context = InferenceContext::new(regions); + inference_context.solve(&infcx, &renumbered_mir); }) } } #[derive(Clone, Default, PartialEq, Eq)] -struct Region { +pub struct Region { points: FxHashSet, } @@ -173,6 +179,14 @@ impl fmt::Debug for Region { } } +impl Region { + pub fn add_point(&mut self, point: Location) -> bool { + self.points.insert(point) + } + pub fn may_contain(&self, point: Location) -> bool { + self.points.contains(&point) + } +} newtype_index!(RegionIndex);