Skip to content

Commit c8a9c34

Browse files
committed
Auto merge of rust-lang#72962 - lcnr:ObligationCause-lrc, r=ecstatic-morse
store `ObligationCause` on the heap Stores `ObligationCause` on the heap using an `Rc`. This PR trades off some transient memory allocations to reduce the size of–and thus the number of instructions required to memcpy–a few widely used data structures in trait solving.
2 parents f315c35 + ea668d9 commit c8a9c34

File tree

10 files changed

+82
-38
lines changed

10 files changed

+82
-38
lines changed

src/librustc_infer/traits/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub type TraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>;
5959

6060
// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
6161
#[cfg(target_arch = "x86_64")]
62-
static_assert_size!(PredicateObligation<'_>, 88);
62+
static_assert_size!(PredicateObligation<'_>, 48);
6363

6464
pub type Obligations<'tcx, O> = Vec<Obligation<'tcx, O>>;
6565
pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;

src/librustc_infer/traits/util.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ fn predicate_obligation<'tcx>(
142142
predicate: ty::Predicate<'tcx>,
143143
span: Option<Span>,
144144
) -> PredicateObligation<'tcx> {
145-
let mut cause = ObligationCause::dummy();
146-
if let Some(span) = span {
147-
cause.span = span;
148-
}
145+
let cause = if let Some(span) = span {
146+
ObligationCause::dummy_with_span(span)
147+
} else {
148+
ObligationCause::dummy()
149+
};
150+
149151
Obligation { cause, param_env: ty::ParamEnv::empty(), recursion_depth: 0, predicate }
150152
}
151153

src/librustc_middle/traits/mod.rs

+46-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ use rustc_span::{Span, DUMMY_SP};
2020
use smallvec::SmallVec;
2121

2222
use std::borrow::Cow;
23-
use std::fmt::Debug;
23+
use std::fmt;
24+
use std::ops::Deref;
2425
use std::rc::Rc;
2526

2627
pub use self::select::{EvaluationCache, EvaluationResult, OverflowError, SelectionCache};
@@ -80,8 +81,39 @@ pub enum Reveal {
8081
}
8182

8283
/// The reason why we incurred this obligation; used for error reporting.
83-
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
84+
///
85+
/// As the happy path does not care about this struct, storing this on the heap
86+
/// ends up increasing performance.
87+
///
88+
/// We do not want to intern this as there are a lot of obligation causes which
89+
/// only live for a short period of time.
90+
#[derive(Clone, PartialEq, Eq, Hash)]
8491
pub struct ObligationCause<'tcx> {
92+
/// `None` for `ObligationCause::dummy`, `Some` otherwise.
93+
data: Option<Rc<ObligationCauseData<'tcx>>>,
94+
}
95+
96+
const DUMMY_OBLIGATION_CAUSE_DATA: ObligationCauseData<'static> =
97+
ObligationCauseData { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation };
98+
99+
// Correctly format `ObligationCause::dummy`.
100+
impl<'tcx> fmt::Debug for ObligationCause<'tcx> {
101+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102+
ObligationCauseData::fmt(self, f)
103+
}
104+
}
105+
106+
impl Deref for ObligationCause<'tcx> {
107+
type Target = ObligationCauseData<'tcx>;
108+
109+
#[inline(always)]
110+
fn deref(&self) -> &Self::Target {
111+
self.data.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_DATA)
112+
}
113+
}
114+
115+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
116+
pub struct ObligationCauseData<'tcx> {
85117
pub span: Span,
86118

87119
/// The ID of the fn body that triggered this obligation. This is
@@ -102,15 +134,24 @@ impl<'tcx> ObligationCause<'tcx> {
102134
body_id: hir::HirId,
103135
code: ObligationCauseCode<'tcx>,
104136
) -> ObligationCause<'tcx> {
105-
ObligationCause { span, body_id, code }
137+
ObligationCause { data: Some(Rc::new(ObligationCauseData { span, body_id, code })) }
106138
}
107139

108140
pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
109-
ObligationCause { span, body_id, code: MiscObligation }
141+
ObligationCause::new(span, body_id, MiscObligation)
110142
}
111143

144+
pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
145+
ObligationCause::new(span, hir::CRATE_HIR_ID, MiscObligation)
146+
}
147+
148+
#[inline(always)]
112149
pub fn dummy() -> ObligationCause<'tcx> {
113-
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation }
150+
ObligationCause { data: None }
151+
}
152+
153+
pub fn make_mut(&mut self) -> &mut ObligationCauseData<'tcx> {
154+
Rc::make_mut(self.data.get_or_insert_with(|| Rc::new(DUMMY_OBLIGATION_CAUSE_DATA)))
114155
}
115156

