Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unification of const abstract impls #104803

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_feature/src/active.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ declare_features! (
(active, half_open_range_patterns_in_slices, "1.66.0", Some(67264), None),
/// Allows `if let` guard in match arms.
(active, if_let_guard, "1.47.0", Some(51114), None),
/// Allow multiple const-generic impls to unify for traits which are abstract.
(active, impl_exhaustive_const_traits, "1.65.0", Some(104806), None),
/// Allows `impl Trait` to be used inside associated types (RFC 2515).
(active, impl_trait_in_assoc_type, "1.70.0", Some(63063), None),
/// Allows `impl Trait` as output type in `Fn` traits in return position of functions.
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,27 +653,32 @@ pub enum ImplSource<'tcx, N> {

/// Successful resolution for a builtin impl.
Builtin(BuiltinImplSource, Vec<N>),
/// Impl Source for an Abstract Const unification.
Exhaustive(Vec<N>),
}

impl<'tcx, N> ImplSource<'tcx, N> {
pub fn nested_obligations(self) -> Vec<N> {
match self {
ImplSource::UserDefined(i) => i.nested,
ImplSource::Param(_, n) | ImplSource::Builtin(_, n) => n,
ImplSource::Exhaustive(n) => n,
}
}

pub fn borrow_nested_obligations(&self) -> &[N] {
match self {
ImplSource::UserDefined(i) => &i.nested,
ImplSource::Param(_, n) | ImplSource::Builtin(_, n) => &n,
ImplSource::Exhaustive(n) => &n,
}
}

pub fn borrow_nested_obligations_mut(&mut self) -> &mut [N] {
match self {
ImplSource::UserDefined(i) => &mut i.nested,
ImplSource::Param(_, n) | ImplSource::Builtin(_, n) => n,
ImplSource::Exhaustive(ref mut n) => n,
}
}

Expand All @@ -691,6 +696,7 @@ impl<'tcx, N> ImplSource<'tcx, N> {
ImplSource::Builtin(source, n) => {
ImplSource::Builtin(source, n.into_iter().map(f).collect())
}
ImplSource::Exhaustive(n) => ImplSource::Exhaustive(n.into_iter().map(f).collect()),
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ pub enum SelectionCandidate<'tcx> {

/// Implementation of `const Destruct`, optionally from a custom `impl const Drop`.
ConstDestructCandidate(Option<DefId>),

/// Candidate which is generated for a abstract const, unifying other impls if they
/// exhaustively cover all values for a type.
/// Will never actually be used, by construction.
ExhaustiveCandidate(ty::PolyTraitPredicate<'tcx>),
}

/// The result of trait evaluation. The order is important
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/traits/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSource<'tcx, N> {
super::ImplSource::Param(ct, n) => {
write!(f, "ImplSourceParamData({n:?}, {ct:?})")
}
super::ImplSource::Exhaustive(ref n) => write!(f, "Exhaustive({:?})", n),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ symbols! {
if_let_guard,
if_while_or_patterns,
ignore,
impl_exhaustive_const_traits,
impl_header_lifetime_elision,
impl_lint_pass,
impl_trait_in_assoc_type,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
// why we special case object types.
false
}
ImplSource::Exhaustive(..) => false,
ImplSource::Builtin(BuiltinImplSource::TraitUpcasting { .. }, _)
| ImplSource::Builtin(BuiltinImplSource::TupleUnsizing, _) => {
// These traits have no associated types.
Expand Down Expand Up @@ -2007,6 +2008,7 @@ fn confirm_select_candidate<'cx, 'tcx>(
}
ImplSource::Builtin(BuiltinImplSource::Object { .. }, _)
| ImplSource::Param(..)
| ImplSource::Exhaustive(..)
| ImplSource::Builtin(BuiltinImplSource::TraitUpcasting { .. }, _)
| ImplSource::Builtin(BuiltinImplSource::TupleUnsizing, _) => {
// we don't create Select candidates with this kind of resolution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hir as hir;
use rustc_infer::traits::ObligationCause;
use rustc_infer::traits::{Obligation, PolyTraitObligation, SelectionError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
use rustc_middle::ty::{self, ConstKind, Ty, TypeVisitableExt};

use crate::traits;
use crate::traits::query::evaluate_obligation::InferCtxtExt;
Expand Down Expand Up @@ -118,6 +118,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

self.assemble_closure_candidates(obligation, &mut candidates);
self.assemble_fn_pointer_candidates(obligation, &mut candidates);
self.assemble_candidates_from_exhaustive_impls(obligation, &mut candidates);
self.assemble_candidates_from_impls(obligation, &mut candidates);
self.assemble_candidates_from_object_ty(obligation, &mut candidates);
}
Expand Down Expand Up @@ -348,6 +349,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup };
let obligation_args = obligation.predicate.skip_binder().trait_ref.args;
// disallow any adts to have recursive types in the LHS
if let ty::Adt(_, args) = obligation.predicate.skip_binder().self_ty().kind() {
if args.consts().any(|c| matches!(c.kind(), ConstKind::Expr(_))) {
return;
}
}
self.tcx().for_each_relevant_impl(
obligation.predicate.def_id(),
obligation.predicate.skip_binder().trait_ref.self_ty(),
Expand Down Expand Up @@ -465,6 +472,54 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
false
}
/// When constructing an impl over a generic const enum (i.e. bool = { true, false })
/// If all possible variants of an enum are implemented AND the obligation is over that
/// variant,
fn assemble_candidates_from_exhaustive_impls(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
candidates: &mut SelectionCandidateSet<'tcx>,
) {
if !self.tcx().features().impl_exhaustive_const_traits {
return;
}

// see note in `assemble_candidates_from_impls`.
if obligation.predicate.references_error() {
return;
}

// note: ow = otherwise
// - check if trait has abstract const argument(s) which is (are) enum or bool, ow return
// - check if trait enum is non-exhaustive, ow return
// - construct required set of possible combinations, with false unless present
// for each relevant trait
// - check if is same trait
// - set combo as present
// If all required sets are present, add candidate impl generic over all combinations.

let query = obligation.predicate.skip_binder().self_ty();
let ty::Adt(_adt_def, adt_substs) = query.kind() else {
return;
};

let Some(ct) = adt_substs
.consts()
.filter(|ct| {
matches!(ct.kind(), ty::ConstKind::Unevaluated(..) | ty::ConstKind::Param(_))
})
.next()
else {
return;
};

// explicitly gate certain types which are exhaustive
if !super::exhaustive_types(self.tcx(), ct.ty(), |_| {}) {
return;
}

candidates.vec.push(ExhaustiveCandidate(obligation.predicate));
}

fn assemble_candidates_from_auto_impls(
&mut self,
Expand Down
77 changes: 76 additions & 1 deletion compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
use rustc_middle::traits::{BuiltinImplSource, SelectionOutputTypeParameterMismatch};
use rustc_middle::ty::{
self, GenericArgs, GenericArgsRef, GenericParamDefKind, ToPolyTraitRef, ToPredicate,
TraitPredicate, Ty, TyCtxt, TypeVisitableExt,
TraitPredicate, TraitRef, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeVisitableExt,
};
use rustc_span::def_id::DefId;

Expand Down Expand Up @@ -120,6 +120,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let data = self.confirm_const_destruct_candidate(obligation, def_id)?;
ImplSource::Builtin(BuiltinImplSource::Misc, data)
}
ExhaustiveCandidate(candidate) => {
let obligations = self.confirm_exhaustive_candidate(obligation, candidate);
ImplSource::Exhaustive(obligations)
}
};

// The obligations returned by confirmation are recursively evaluated
Expand Down Expand Up @@ -1360,4 +1364,75 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

Ok(nested)
}

/// Generates obligations for a lazy candidate
/// Where the obligations generated are all possible values for the type of a
/// `ConstKind::Unevaluated(..)`.
/// In the future, it would be nice to extend this to inductive proofs.
#[allow(unused_variables, unused_mut, dead_code)]
fn confirm_exhaustive_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
candidate: ty::PolyTraitPredicate<'tcx>,
) -> Vec<PredicateObligation<'tcx>> {
let mut obligations = vec![];
let tcx = self.tcx();

