Skip to content

Commit

Permalink
Auto merge of rust-lang#106395 - compiler-errors:rework-predicates, r…
Browse files Browse the repository at this point in the history
…=eholk

Rework some `predicates_of`/`{Generic,Instantiated}Predicates` code

1. Make `instantiate_own` return an iterator, since it's a bit more efficient and easier to work with
2. Remove `bound_{explicit,}_predicates_of` -- these `bound_` methods in particular were a bit awkward to work with since `ty::GenericPredicates` *already* acts kinda like an `EarlyBinder` with its own `instantiate_*` methods, and had only a few call sites anyways.
3. Implement `IntoIterator` for `InstantiatedPredicates`, since it's *very* commonly being `zip`'d together.
  • Loading branch information
bors committed Jan 16, 2023
2 parents 41edaac + 90df86f commit d12412c
Show file tree
Hide file tree
Showing 22 changed files with 155 additions and 156 deletions.
48 changes: 21 additions & 27 deletions compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,40 +673,34 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
let tcx = self.infcx.tcx;

// Find out if the predicates show that the type is a Fn or FnMut
let find_fn_kind_from_did = |predicates: ty::EarlyBinder<
&[(ty::Predicate<'tcx>, Span)],
>,
substs| {
predicates.0.iter().find_map(|(pred, _)| {
let pred = if let Some(substs) = substs {
predicates.rebind(*pred).subst(tcx, substs).kind().skip_binder()
} else {
pred.kind().skip_binder()
};
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = pred && pred.self_ty() == ty {
if Some(pred.def_id()) == tcx.lang_items().fn_trait() {
return Some(hir::Mutability::Not);
} else if Some(pred.def_id()) == tcx.lang_items().fn_mut_trait() {
return Some(hir::Mutability::Mut);
}
let find_fn_kind_from_did = |(pred, _): (ty::Predicate<'tcx>, _)| {
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = pred.kind().skip_binder()
&& pred.self_ty() == ty
{
if Some(pred.def_id()) == tcx.lang_items().fn_trait() {
return Some(hir::Mutability::Not);
} else if Some(pred.def_id()) == tcx.lang_items().fn_mut_trait() {
return Some(hir::Mutability::Mut);
}
None
})
}
None
};

// If the type is opaque/param/closure, and it is Fn or FnMut, let's suggest (mutably)
// borrowing the type, since `&mut F: FnMut` iff `F: FnMut` and similarly for `Fn`.
// These types seem reasonably opaque enough that they could be substituted with their
// borrowed variants in a function body when we see a move error.
let borrow_level = match ty.kind() {
ty::Param(_) => find_fn_kind_from_did(
tcx.bound_explicit_predicates_of(self.mir_def_id().to_def_id())
.map_bound(|p| p.predicates),
None,
),
ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => {
find_fn_kind_from_did(tcx.bound_explicit_item_bounds(*def_id), Some(*substs))
}
let borrow_level = match *ty.kind() {
ty::Param(_) => tcx
.explicit_predicates_of(self.mir_def_id().to_def_id())
.predicates
.iter()
.copied()
.find_map(find_fn_kind_from_did),
ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => tcx
.bound_explicit_item_bounds(def_id)
.subst_iter_copied(tcx, substs)
.find_map(find_fn_kind_from_did),
ty::Closure(_, substs) => match substs.as_closure().kind() {
ty::ClosureKind::Fn => Some(hir::Mutability::Not),
ty::ClosureKind::FnMut => Some(hir::Mutability::Mut),
Expand Down
6 changes: 1 addition & 5 deletions compiler/rustc_borrowck/src/type_check/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
instantiated_predicates: ty::InstantiatedPredicates<'tcx>,
locations: Locations,
) {
for (predicate, span) in instantiated_predicates
.predicates
.into_iter()
.zip(instantiated_predicates.spans.into_iter())
{
for (predicate, span) in instantiated_predicates {
debug!(?predicate);
let category = ConstraintCategory::Predicate(span);
let predicate = self.normalize_with_category(predicate, locations, category);
Expand Down
25 changes: 13 additions & 12 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ fn compare_method_predicate_entailment<'tcx>(
//
// We then register the obligations from the impl_m and check to see
// if all constraints hold.
hybrid_preds
.predicates
.extend(trait_m_predicates.instantiate_own(tcx, trait_to_placeholder_substs).predicates);
hybrid_preds.predicates.extend(
trait_m_predicates
.instantiate_own(tcx, trait_to_placeholder_substs)
.map(|(predicate, _)| predicate),
);

// Construct trait parameter environment and then shift it into the placeholder viewpoint.
// The key step here is to update the caller_bounds's predicates to be
Expand All @@ -230,7 +232,7 @@ fn compare_method_predicate_entailment<'tcx>(
debug!("compare_impl_method: caller_bounds={:?}", param_env.caller_bounds());

let impl_m_own_bounds = impl_m_predicates.instantiate_own(tcx, impl_to_placeholder_substs);
for (predicate, span) in iter::zip(impl_m_own_bounds.predicates, impl_m_own_bounds.spans) {
for (predicate, span) in impl_m_own_bounds {
let normalize_cause = traits::ObligationCause::misc(span, impl_m_hir_id);
let predicate = ocx.normalize(&normalize_cause, param_env, predicate);

Expand Down Expand Up @@ -1828,8 +1830,7 @@ fn compare_type_predicate_entailment<'tcx>(
check_region_bounds_on_impl_item(tcx, impl_ty, trait_ty, false)?;

let impl_ty_own_bounds = impl_ty_predicates.instantiate_own(tcx, impl_substs);

if impl_ty_own_bounds.is_empty() {
if impl_ty_own_bounds.len() == 0 {
// Nothing to check.
return Ok(());
}
Expand All @@ -1844,9 +1845,11 @@ fn compare_type_predicate_entailment<'tcx>(
// associated type in the trait are assumed.
let impl_predicates = tcx.predicates_of(impl_ty_predicates.parent.unwrap());
let mut hybrid_preds = impl_predicates.instantiate_identity(tcx);
hybrid_preds
.predicates
.extend(trait_ty_predicates.instantiate_own(tcx, trait_to_impl_substs).predicates);
hybrid_preds.predicates.extend(
trait_ty_predicates
.instantiate_own(tcx, trait_to_impl_substs)
.map(|(predicate, _)| predicate),
);

debug!("compare_type_predicate_entailment: bounds={:?}", hybrid_preds);

Expand All @@ -1862,9 +1865,7 @@ fn compare_type_predicate_entailment<'tcx>(

debug!("compare_type_predicate_entailment: caller_bounds={:?}", param_env.caller_bounds());

assert_eq!(impl_ty_own_bounds.predicates.len(), impl_ty_own_bounds.spans.len());
for (span, predicate) in std::iter::zip(impl_ty_own_bounds.spans, impl_ty_own_bounds.predicates)
{
for (predicate, span) in impl_ty_own_bounds {
let cause = ObligationCause::misc(span, impl_ty_hir_id);
let predicate = ocx.normalize(&cause, param_env, predicate);

Expand Down
29 changes: 13 additions & 16 deletions compiler/rustc_hir_analysis/src/check/wfcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use rustc_trait_selection::traits::{
};

use std::cell::LazyCell;
use std::iter;
use std::ops::{ControlFlow, Deref};

pub(super) struct WfCheckingCtxt<'a, 'tcx> {
Expand Down Expand Up @@ -1310,7 +1309,7 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id
let infcx = wfcx.infcx;
let tcx = wfcx.tcx();

let predicates = tcx.bound_predicates_of(def_id.to_def_id());
let predicates = tcx.predicates_of(def_id.to_def_id());
let generics = tcx.generics_of(def_id);

let is_our_default = |def: &ty::GenericParamDef| match def.kind {
Expand Down Expand Up @@ -1411,7 +1410,6 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id

// Now we build the substituted predicates.
let default_obligations = predicates
.0
.predicates
.iter()
.flat_map(|&(pred, sp)| {
Expand Down Expand Up @@ -1442,13 +1440,13 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id
}
let mut param_count = CountParams::default();
let has_region = pred.visit_with(&mut param_count).is_break();
let substituted_pred = predicates.rebind(pred).subst(tcx, substs);
let substituted_pred = ty::EarlyBinder(pred).subst(tcx, substs);
// Don't check non-defaulted params, dependent defaults (including lifetimes)
// or preds with multiple params.
if substituted_pred.has_non_region_param() || param_count.params.len() > 1 || has_region
{
None
} else if predicates.0.predicates.iter().any(|&(p, _)| p == substituted_pred) {
} else if predicates.predicates.iter().any(|&(p, _)| p == substituted_pred) {
// Avoid duplication of predicates that contain no parameters, for example.
None
} else {
Expand All @@ -1474,22 +1472,21 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id
traits::Obligation::new(tcx, cause, wfcx.param_env, pred)
});

let predicates = predicates.0.instantiate_identity(tcx);
let predicates = predicates.instantiate_identity(tcx);

let predicates = wfcx.normalize(span, None, predicates);

debug!(?predicates.predicates);
assert_eq!(predicates.predicates.len(), predicates.spans.len());
let wf_obligations =
iter::zip(&predicates.predicates, &predicates.spans).flat_map(|(&p, &sp)| {
traits::wf::predicate_obligations(
infcx,
wfcx.param_env.without_const(),
wfcx.body_id,
p,
sp,
)
});
let wf_obligations = predicates.into_iter().flat_map(|(p, sp)| {
traits::wf::predicate_obligations(
infcx,
wfcx.param_env.without_const(),
wfcx.body_id,
p,
sp,
)
});

let obligations: Vec<_> = wf_obligations.chain(default_obligations).collect();
wfcx.register_obligations(obligations);
Expand Down
8 changes: 3 additions & 5 deletions compiler/rustc_hir_typeck/src/callee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
if self.tcx.has_attr(def_id, sym::rustc_evaluate_where_clauses) {
let predicates = self.tcx.predicates_of(def_id);
let predicates = predicates.instantiate(self.tcx, subst);
for (predicate, predicate_span) in
predicates.predicates.iter().zip(&predicates.spans)
{
for (predicate, predicate_span) in predicates {
let obligation = Obligation::new(
self.tcx,
ObligationCause::dummy_with_span(callee_expr.span),
self.param_env,
*predicate,
predicate,
);
let result = self.evaluate_obligation(&obligation);
self.tcx
Expand All @@ -391,7 +389,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
callee_expr.span,
&format!("evaluate({:?}) = {:?}", predicate, result),
)
.span_label(*predicate_span, "predicate")
.span_label(predicate_span, "predicate")
.emit();
}
}
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2140,8 +2140,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// FIXME(compiler-errors): This could be problematic if something has two
// fn-like predicates with different args, but callable types really never
// do that, so it's OK.
for (predicate, span) in
std::iter::zip(instantiated.predicates, instantiated.spans)
for (predicate, span) in instantiated
{
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = predicate.kind().skip_binder()
&& pred.self_ty().peel_refs() == callee_ty
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_hir_typeck/src/method/confirm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use rustc_middle::ty::{InternalSubsts, UserSubsts, UserType};
use rustc_span::{Span, DUMMY_SP};
use rustc_trait_selection::traits;

use std::iter;
use std::ops::Deref;

struct ConfirmContext<'a, 'tcx> {
Expand Down Expand Up @@ -101,7 +100,7 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
let filler_substs = rcvr_substs
.extend_to(self.tcx, pick.item.def_id, |def, _| self.tcx.mk_param_from_def(def));
let illegal_sized_bound = self.predicates_require_illegal_sized_bound(
&self.tcx.predicates_of(pick.item.def_id).instantiate(self.tcx, filler_substs),
self.tcx.predicates_of(pick.item.def_id).instantiate(self.tcx, filler_substs),
);

// Unify the (adjusted) self type with what the method expects.
Expand Down Expand Up @@ -565,7 +564,7 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {

fn predicates_require_illegal_sized_bound(
&self,
predicates: &ty::InstantiatedPredicates<'tcx>,
predicates: ty::InstantiatedPredicates<'tcx>,
) -> Option<Span> {
let sized_def_id = self.tcx.lang_items().sized_trait()?;

Expand All @@ -575,10 +574,11 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred))
if trait_pred.def_id() == sized_def_id =>
{
let span = iter::zip(&predicates.predicates, &predicates.spans)
let span = predicates
.iter()
.find_map(
|(p, span)| {
if *p == obligation.predicate { Some(*span) } else { None }
if p == obligation.predicate { Some(span) } else { None }
},
)
.unwrap_or(rustc_span::DUMMY_SP);
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_infer/src/infer/error_reporting/note.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,8 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {

let Ok(trait_predicates) = self
.tcx
.bound_explicit_predicates_of(trait_item_def_id)
.map_bound(|p| p.predicates)
.subst_iter_copied(self.tcx, trait_item_substs)
.explicit_predicates_of(trait_item_def_id)
.instantiate_own(self.tcx, trait_item_substs)
.map(|(pred, _)| {
if pred.is_suggestable(self.tcx, false) {
Ok(pred.to_string())
Expand Down
12 changes: 3 additions & 9 deletions compiler/rustc_middle/src/ty/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,9 @@ impl<'tcx> GenericPredicates<'tcx> {
&self,
tcx: TyCtxt<'tcx>,
substs: SubstsRef<'tcx>,
) -> InstantiatedPredicates<'tcx> {
InstantiatedPredicates {
predicates: self
.predicates
.iter()
.map(|(p, _)| EarlyBinder(*p).subst(tcx, substs))
.collect(),
spans: self.predicates.iter().map(|(_, sp)| *sp).collect(),
}
) -> impl Iterator<Item = (Predicate<'tcx>, Span)> + DoubleEndedIterator + ExactSizeIterator
{
EarlyBinder(self.predicates).subst_iter_copied(tcx, substs)
}

#[instrument(level = "debug", skip(self, tcx))]
Expand Down
29 changes: 29 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,35 @@ impl<'tcx> InstantiatedPredicates<'tcx> {
pub fn is_empty(&self) -> bool {
self.predicates.is_empty()
}

pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
(&self).into_iter()
}
}

impl<'tcx> IntoIterator for InstantiatedPredicates<'tcx> {
type Item = (Predicate<'tcx>, Span);

type IntoIter = std::iter::Zip<std::vec::IntoIter<Predicate<'tcx>>, std::vec::IntoIter<Span>>;

fn into_iter(self) -> Self::IntoIter {
debug_assert_eq!(self.predicates.len(), self.spans.len());
std::iter::zip(self.predicates, self.spans)
}
}

impl<'a, 'tcx> IntoIterator for &'a InstantiatedPredicates<'tcx> {
type Item = (Predicate<'tcx>, Span);

type IntoIter = std::iter::Zip<
std::iter::Copied<std::slice::Iter<'a, Predicate<'tcx>>>,
std::iter::Copied<std::slice::Iter<'a, Span>>,
>;

fn into_iter(self) -> Self::IntoIter {
debug_assert_eq!(self.predicates.len(), self.spans.len());
std::iter::zip(self.predicates.iter().copied(), self.spans.iter().copied())
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HashStable, TyEncodable, TyDecodable, Lift)]
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_middle/src/ty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,13 @@ where
}
}

impl<'tcx, I: IntoIterator> ExactSizeIterator for SubstIter<'_, 'tcx, I>
where
I::IntoIter: ExactSizeIterator,
I::Item: TypeFoldable<'tcx>,
{
}

impl<'tcx, 's, I: IntoIterator> EarlyBinder<I>
where
I::Item: Deref,
Expand Down Expand Up @@ -686,6 +693,14 @@ where
}
}

impl<'tcx, I: IntoIterator> ExactSizeIterator for SubstIterCopied<'_, 'tcx, I>
where
I::IntoIter: ExactSizeIterator,
I::Item: Deref,
<I::Item as Deref>::Target: Copy + TypeFoldable<'tcx>,
{
}

pub struct EarlyBinderIter<T> {
t: T,
}
Expand Down
Loading

0 comments on commit d12412c

Please sign in to comment.