Skip to content

Commit 8752b43

Browse files
authored
Rollup merge of #77754 - bugadani:find_map_relevant_impl, r=matthewjasper
Add TraitDef::find_map_relevant_impl This PR adds a method to `TraitDef`. While `for_each_relevant_impl` covers the general use case, sometimes it's not necessary to scan through all the relevant implementations, so this PR introduces a new method, `find_map_relevant_impl`. I've also replaced the `for_each_relevant_impl` calls where possible. I'm hoping for a tiny bit of efficiency gain here and there.
2 parents 8368588 + 217d6f9 commit 8752b43

File tree

5 files changed

+38
-27
lines changed

5 files changed

+38
-27
lines changed

compiler/rustc_middle/src/ty/trait_def.rs

+25-3
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,26 @@ impl<'tcx> TyCtxt<'tcx> {
123123
self_ty: Ty<'tcx>,
124124
mut f: F,
125125
) {
126+
let _: Option<()> = self.find_map_relevant_impl(def_id, self_ty, |did| {
127+
f(did);
128+
None
129+
});
130+
}
131+
132+
/// Applies function to every impl that could possibly match the self type `self_ty` and returns
133+
/// the first non-none value.
134+
pub fn find_map_relevant_impl<T, F: FnMut(DefId) -> Option<T>>(
135+
self,
136+
def_id: DefId,
137+
self_ty: Ty<'tcx>,
138+
mut f: F,
139+
) -> Option<T> {
126140
let impls = self.trait_impls_of(def_id);
127141

128142
for &impl_def_id in impls.blanket_impls.iter() {
129-
f(impl_def_id);
143+
if let result @ Some(_) = f(impl_def_id) {
144+
return result;
145+
}
130146
}
131147

132148
// simplify_type(.., false) basically replaces type parameters and
@@ -157,14 +173,20 @@ impl<'tcx> TyCtxt<'tcx> {
157173
if let Some(simp) = fast_reject::simplify_type(self, self_ty, true) {
158174
if let Some(impls) = impls.non_blanket_impls.get(&simp) {
159175
for &impl_def_id in impls {
160-
f(impl_def_id);
176+
if let result @ Some(_) = f(impl_def_id) {
177+
return result;
178+
}
161179
}
162180
}
163181
} else {
164182
for &impl_def_id in impls.non_blanket_impls.values().flatten() {
165-
f(impl_def_id);
183+
if let result @ Some(_) = f(impl_def_id) {
184+
return result;
185+
}
166186
}
167187
}
188+
189+
None
168190
}
169191

170192
/// Returns an iterator containing all impls

compiler/rustc_middle/src/ty/util.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -346,14 +346,14 @@ impl<'tcx> TyCtxt<'tcx> {
346346
let drop_trait = self.lang_items().drop_trait()?;
347347
self.ensure().coherent_trait(drop_trait);
348348

349-
let mut dtor_did = None;
350349
let ty = self.type_of(adt_did);
351-
self.for_each_relevant_impl(drop_trait, ty, |impl_did| {
350+
let dtor_did = self.find_map_relevant_impl(drop_trait, ty, |impl_did| {
352351
if let Some(item) = self.associated_items(impl_did).in_definition_order().next() {
353352
if validate(self, impl_did).is_ok() {
354-
dtor_did = Some(item.def_id);
353+
return Some(item.def_id);
355354
}
356355
}
356+
None
357357
});
358358

359359
Some(ty::Destructor { did: dtor_did? })

compiler/rustc_mir/src/transform/check_const_item_mutation.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ impl<'a, 'tcx> ConstMutationChecker<'a, 'tcx> {
3434

3535
fn is_const_item_without_destructor(&self, local: Local) -> Option<DefId> {
3636
let def_id = self.is_const_item(local)?;
37-
let mut any_dtor = |_tcx, _def_id| Ok(());
3837

3938
// We avoid linting mutation of a const item if the const's type has a
4039
// Drop impl. The Drop logic observes the mutation which was performed.
@@ -54,7 +53,7 @@ impl<'a, 'tcx> ConstMutationChecker<'a, 'tcx> {
5453
//
5554
// #[const_mutation_allowed]
5655
// pub const LOG: Log = Log { msg: "" };
57-
match self.tcx.calculate_dtor(def_id, &mut any_dtor) {
56+
match self.tcx.calculate_dtor(def_id, &mut |_, _| Ok(())) {
5857
Some(_) => None,
5958
None => Some(def_id),
6059
}

compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -1384,17 +1384,11 @@ impl<'a, 'tcx> InferCtxtPrivExt<'tcx> for InferCtxt<'a, 'tcx> {
13841384
trait_ref: &ty::PolyTraitRef<'tcx>,
13851385
) {
13861386
let get_trait_impl = |trait_def_id| {
1387-
let mut trait_impl = None;
1388-
self.tcx.for_each_relevant_impl(
1387+
self.tcx.find_map_relevant_impl(
13891388
trait_def_id,
13901389
trait_ref.skip_binder().self_ty(),
1391-
|impl_def_id| {
1392-
if trait_impl.is_none() {
1393-
trait_impl = Some(impl_def_id);
1394-
}
1395-
},
1396-
);
1397-
trait_impl
1390+
|impl_def_id| Some(impl_def_id),
1391+
)
13981392
};
13991393
let required_trait_path = self.tcx.def_path_str(trait_ref.def_id());
14001394
let all_traits = self.tcx.all_traits(LOCAL_CRATE);

src/librustdoc/passes/collect_intra_doc_links.rs

+6-10
Original file line numberDiff line numberDiff line change
@@ -650,14 +650,9 @@ fn traits_implemented_by(cx: &DocContext<'_>, type_: DefId, module: DefId) -> Fx
650650
let ty = cx.tcx.type_of(type_);
651651
let iter = in_scope_traits.iter().flat_map(|&trait_| {
652652
trace!("considering explicit impl for trait {:?}", trait_);
653-
let mut saw_impl = false;
654-
// Look at each trait implementation to see if it's an impl for `did`
655-
cx.tcx.for_each_relevant_impl(trait_, ty, |impl_| {
656-
// FIXME: this is inefficient, find a way to short-circuit for_each_* so this doesn't take as long
657-
if saw_impl {
658-
return;
659-
}
660653

654+
// Look at each trait implementation to see if it's an impl for `did`
655+
cx.tcx.find_map_relevant_impl(trait_, ty, |impl_| {
661656
let trait_ref = cx.tcx.impl_trait_ref(impl_).expect("this is not an inherent impl");
662657
// Check if these are the same type.
663658
let impl_type = trait_ref.self_ty();
@@ -668,7 +663,7 @@ fn traits_implemented_by(cx: &DocContext<'_>, type_: DefId, module: DefId) -> Fx
668663
type_
669664
);
670665
// Fast path: if this is a primitive simple `==` will work
671-
saw_impl = impl_type == ty
666+
let saw_impl = impl_type == ty
672667
|| match impl_type.kind() {
673668
// Check if these are the same def_id
674669
ty::Adt(def, _) => {
@@ -678,8 +673,9 @@ fn traits_implemented_by(cx: &DocContext<'_>, type_: DefId, module: DefId) -> Fx
678673
ty::Foreign(def_id) => *def_id == type_,
679674
_ => false,
680675
};
681-
});
682-
if saw_impl { Some(trait_) } else { None }
676+
677+
if saw_impl { Some(trait_) } else { None }
678+
})
683679
});
684680
iter.collect()
685681
}

0 commit comments

Comments
 (0)