Skip to content

Commit 76244d4

Browse files
committed
Make jump threading state sparse.
1 parent 1834f5a commit 76244d4

File tree

3 files changed

+86
-38
lines changed

3 files changed

+86
-38
lines changed

compiler/rustc_mir_dataflow/src/framework/lattice.rs

+14
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ pub trait MeetSemiLattice: Eq {
7676
/// A set that has a "bottom" element, which is less than or equal to any other element.
7777
pub trait HasBottom {
7878
const BOTTOM: Self;
79+
80+
fn is_bottom(&self) -> bool;
7981
}
8082

8183
/// A set that has a "top" element, which is greater than or equal to any other element.
@@ -114,6 +116,10 @@ impl MeetSemiLattice for bool {
114116

115117
impl HasBottom for bool {
116118
const BOTTOM: Self = false;
119+
120+
fn is_bottom(&self) -> bool {
121+
!self
122+
}
117123
}
118124

119125
impl HasTop for bool {
@@ -267,6 +273,10 @@ impl<T: Clone + Eq> MeetSemiLattice for FlatSet<T> {
267273

268274
impl<T> HasBottom for FlatSet<T> {
269275
const BOTTOM: Self = Self::Bottom;
276+
277+
fn is_bottom(&self) -> bool {
278+
matches!(self, Self::Bottom)
279+
}
270280
}
271281

272282
impl<T> HasTop for FlatSet<T> {
@@ -291,6 +301,10 @@ impl<T> MaybeReachable<T> {
291301

292302
impl<T> HasBottom for MaybeReachable<T> {
293303
const BOTTOM: Self = MaybeReachable::Unreachable;
304+
305+
fn is_bottom(&self) -> bool {
306+
matches!(self, Self::Unreachable)
307+
}
294308
}
295309

296310
impl<T: HasTop> HasTop for MaybeReachable<T> {

compiler/rustc_mir_dataflow/src/value_analysis.rs

+58-33
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use std::collections::VecDeque;
3636
use std::fmt::{Debug, Formatter};
3737
use std::ops::Range;
3838

39-
use rustc_data_structures::fx::FxHashMap;
39+
use rustc_data_structures::fx::{FxHashMap, StdEntry};
4040
use rustc_data_structures::stack::ensure_sufficient_stack;
4141
use rustc_index::bit_set::BitSet;
4242
use rustc_index::IndexVec;
@@ -342,8 +342,7 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper
342342
fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) {
343343
// The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥.
344344
assert!(matches!(state, State::Unreachable));
345-
let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
346-
*state = State::Reachable(values);
345+
*state = State::new_reachable();
347346
for arg in body.args_iter() {
348347
state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map());
349348
}
@@ -415,30 +414,54 @@ rustc_index::newtype_index!(
415414

416415
/// See [`State`].
417416
#[derive(PartialEq, Eq, Debug)]
418-
struct StateData<V> {
419-
map: IndexVec<ValueIndex, V>,
417+
pub struct StateData<V> {
418+
bottom: V,
419+
/// This map only contains values that are not `⊥`.
420+
map: FxHashMap<ValueIndex, V>,
420421
}
421422

422-
impl<V: Clone> StateData<V> {
423-
fn from_elem_n(elem: V, n: usize) -> StateData<V> {
424-
StateData { map: IndexVec::from_elem_n(elem, n) }
423+
impl<V: HasBottom> StateData<V> {
424+
fn new() -> StateData<V> {
425+
StateData { bottom: V::BOTTOM, map: FxHashMap::default() }
426+
}
427+
428+
fn get(&self, idx: ValueIndex) -> &V {
429+
self.map.get(&idx).unwrap_or(&self.bottom)
430+
}
431+
432+
fn insert(&mut self, idx: ValueIndex, elem: V) {
433+
if elem.is_bottom() {
434+
self.map.remove(&idx);
435+
} else {
436+
self.map.insert(idx, elem);
437+
}
425438
}
426439
}
427440

428441
impl<V: Clone> Clone for StateData<V> {
429442
fn clone(&self) -> Self {
430-
StateData { map: self.map.clone() }
443+
StateData { bottom: self.bottom.clone(), map: self.map.clone() }
431444
}
432445

433446
fn clone_from(&mut self, source: &Self) {
434-
// We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
435-
self.map.raw.clone_from(&source.map.raw)
447+
self.map.clone_from(&source.map)
436448
}
437449
}
438450

439-
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for StateData<V> {
451+
impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for StateData<V> {
440452
fn join(&mut self, other: &Self) -> bool {
441-
self.map.join(&other.map)
453+
let mut changed = false;
454+
#[allow(rustc::potential_query_instability)]
455+
for (i, v) in other.map.iter() {
456+
match self.map.entry(*i) {
457+
StdEntry::Vacant(e) => {
458+
e.insert(v.clone());
459+
changed = true
460+
}
461+
StdEntry::Occupied(e) => changed |= e.into_mut().join(v),
462+
}
463+
}
464+
changed
442465
}
443466
}
444467

@@ -476,15 +499,19 @@ impl<V: Clone> Clone for State<V> {
476499
}
477500
}
478501

479-
impl<V: Clone> State<V> {
480-
pub fn new(init: V, map: &Map) -> State<V> {
481-
State::Reachable(StateData::from_elem_n(init, map.value_count))
502+
impl<V: Clone + HasBottom> State<V> {
503+
pub fn new_reachable() -> State<V> {
504+
State::Reachable(StateData::new())
482505
}
483506

484-
pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
507+
pub fn all_bottom(&self) -> bool {
485508
match self {
486-
State::Unreachable => true,
487-
State::Reachable(ref values) => values.map.iter().all(f),
509+
State::Unreachable => false,
510+
State::Reachable(ref values) =>
511+
{
512+
#[allow(rustc::potential_query_instability)]
513+
values.map.values().all(V::is_bottom)
514+
}
488515
}
489516
}
490517

@@ -533,9 +560,7 @@ impl<V: Clone> State<V> {
533560
value: V,
534561
) {
535562
let State::Reachable(values) = self else { return };
536-
map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
537-
values.map[vi] = value.clone();
538-
});
563+
map.for_each_aliasing_place(place, tail_elem, &mut |vi| values.insert(vi, value.clone()));
539564
}
540565

541566
/// Low-level method that assigns to a place.
@@ -556,7 +581,7 @@ impl<V: Clone> State<V> {
556581
pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
557582
let State::Reachable(values) = self else { return };
558583
if let Some(value_index) = map.places[target].value_index {
559-
values.map[value_index] = value;
584+
values.insert(value_index, value)
560585
}
561586
}
562587

@@ -575,7 +600,7 @@ impl<V: Clone> State<V> {
575600
// already been performed.
576601
if let Some(target_value) = map.places[target].value_index {
577602
if let Some(source_value) = map.places[source].value_index {
578-
values.map[target_value] = values.map[source_value].clone();
603+
values.insert(target_value, values.get(source_value).clone());
579604
}
580605
}
581606
for target_child in map.children(target) {
@@ -631,7 +656,7 @@ impl<V: Clone> State<V> {
631656
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
632657
match self {
633658
State::Reachable(values) => {
634-
map.places[place].value_index.map(|v| values.map[v].clone())
659+
map.places[place].value_index.map(|v| values.get(v).clone())
635660
}
636661
State::Unreachable => None,
637662
}
@@ -688,7 +713,7 @@ impl<V: Clone> State<V> {
688713
{
689714
match self {
690715
State::Reachable(values) => {
691-
map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP)
716+
map.places[place].value_index.map(|v| values.get(v).clone()).unwrap_or(V::TOP)
692717
}
693718
State::Unreachable => {
694719
// Because this is unreachable, we can return any value we want.
@@ -698,7 +723,7 @@ impl<V: Clone> State<V> {
698723
}
699724
}
700725

701-
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> {
726+
impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for State<V> {
702727
fn join(&mut self, other: &Self) -> bool {
703728
match (&mut *self, other) {
704729
(_, State::Unreachable) => false,
@@ -1228,7 +1253,7 @@ where
12281253
}
12291254
}
12301255

1231-
fn debug_with_context_rec<V: Debug + Eq>(
1256+
fn debug_with_context_rec<V: Debug + Eq + HasBottom>(
12321257
place: PlaceIndex,
12331258
place_str: &str,
12341259
new: &StateData<V>,
@@ -1238,11 +1263,11 @@ fn debug_with_context_rec<V: Debug + Eq>(
12381263
) -> std::fmt::Result {
12391264
if let Some(value) = map.places[place].value_index {
12401265
match old {
1241-
None => writeln!(f, "{}: {:?}", place_str, new.map[value])?,
1266+
None => writeln!(f, "{}: {:?}", place_str, new.get(value))?,
12421267
Some(old) => {
1243-
if new.map[value] != old.map[value] {
1244-
writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?;
1245-
writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?;
1268+
if new.get(value) != old.get(value) {
1269+
writeln!(f, "\u{001f}-{}: {:?}", place_str, old.get(value))?;
1270+
writeln!(f, "\u{001f}+{}: {:?}", place_str, new.get(value))?;
12461271
}
12471272
}
12481273
}
@@ -1274,7 +1299,7 @@ fn debug_with_context_rec<V: Debug + Eq>(
12741299
Ok(())
12751300
}
12761301

1277-
fn debug_with_context<V: Debug + Eq>(
1302+
fn debug_with_context<V: Debug + Eq + HasBottom>(
12781303
new: &StateData<V>,
12791304
old: Option<&StateData<V>>,
12801305
map: &Map,

compiler/rustc_mir_transform/src/jump_threading.rs

+14-5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use rustc_middle::mir::visit::Visitor;
4747
use rustc_middle::mir::*;
4848
use rustc_middle::ty::layout::LayoutOf;
4949
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
50+
use rustc_mir_dataflow::lattice::HasBottom;
5051
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
5152
use rustc_span::DUMMY_SP;
5253
use rustc_target::abi::{TagEncoding, Variants};
@@ -158,9 +159,17 @@ impl Condition {
158159
}
159160
}
160161

161-
#[derive(Copy, Clone, Debug, Default)]
162+
#[derive(Copy, Clone, Debug)]
162163
struct ConditionSet<'a>(&'a [Condition]);
163164

165+
impl HasBottom for ConditionSet<'_> {
166+
const BOTTOM: Self = ConditionSet(&[]);
167+
168+
fn is_bottom(&self) -> bool {
169+
self.0.is_empty()
170+
}
171+
}
172+
164173
impl<'a> ConditionSet<'a> {
165174
fn iter(self) -> impl Iterator<Item = Condition> + 'a {
166175
self.0.iter().copied()
@@ -177,7 +186,7 @@ impl<'a> ConditionSet<'a> {
177186

178187
impl<'tcx, 'a> TOFinder<'tcx, 'a> {
179188
fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
180-
state.all(|cs| cs.0.is_empty())
189+
state.all_bottom()
181190
}
182191

183192
/// Recursion entry point to find threading opportunities.
@@ -198,7 +207,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
198207
debug!(?discr);
199208

200209
let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
201-
let mut state = State::new(ConditionSet::default(), self.map);
210+
let mut state = State::new_reachable();
202211

203212
let conds = if let Some((value, then, else_)) = targets.as_static_if() {
204213
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
@@ -255,7 +264,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
255264
// _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
256265
// _1 = 6
257266
if let Some((lhs, tail)) = self.mutated_statement(stmt) {
258-
state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default());
267+
state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::BOTTOM);
259268
}
260269
}
261270

@@ -609,7 +618,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
609618
// We can recurse through this terminator.
610619
let mut state = state();
611620
if let Some(place_to_flood) = place_to_flood {
612-
state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default());
621+
state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::BOTTOM);
613622
}
614623
self.find_opportunity(bb, state, cost.clone(), depth + 1);
615624
}

0 commit comments

Comments
 (0)