let query = obligation.predicate.skip_binder().trait_ref.self_ty();
let ty::Adt(_adt_def, adt_substs) = query.kind() else {
return obligations;
};
// Find one adt subst at a time, then will be handled recursively.
let const_to_replace = adt_substs
.consts()
.find(|ct| {
matches!(ct.kind(), ty::ConstKind::Unevaluated(..) | ty::ConstKind::Param(_))
})
.unwrap();
let is_exhaustive = super::exhaustive_types(tcx, const_to_replace.ty(), |val| {
let predicate = candidate.map_bound(|pt_ref| {
let query = pt_ref.self_ty();
let mut folder = Folder { tcx, replace: const_to_replace, with: val };
let mut new_poly_trait_ref = pt_ref.clone();
new_poly_trait_ref.trait_ref = TraitRef::new(
tcx,
pt_ref.trait_ref.def_id,
[query.fold_with(&mut folder).into()]
.into_iter()
.chain(pt_ref.trait_ref.args.iter().skip(1)),
);
new_poly_trait_ref
});

let ob = Obligation::new(
self.tcx(),
obligation.cause.clone(),
obligation.param_env,
predicate,
);
obligations.push(ob);
});

// should only allow exhaustive types in
// candidate_assembly::assemble_candidates_from_exhaustive_impls
assert!(is_exhaustive);

