Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

-Ztrait-solver=next: stop depending on old solver #113317

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions compiler/rustc_infer/src/infer/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,18 @@ impl<'tcx> InferCtxt<'tcx> {
recursion_depth: usize,
obligations: &mut Vec<PredicateObligation<'tcx>>,
) -> Ty<'tcx> {
if self.next_trait_solver() {
// FIXME(-Ztrait-solver=next): Instead of branching here,
// completely change the normalization routine with the new solver.
//
// The new solver correctly handles projection equality so this hack
// is not necessary. if re-enabled it should emit `PredicateKind::AliasRelate`
// not `PredicateKind::Clause(ClauseKind::Projection(..))` as in the new solver
// `Projection` is used as `normalizes-to` which will fail for `<T as Trait>::Assoc eq ?0`.
return projection_ty.to_ty(self.tcx);
} else {
let def_id = projection_ty.def_id;
let ty_var = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: self.tcx.def_span(def_id),
});
let projection =
ty::Binder::dummy(ty::PredicateKind::Clause(ty::ClauseKind::Projection(
ty::ProjectionPredicate { projection_ty, term: ty_var.into() },
)));
let obligation =
Obligation::with_depth(self.tcx, cause, recursion_depth, param_env, projection);
obligations.push(obligation);
ty_var
}
debug_assert!(!self.next_trait_solver());
let def_id = projection_ty.def_id;
let ty_var = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: self.tcx.def_span(def_id),
});
let projection = ty::Binder::dummy(ty::PredicateKind::Clause(ty::ClauseKind::Projection(
ty::ProjectionPredicate { projection_ty, term: ty_var.into() },
)));
let obligation =
Obligation::with_depth(self.tcx, cause, recursion_depth, param_env, projection);
obligations.push(obligation);
ty_var
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_trait_selection/src/solve/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ impl<'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for NormalizationFolder<'_, 'tcx> {
fn try_fold_ty(&mut self, ty: Ty<'tcx>) -> Result<Ty<'tcx>, Self::Error> {
let reveal = self.at.param_env.reveal();
let infcx = self.at.infcx;
debug_assert_eq!(ty, infcx.shallow_resolve(ty));
if !needs_normalization(&ty, reveal) {
return Ok(ty);
}
Expand Down Expand Up @@ -192,6 +193,7 @@ impl<'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for NormalizationFolder<'_, 'tcx> {
fn try_fold_const(&mut self, ct: ty::Const<'tcx>) -> Result<ty::Const<'tcx>, Self::Error> {
let reveal = self.at.param_env.reveal();
let infcx = self.at.infcx;
debug_assert_eq!(ct, infcx.shallow_resolve(ct));
if !needs_normalization(&ct, reveal) {
return Ok(ct);
}
Expand Down
22 changes: 16 additions & 6 deletions compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::fmt::Debug;
use std::iter;
use std::ops::ControlFlow;

use super::query::evaluate_obligation::InferCtxtExt;
use super::NormalizeExt;

/// Whether we do the orphan check relative to this crate or
Expand Down Expand Up @@ -290,19 +291,28 @@ fn impl_intersection_has_impossible_obligation<'cx, 'tcx>(
) -> bool {
let infcx = selcx.infcx;

let obligation_guaranteed_to_fail = move |obligation: &PredicateObligation<'tcx>| {
if infcx.next_trait_solver() {
infcx.evaluate_obligation(obligation).map_or(false, |result| !result.may_apply())
} else {
// We use `evaluate_root_obligation` to correctly track
// intercrate ambiguity clauses. We do not need this in the
// new solver.
selcx.evaluate_root_obligation(obligation).map_or(
false, // Overflow has occurred, and treat the obligation as possibly holding.
|result| !result.may_apply(),
)
}
};

let opt_failing_obligation = [&impl1_header.predicates, &impl2_header.predicates]
.into_iter()
.flatten()
.map(|&predicate| {
Obligation::new(infcx.tcx, ObligationCause::dummy(), param_env, predicate)
})
.chain(obligations)
.find(|o| {
selcx.evaluate_root_obligation(o).map_or(
false, // Overflow has occurred, and treat the obligation as possibly holding.
|result| !result.may_apply(),
)
});
.find(obligation_guaranteed_to_fail);

if let Some(failing_obligation) = opt_failing_obligation {
debug!("overlap: obligation unsatisfiable {:?}", failing_obligation);
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ impl<'a, 'b, 'tcx> AssocTypeNormalizer<'a, 'b, 'tcx> {
depth: usize,
obligations: &'a mut Vec<PredicateObligation<'tcx>>,
) -> AssocTypeNormalizer<'a, 'b, 'tcx> {
debug_assert!(!selcx.infcx.next_trait_solver());
AssocTypeNormalizer {
selcx,
param_env,
Expand Down Expand Up @@ -1122,6 +1123,7 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
obligations: &mut Vec<PredicateObligation<'tcx>>,
) -> Result<Option<Term<'tcx>>, InProgress> {
let infcx = selcx.infcx;
debug_assert!(!selcx.infcx.next_trait_solver());
// Don't use the projection cache in intercrate mode -
// the `infcx` may be re-used between intercrate in non-intercrate
// mode, which could lead to using incorrect cache results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
}
})
} else {
assert!(!self.intercrate);
let c_pred = self.canonicalize_query_keep_static(
param_env.and(obligation.predicate),
&mut _orig_values,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::solve;
use crate::traits::query::NoSolution;
use crate::traits::wf;
use crate::traits::ObligationCtxt;
Expand All @@ -6,6 +7,7 @@ use rustc_infer::infer::canonical::Canonical;
use rustc_infer::infer::outlives::components::{push_outlives_components, Component};
use rustc_infer::traits::query::OutlivesBound;
use rustc_middle::infer::canonical::CanonicalQueryResponse;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::{self, ParamEnvAnd, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::def_id::CRATE_DEF_ID;
use rustc_span::source_map::DUMMY_SP;
Expand Down Expand Up @@ -164,19 +166,29 @@ pub fn compute_implied_outlives_bounds_inner<'tcx>(

// We lazily compute the outlives components as
// `select_all_or_error` constrains inference variables.
let implied_bounds = outlives_bounds
.into_iter()
.flat_map(|ty::OutlivesPredicate(a, r_b)| match a.unpack() {
ty::GenericArgKind::Lifetime(r_a) => vec![OutlivesBound::RegionSubRegion(r_b, r_a)],
let mut implied_bounds = Vec::new();
for ty::OutlivesPredicate(a, r_b) in outlives_bounds {
match a.unpack() {
ty::GenericArgKind::Lifetime(r_a) => {
implied_bounds.push(OutlivesBound::RegionSubRegion(r_b, r_a))
}
ty::GenericArgKind::Type(ty_a) => {
let ty_a = ocx.infcx.resolve_vars_if_possible(ty_a);
let mut ty_a = ocx.infcx.resolve_vars_if_possible(ty_a);
// Need to manually normalize in the new solver as `wf::obligations` does not.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't that done? I see that you're bailing out in WfPredicates::normalize, but not why

Copy link
Contributor Author

@lcnr lcnr Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extended the comment in wf.rs, the reasons are

  • the new solver can only deeply normalize if there is no ambiguity, so there is nothing we could use in wf
  • we never have to normalize obligations in the new solver, but we do have to normalize the inputs to region solving

if ocx.infcx.next_trait_solver() {
ty_a = solve::deeply_normalize(
ocx.infcx.at(&ObligationCause::dummy(), param_env),
ty_a,
)
.map_err(|_errs| NoSolution)?;
}
let mut components = smallvec![];
push_outlives_components(tcx, ty_a, &mut components);
implied_bounds_from_components(r_b, components)
implied_bounds.extend(implied_bounds_from_components(r_b, components))
}
ty::GenericArgKind::Const(_) => unreachable!(),
})
.collect();
}
}

Ok(implied_bounds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
/// `FnPtr`, when we wanted to report that it doesn't implement `Trait`.
#[instrument(level = "trace", skip(self), ret)]
fn reject_fn_ptr_impls(
&self,
&mut self,
impl_def_id: DefId,
obligation: &TraitObligation<'tcx>,
impl_self_ty: Ty<'tcx>,
Expand Down Expand Up @@ -464,7 +464,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
ty::PredicateKind::Clause(ty::ClauseKind::Trait(pred))
})),
);
if let Ok(r) = self.infcx.evaluate_obligation(&obligation) {
if let Ok(r) = self.evaluate_root_obligation(&obligation) {
if !r.may_apply() {
return true;
}
Expand Down
65 changes: 21 additions & 44 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::infer::LateBoundRegionConversionTime;
use rustc_infer::traits::TraitEngine;
use rustc_infer::traits::TraitEngineExt;
use rustc_middle::dep_graph::{DepKind, DepNodeIndex};
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
Expand Down Expand Up @@ -312,6 +310,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
stack: &TraitObligationStack<'o, 'tcx>,
) -> SelectionResult<'tcx, SelectionCandidate<'tcx>> {
debug_assert!(!self.infcx.next_trait_solver());
// Watch out for overflow. This intentionally bypasses (and does
// not update) the cache.
self.check_recursion_limit(&stack.obligation, &stack.obligation)?;
Expand Down Expand Up @@ -526,21 +525,20 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
/// Evaluates whether the obligation `obligation` can be satisfied
/// and returns an `EvaluationResult`. This is meant for the
/// *initial* call.
///
/// Do not use this directly, use `infcx.evaluate_obligation` instead.
pub fn evaluate_root_obligation(
&mut self,
obligation: &PredicateObligation<'tcx>,
) -> Result<EvaluationResult, OverflowError> {
debug_assert!(!self.infcx.next_trait_solver());
self.evaluation_probe(|this| {
let goal =
this.infcx.resolve_vars_if_possible((obligation.predicate, obligation.param_env));
let mut result = if this.infcx.next_trait_solver() {
this.evaluate_predicates_recursively_in_new_solver([obligation.clone()])?
} else {
this.evaluate_predicate_recursively(
TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
obligation.clone(),
)?
};
let mut result = this.evaluate_predicate_recursively(
TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
obligation.clone(),
)?;
// If the predicate has done any inference, then downgrade the
// result to ambiguous.
if this.infcx.shallow_resolve(goal) != goal {
Expand Down Expand Up @@ -587,42 +585,19 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
where
I: IntoIterator<Item = PredicateObligation<'tcx>> + std::fmt::Debug,
{
if self.infcx.next_trait_solver() {
self.evaluate_predicates_recursively_in_new_solver(predicates)
} else {
let mut result = EvaluatedToOk;
for mut obligation in predicates {
obligation.set_depth_from_parent(stack.depth());
let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
if let EvaluatedToErr = eval {
// fast-path - EvaluatedToErr is the top of the lattice,
// so we don't need to look on the other predicates.
return Ok(EvaluatedToErr);
} else {
result = cmp::max(result, eval);
}
let mut result = EvaluatedToOk;
for mut obligation in predicates {
obligation.set_depth_from_parent(stack.depth());
let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
if let EvaluatedToErr = eval {
// fast-path - EvaluatedToErr is the top of the lattice,
// so we don't need to look on the other predicates.
return Ok(EvaluatedToErr);
} else {
result = cmp::max(result, eval);
}
Ok(result)
}
}

/// Evaluates the predicates using the new solver when `-Ztrait-solver=next` is enabled
fn evaluate_predicates_recursively_in_new_solver(
&mut self,
predicates: impl IntoIterator<Item = PredicateObligation<'tcx>>,
) -> Result<EvaluationResult, OverflowError> {
let mut fulfill_cx = crate::solve::FulfillmentCtxt::new(self.infcx);
fulfill_cx.register_predicate_obligations(self.infcx, predicates);
// True errors
// FIXME(-Ztrait-solver=next): Overflows are reported as ambig here, is that OK?
if !fulfill_cx.select_where_possible(self.infcx).is_empty() {
return Ok(EvaluatedToErr);
}
if !fulfill_cx.select_all_or_error(self.infcx).is_empty() {
return Ok(EvaluatedToAmbig);
}
// Regions and opaques are handled in the `evaluation_probe` by looking at the snapshot
Ok(EvaluatedToOk)
Ok(result)
}

#[instrument(
Expand All @@ -636,6 +611,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
previous_stack: TraitObligationStackList<'o, 'tcx>,
obligation: PredicateObligation<'tcx>,
) -> Result<EvaluationResult, OverflowError> {
debug_assert!(!self.infcx.next_trait_solver());
// `previous_stack` stores a `TraitObligation`, while `obligation` is
// a `PredicateObligation`. These are distinct types, so we can't
// use any `Option` combinator method that would force them to be
Expand Down Expand Up @@ -1182,6 +1158,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
stack: &TraitObligationStack<'o, 'tcx>,
) -> Result<EvaluationResult, OverflowError> {
debug_assert!(!self.infcx.next_trait_solver());
// In intercrate mode, whenever any of the generics are unbound,
// there can always be an impl. Even if there are no impls in
// this crate, perhaps the type would be unified with
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_trait_selection/src/traits/wf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> {
}

fn normalize(self, infcx: &InferCtxt<'tcx>) -> Vec<traits::PredicateObligation<'tcx>> {
// Do not normalize `wf` obligations with the new solver.
if infcx.next_trait_solver() {
return self.out;
}

let cause = self.cause(traits::WellFormed(None));
let param_env = self.param_env;
let mut obligations = Vec::with_capacity(self.out.len());
Expand Down