Skip to content

Commit

Permalink
Deeply check that method signatures match, and allow for nested RPITITs
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Sep 9, 2022
1 parent 1f03ede commit cdf7807
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 52 deletions.
9 changes: 2 additions & 7 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1358,10 +1358,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
ImplTraitContext::InTrait => {
self.lower_impl_trait_in_trait(span, def_node_id, |lctx| {
lctx.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Trait),
)
lctx.lower_param_bounds(bounds, ImplTraitContext::InTrait)
})
}
ImplTraitContext::Universal => {
Expand Down Expand Up @@ -1559,8 +1556,6 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
) -> hir::TyKind<'hir> {
let opaque_ty_def_id = self.local_def_id(opaque_ty_node_id);
self.with_hir_id_owner(opaque_ty_node_id, |lctx| {
// FIXME(RPITIT): This should be a more descriptive ImplTraitPosition, i.e. nested RPITIT
// FIXME(RPITIT): We _also_ should support this eventually
let hir_bounds = lower_bounds(lctx);
let rpitit_placeholder = hir::ImplTraitPlaceholder { bounds: hir_bounds };
let rpitit_item = hir::Item {
Expand Down Expand Up @@ -2073,7 +2068,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let bound = lctx.lower_async_fn_output_type_to_future_bound(
output,
output.span(),
ImplTraitContext::Disallowed(ImplTraitPosition::TraitReturn),
ImplTraitContext::InTrait,
);
arena_vec![lctx; bound]
});
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ macro_rules! arena_types {
[decode] impl_source: rustc_middle::traits::ImplSource<'tcx, ()>,

[] dep_kind: rustc_middle::dep_graph::DepKindStruct<'tcx>,

[] trait_impl_trait_tys: rustc_data_structures::fx::FxHashMap<rustc_hir::def_id::DefId, rustc_middle::ty::Ty<'tcx>>,
]);
)
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ rustc_queries! {
separate_provide_extern
}

query compare_predicates_and_trait_impl_trait_tys(key: DefId)
-> Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed>
{
desc { "better description please" }
separate_provide_extern
}

query analysis(key: ()) -> Result<(), ErrorGuaranteed> {
eval_always
desc { "running analysis passes on this crate" }
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2484,6 +2484,14 @@ impl<'tcx> TyCtxt<'tcx> {
pub fn is_const_default_method(self, def_id: DefId) -> bool {
matches!(self.trait_of_item(def_id), Some(trait_id) if self.has_attr(trait_id, sym::const_trait))
}

pub fn impl_trait_in_trait_parent(self, mut def_id: DefId) -> DefId {
while let def_kind = self.def_kind(def_id) && def_kind != DefKind::AssocFn {
debug_assert_eq!(def_kind, DefKind::ImplTraitPlaceholder);
def_id = self.parent(def_id);
}
def_id
}
}

/// Yields the parent function's `LocalDefId` if `def_id` is an `impl Trait` definition.
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,13 @@ impl<'tcx> TyCtxt<'tcx> {
ty::EarlyBinder(self.type_of(def_id))
}

pub fn bound_trait_impl_trait_tys(
self,
def_id: DefId,
) -> ty::EarlyBinder<Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed>> {
ty::EarlyBinder(self.compare_predicates_and_trait_impl_trait_tys(def_id))
}

pub fn bound_fn_sig(self, def_id: DefId) -> ty::EarlyBinder<ty::PolyFnSig<'tcx>> {
ty::EarlyBinder(self.fn_sig(def_id))
}
Expand Down
13 changes: 6 additions & 7 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
) {
let tcx = selcx.tcx();
if tcx.def_kind(obligation.predicate.item_def_id) == DefKind::ImplTraitPlaceholder {
let trait_fn_def_id = tcx.parent(obligation.predicate.item_def_id);
let trait_fn_def_id = tcx.impl_trait_in_trait_parent(obligation.predicate.item_def_id);
let trait_def_id = tcx.parent(trait_fn_def_id);
let trait_substs =
obligation.predicate.substs.truncate_to(tcx, tcx.generics_of(trait_def_id));
Expand Down Expand Up @@ -2176,11 +2176,6 @@ fn confirm_impl_trait_in_trait_candidate<'tcx>(
let impl_fn_def_id = leaf_def.item.def_id;
let impl_fn_substs = obligation.predicate.substs.rebase_onto(tcx, trait_fn_def_id, data.substs);

let sig = tcx
.bound_fn_sig(impl_fn_def_id)
.map_bound(|fn_sig| tcx.liberate_late_bound_regions(impl_fn_def_id, fn_sig))
.subst(tcx, impl_fn_substs);

let cause = ObligationCause::new(
obligation.cause.span,
obligation.cause.body_id,
Expand Down Expand Up @@ -2217,7 +2212,11 @@ fn confirm_impl_trait_in_trait_candidate<'tcx>(
selcx,
obligation.param_env,
cause.clone(),
sig.output(),
tcx.bound_trait_impl_trait_tys(impl_fn_def_id)
.map_bound(|tys| {
tys.map_or_else(|_| tcx.ty_error(), |tys| tys[&obligation.predicate.item_def_id])
})
.subst(tcx, impl_fn_substs),
&mut obligations,
);

Expand Down
147 changes: 110 additions & 37 deletions compiler/rustc_typeck/src/check/compare_method.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
use super::potentially_plural_count;
use crate::errors::LifetimesOrBoundsMismatchOnTrait;
use rustc_data_structures::fx::FxHashSet;
use hir::def_id::DefId;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_errors::{pluralize, struct_span_err, Applicability, DiagnosticId, ErrorGuaranteed};
use rustc_hir as hir;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::intravisit;
use rustc_hir::{GenericParamKind, ImplItemKind, TraitItemKind};
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_infer::infer::{self, TyCtxtInferExt};
use rustc_infer::traits::util;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::subst::{InternalSubsts, Subst};
use rustc_middle::ty::util::ExplicitSelf;
use rustc_middle::ty::{self, DefIdTree};
use rustc_middle::ty::{
self, DefIdTree, Ty, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable,
};
use rustc_middle::ty::{GenericParamDefKind, ToPredicate, TyCtxt};
use rustc_span::Span;
use rustc_trait_selection::traits::error_reporting::InferCtxtExt;
Expand Down Expand Up @@ -64,10 +68,7 @@ pub(crate) fn compare_impl_method<'tcx>(
return;
}

if let Err(_) = compare_predicate_entailment(tcx, impl_m, impl_m_span, trait_m, impl_trait_ref)
{
return;
}
tcx.ensure().compare_predicates_and_trait_impl_trait_tys(impl_m.def_id);
}