obligations
}
}

/// Folder for replacing specific const values in `substs`.
struct Folder<'tcx> {
tcx: TyCtxt<'tcx>,
replace: ty::Const<'tcx>,
with: ty::Const<'tcx>,
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for Folder<'tcx> {
fn interner(&self) -> TyCtxt<'tcx> {
self.tcx
}
fn fold_const(&mut self, c: ty::Const<'tcx>) -> ty::Const<'tcx> {
if c == self.replace { self.with } else { c }
}
}
52 changes: 52 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
// This is a fix for #53123 and prevents winnowing from accidentally extending the
// lifetime of a variable.
match (&other.candidate, &victim.candidate) {
(_, ExhaustiveCandidate(..)) => DropVictim::Yes,
(ExhaustiveCandidate(..), _) => DropVictim::No,
// FIXME(@jswrenn): this should probably be more sophisticated
(TransmutabilityCandidate, _) | (_, TransmutabilityCandidate) => DropVictim::No,

Expand Down Expand Up @@ -2604,6 +2606,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
assert_eq!(predicates.parent, None);
let predicates = predicates.instantiate_own(tcx, args);
let mut obligations = Vec::with_capacity(predicates.len());

for (index, (predicate, span)) in predicates.into_iter().enumerate() {
let cause =
if Some(parent_trait_pred.def_id()) == tcx.lang_items().coerce_unsized_trait() {
Expand Down Expand Up @@ -2999,3 +3002,52 @@ fn bind_generator_hidden_types_above<'tcx>(
));
ty::Binder::bind_with_vars(hidden_types, bound_vars)
}

// For a given type, will run a function over all constants for that type if permitted.
// returns false if not permitted. Callers should not rely on the order.
fn exhaustive_types<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
mut f: impl FnMut(ty::Const<'tcx>),
) -> bool {
use std::mem::transmute;
match ty.kind() {
ty::Bool => {
for v in [true, false].into_iter() {
f(ty::Const::from_bool(tcx, v));
}
}
// Should always compile, as this is never instantiable
ty::Never => {}
ty::Adt(adt_def, _substs) => {
if adt_def.is_payloadfree() {
return true;
}
if adt_def.is_variant_list_non_exhaustive() {
return false;
}

// FIXME(julianknodt): here need to create constants for each variant
return false;
}

ty::Int(ty::IntTy::I8) => {
for v in -128i8..127i8 {
let c = ty::Const::from_bits(
tcx,
unsafe { transmute(v as i128) },
ty::ParamEnv::empty().and(ty),
);
f(c);
}
}
ty::Uint(ty::UintTy::U8) => {
for v in 0u8..=255u8 {
let c = ty::Const::from_bits(tcx, v as u128, ty::ParamEnv::empty().and(ty));
f(c);
}
}
_ => return false,
}
true
}
1 change: 1 addition & 0 deletions compiler/rustc_ty_utils/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ fn resolve_associated_item<'tcx>(
}
traits::ImplSource::Param(..)
| traits::ImplSource::Builtin(BuiltinImplSource::TraitUpcasting { .. }, _)
| traits::ImplSource::Exhaustive(..)
| traits::ImplSource::Builtin(BuiltinImplSource::TupleUnsizing, _) => None,
})
}
Expand Down
31 changes: 31 additions & 0 deletions tests/ui/const-generics/bool_cond.normal.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
error[E0277]: the trait bound `ConstOption<usize, { N <= 0 }>: Default` is not satisfied
--> $DIR/bool_cond.rs:42:5
|
LL | #[derive(Default)]
| ------- in this derive macro expansion
...
LL | _a: ConstOption<usize, { N <= 0 }>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Default` is not implemented for `ConstOption<usize, { N <= 0 }>`
|
= help: the following other types implement trait `Default`:
ConstOption<T, false>
ConstOption<T, true>
= note: this error originates in the derive macro `Default` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `ConstOption<usize, { N <= 1 }>: Default` is not satisfied
--> $DIR/bool_cond.rs:44:5
|
LL | #[derive(Default)]
| ------- in this derive macro expansion
...
LL | _b: ConstOption<usize, { N <= 1 }>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Default` is not implemented for `ConstOption<usize, { N <= 1 }>`
|
= help: the following other types implement trait `Default`:
ConstOption<T, false>
ConstOption<T, true>
= note: this error originates in the derive macro `Default` (in Nightly builds, run with -Z macro-backtrace for more info)

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0277`.
Loading
Loading