116157
pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {

src/librustc_middle/traits/structural_impls.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,7 @@ impl<'a, 'tcx> Lift<'tcx> for traits::DerivedObligationCause<'a> {
232232
impl<'a, 'tcx> Lift<'tcx> for traits::ObligationCause<'a> {
233233
type Lifted = traits::ObligationCause<'tcx>;
234234
fn lift_to_tcx(&self, tcx: TyCtxt<'tcx>) -> Option<Self::Lifted> {
235-
tcx.lift(&self.code).map(|code| traits::ObligationCause {
236-
span: self.span,
237-
body_id: self.body_id,
238-
code,
239-
})
235+
tcx.lift(&self.code).map(|code| traits::ObligationCause::new(self.span, self.body_id, code))
240236
}
241237
}
242238

src/librustc_mir/borrow_check/type_check/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1250,7 +1250,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
12501250
|infcx| {
12511251
let mut obligations = ObligationAccumulator::default();
12521252

1253-
let dummy_body_id = ObligationCause::dummy().body_id;
1253+
let dummy_body_id = hir::CRATE_HIR_ID;
12541254
let (output_ty, opaque_type_map) =
12551255
obligations.add(infcx.instantiate_opaque_types(
12561256
anon_owner_def_id,

src/librustc_trait_selection/traits/fulfill.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ pub struct PendingPredicateObligation<'tcx> {
8484

8585
// `PendingPredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
8686
#[cfg(target_arch = "x86_64")]
87-
static_assert_size!(PendingPredicateObligation<'_>, 112);
87+
static_assert_size!(PendingPredicateObligation<'_>, 72);
8888

8989
impl<'a, 'tcx> FulfillmentContext<'tcx> {
9090
/// Creates a new fulfillment context.

src/librustc_trait_selection/traits/misc.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub fn can_type_implement_copy(
4848
continue;
4949
}
5050
let span = tcx.def_span(field.did);
51-
let cause = ObligationCause { span, ..ObligationCause::dummy() };
51+
let cause = ObligationCause::dummy_with_span(span);
5252
let ctx = traits::FulfillmentContext::new();
5353
match traits::fully_normalize(&infcx, ctx, cause, param_env, &ty) {
5454
Ok(ty) => {

src/librustc_trait_selection/traits/wf.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ fn extend_cause_with_original_assoc_item_obligation<'tcx>(
205205
if let Some(impl_item_span) =
206206
items.iter().find(|item| item.ident == trait_assoc_item.ident).map(fix_span)
207207
{
208-
cause.span = impl_item_span;
208+
cause.make_mut().span = impl_item_span;
209209
}
210210
}
211211
}
@@ -222,7 +222,7 @@ fn extend_cause_with_original_assoc_item_obligation<'tcx>(
222222
items.iter().find(|i| i.ident == trait_assoc_item.ident).map(fix_span)
223223
})
224224
{
225-
cause.span = impl_item_span;
225+
cause.make_mut().span = impl_item_span;
226226
}
227227
}
228228
}
@@ -273,7 +273,8 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> {
273273
parent_trait_ref,
274274
parent_code: Rc::new(obligation.cause.code.clone()),
275275
};
276-
cause.code = traits::ObligationCauseCode::DerivedObligation(derived_cause);
276+
cause.make_mut().code =
277+
traits::ObligationCauseCode::DerivedObligation(derived_cause);
277278
}
278279
extend_cause_with_original_assoc_item_obligation(
279280
tcx,

src/librustc_typeck/check/compare_method.rs

+18-14
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,16 @@ fn compare_predicate_entailment<'tcx>(
7878
// `regionck_item` expects.
7979
let impl_m_hir_id = tcx.hir().as_local_hir_id(impl_m.def_id.expect_local());
8080

81-
let cause = ObligationCause {
82-
span: impl_m_span,
83-
body_id: impl_m_hir_id,
84-
code: ObligationCauseCode::CompareImplMethodObligation {
81+
// We sometimes modify the span further down.
82+
let mut cause = ObligationCause::new(
83+
impl_m_span,
84+
impl_m_hir_id,
85+
ObligationCauseCode::CompareImplMethodObligation {
8586
item_name: impl_m.ident.name,
8687
impl_item_def_id: impl_m.def_id,
8788
trait_item_def_id: trait_m.def_id,
8889
},
89-
};
90+
);
9091

9192
// This code is best explained by example. Consider a trait:
9293
//
@@ -280,7 +281,7 @@ fn compare_predicate_entailment<'tcx>(
280281
&infcx, param_env, &terr, &cause, impl_m, impl_sig, trait_m, trait_sig,
281282
);
282283

283-
let cause = ObligationCause { span: impl_err_span, ..cause };
284+
cause.make_mut().span = impl_err_span;
284285

285286
let mut diag = struct_span_err!(
286287
tcx.sess,
@@ -965,8 +966,11 @@ crate fn compare_const_impl<'tcx>(
965966
// Compute placeholder form of impl and trait const tys.
966967
let impl_ty = tcx.type_of(impl_c.def_id);
967968
let trait_ty = tcx.type_of(trait_c.def_id).subst(tcx, trait_to_impl_substs);
968-
let mut cause = ObligationCause::misc(impl_c_span, impl_c_hir_id);
969-
cause.code = ObligationCauseCode::CompareImplConstObligation;
969+
let mut cause = ObligationCause::new(
970+
impl_c_span,
971+
impl_c_hir_id,
972+
ObligationCauseCode::CompareImplConstObligation,
973+
);
970974

971975
// There is no "body" here, so just pass dummy id.
972976
let impl_ty =
@@ -992,7 +996,7 @@ crate fn compare_const_impl<'tcx>(
992996

993997
// Locate the Span containing just the type of the offending impl
994998
match tcx.hir().expect_impl_item(impl_c_hir_id).kind {
995-
ImplItemKind::Const(ref ty, _) => cause.span = ty.span,
999+
ImplItemKind::Const(ref ty, _) => cause.make_mut().span = ty.span,
9961000
_ => bug!("{:?} is not a impl const", impl_c),
9971001
}
9981002

@@ -1095,15 +1099,15 @@ fn compare_type_predicate_entailment(
10951099
// `ObligationCause` (and the `FnCtxt`). This is what
10961100
// `regionck_item` expects.
10971101
let impl_ty_hir_id = tcx.hir().as_local_hir_id(impl_ty.def_id.expect_local());
1098-
let cause = ObligationCause {
1099-
span: impl_ty_span,
1100-
body_id: impl_ty_hir_id,
1101-
code: ObligationCauseCode::CompareImplTypeObligation {
1102+
let cause = ObligationCause::new(
1103+
impl_ty_span,
1104+
impl_ty_hir_id,
1105+
ObligationCauseCode::CompareImplTypeObligation {
11021106
item_name: impl_ty.ident.name,
11031107
impl_item_def_id: impl_ty.def_id,
11041108
trait_item_def_id: trait_ty.def_id,
11051109
},
1106-
};
1110+
);
11071111

11081112
debug!("compare_type_predicate_entailment: trait_to_impl_substs={:?}", trait_to_impl_substs);
11091113

src/librustc_typeck/check/mod.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -4218,7 +4218,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
42184218
if let (Some(ref_in), None) = (referenced_in.pop(), referenced_in.pop()) {
42194219
// We make sure that only *one* argument matches the obligation failure
42204220
// and we assign the obligation's span to its expression's.
4221-
error.obligation.cause.span = args[ref_in].span;
4221+
error.obligation.cause.make_mut().span = args[ref_in].span;
42224222
error.points_at_arg_span = true;
42234223
}
42244224
}
@@ -4261,7 +4261,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
42614261
let ty = AstConv::ast_ty_to_ty(self, hir_ty);
42624262
let ty = self.resolve_vars_if_possible(&ty);
42634263
if ty == predicate.skip_binder().self_ty() {
4264-
error.obligation.cause.span = hir_ty.span;
4264+
error.obligation.cause.make_mut().span = hir_ty.span;
42654265
}
42664266
}
42674267
}
@@ -5689,7 +5689,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
56895689
{
56905690
// This makes the error point at the bound, but we want to point at the argument
56915691
if let Some(span) = spans.get(i) {
5692-
obligation.cause.code = traits::BindingObligation(def_id, *span);
5692+
obligation.cause.make_mut().code = traits::BindingObligation(def_id, *span);
56935693
}
56945694
self.register_predicate(obligation);
56955695
}

0 commit comments

Comments
 (0)