/// This function is best explained by example. Consider a trait:
Expand Down Expand Up @@ -136,13 +137,15 @@ pub(crate) fn compare_impl_method<'tcx>(
///
/// Finally we register each of these predicates as an obligation and check that
/// they hold.
fn compare_predicate_entailment<'tcx>(
pub(super) fn compare_predicates_and_trait_impl_trait_tys<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: &ty::AssocItem,
impl_m_span: Span,
trait_m: &ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) -> Result<(), ErrorGuaranteed> {
def_id: DefId,
) -> Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed> {
let impl_m = tcx.opt_associated_item(def_id).unwrap();
let impl_m_span = tcx.def_span(def_id);
let trait_m = tcx.opt_associated_item(impl_m.trait_item_def_id.unwrap()).unwrap();
let impl_trait_ref = tcx.impl_trait_ref(impl_m.impl_container(tcx).unwrap()).unwrap();

let trait_to_impl_substs = impl_trait_ref.substs;

// This node-id should be used for the `body_id` field on each
Expand All @@ -161,6 +164,7 @@ fn compare_predicate_entailment<'tcx>(
kind: impl_m.kind,
},
);
let return_span = tcx.hir().fn_decl_by_hir_id(impl_m_hir_id).unwrap().output.span();

// Create mapping from impl to placeholder.
let impl_to_placeholder_substs = InternalSubsts::identity_for_item(tcx, impl_m.def_id);
Expand Down Expand Up @@ -266,6 +270,13 @@ fn compare_predicate_entailment<'tcx>(

let trait_sig = tcx.bound_fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs);
let trait_sig = tcx.liberate_late_bound_regions(impl_m.def_id, trait_sig);
let mut collector =
ImplTraitInTraitCollector::new(&ocx, return_span, param_env, impl_m_hir_id);
// FIXME(RPITIT): This should only be needed on the output type, but
// RPITIT placeholders shouldn't show up anywhere except for there,
// so I think this is fine.
let trait_sig = trait_sig.fold_with(&mut collector);

// Next, add all inputs and output as well-formed tys. Importantly,
// we have to do this before normalization, since the normalized ty may
// not contain the input parameters. See issue #87748.
Expand Down Expand Up @@ -391,30 +402,6 @@ fn compare_predicate_entailment<'tcx>(
return Err(diag.emit());
}

// Check that an impl's fn return satisfies the bounds of the
// FIXME(RPITIT): Generalize this to nested impl traits
if let ty::Projection(proj) = tcx.fn_sig(trait_m.def_id).skip_binder().output().kind()
&& tcx.def_kind(proj.item_def_id) == DefKind::ImplTraitPlaceholder
{
let return_span = tcx.hir().fn_decl_by_hir_id(impl_m_hir_id).unwrap().output.span();

for (predicate, span) in tcx
.bound_explicit_item_bounds(proj.item_def_id)
.transpose_iter()
.map(|pred| pred.map_bound(|pred| *pred).subst(tcx, trait_to_placeholder_substs))
{
ocx.register_obligation(traits::Obligation::new(
traits::ObligationCause::new(
return_span,
impl_m_hir_id,
ObligationCauseCode::BindingObligation(proj.item_def_id, span),
),
param_env,
predicate,
));
}
}

// Check that all obligations are satisfied by the implementation's
// version.
let errors = ocx.select_all_or_error();
Expand All @@ -435,10 +422,96 @@ fn compare_predicate_entailment<'tcx>(
&outlives_environment,
);

Ok(())
let mut collected_tys = FxHashMap::default();
for (def_id, ty) in collector.types {
match infcx.fully_resolve(ty) {
Ok(ty) => {
collected_tys.insert(def_id, ty);
}
Err(err) => {
tcx.sess.delay_span_bug(
return_span,
format!("could not fully resolve: {ty} => {err:?}"),
);
collected_tys.insert(def_id, tcx.ty_error());
}
}
}

Ok(&*tcx.arena.alloc(collected_tys))
})
}

struct ImplTraitInTraitCollector<'a, 'tcx> {
ocx: &'a ObligationCtxt<'a, 'tcx>,
types: FxHashMap<DefId, Ty<'tcx>>,
span: Span,
param_env: ty::ParamEnv<'tcx>,
body_id: hir::HirId,
}

impl<'a, 'tcx> ImplTraitInTraitCollector<'a, 'tcx> {
fn new(
ocx: &'a ObligationCtxt<'a, 'tcx>,
span: Span,
param_env: ty::ParamEnv<'tcx>,
body_id: hir::HirId,
) -> Self {
ImplTraitInTraitCollector { ocx, types: FxHashMap::default(), span, param_env, body_id }
}
}

impl<'tcx> TypeFolder<'tcx> for ImplTraitInTraitCollector<'_, 'tcx> {
fn tcx<'a>(&'a self) -> TyCtxt<'tcx> {
self.ocx.infcx.tcx
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if let ty::Projection(proj) = ty.kind()
&& self.tcx().def_kind(proj.item_def_id) == DefKind::ImplTraitPlaceholder
{
if let Some(ty) = self.types.get(&proj.item_def_id) {
return *ty;
}
//FIXME(RPITIT): Deny nested RPITIT in substs too
if proj.substs.has_escaping_bound_vars() {
bug!("FIXME(RPITIT): error here");
}
// Replace with infer var
let infer_ty = self.ocx.infcx.next_ty_var(TypeVariableOrigin {
span: self.span,
kind: TypeVariableOriginKind::MiscVariable,
});
self.types.insert(proj.item_def_id, infer_ty);
// Recurse into bounds
for pred in self.tcx().bound_explicit_item_bounds(proj.item_def_id).transpose_iter() {
let pred_span = pred.0.1;

let pred = pred.map_bound(|(pred, _)| *pred).subst(self.tcx(), proj.substs);
let pred = pred.fold_with(self);
let pred = self.ocx.normalize(
ObligationCause::misc(self.span, self.body_id),
self.param_env,
pred,
);

self.ocx.register_obligation(traits::Obligation::new(
ObligationCause::new(
self.span,
self.body_id,
ObligationCauseCode::BindingObligation(proj.item_def_id, pred_span),
),
self.param_env,
pred,
));
}
infer_ty
} else {
ty.super_fold_with(self)
}
}
}

fn check_region_bounds_on_impl_item<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: &ty::AssocItem,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_typeck/src/check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ use crate::require_c_abi_if_c_variadic;
use crate::util::common::indenter;

use self::coercion::DynamicCoerceMany;
use self::compare_method::compare_predicates_and_trait_impl_trait_tys;
use self::region::region_scope_tree;
pub use self::Expectation::*;

Expand Down Expand Up @@ -249,6 +250,7 @@ pub fn provide(providers: &mut Providers) {
used_trait_imports,
check_mod_item_types,
region_scope_tree,
compare_predicates_and_trait_impl_trait_tys,
..*providers
};
}
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_typeck/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,10 @@ fn generics_of(tcx: TyCtxt<'_>, def_id: DefId) -> ty::Generics {
}
ItemKind::ImplTraitPlaceholder(_) => {
let parent_id = tcx.hir().get_parent_item(hir_id).to_def_id();
assert_eq!(tcx.def_kind(parent_id), DefKind::AssocFn);
assert!(matches!(
tcx.def_kind(parent_id),
DefKind::AssocFn | DefKind::ImplTraitPlaceholder
));
Some(parent_id)
}
_ => None,
Expand Down
16 changes: 16 additions & 0 deletions src/test/ui/impl-trait/in-trait/deep-match-works.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// check-pass

#![feature(return_position_impl_trait_in_trait)]
#![allow(incomplete_features)]

struct Wrapper<T>(T);

trait Foo {
fn bar() -> Wrapper<impl Sized>;
}

impl Foo for () {
fn bar() -> Wrapper<i32> { Wrapper(0) }
}

fn main() {}
15 changes: 15 additions & 0 deletions src/test/ui/impl-trait/in-trait/deep-match.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#![feature(return_position_impl_trait_in_trait)]
#![allow(incomplete_features)]

struct Wrapper<T>(T);

trait Foo {
fn bar() -> Wrapper<impl Sized>;
}

impl Foo for () {
fn bar() -> i32 { 0 }
//~^ ERROR method `bar` has an incompatible type for trait
}

fn main() {}
Loading

0 comments on commit cdf7807

Please sign in to comment.