Skip to content

Commit 1834f5a

Browse files
committed
Swap encapsulation of DCP state.
1 parent 2975a21 commit 1834f5a

File tree

1 file changed

+81
-68
lines changed

1 file changed

+81
-68
lines changed

compiler/rustc_mir_dataflow/src/value_analysis.rs

+81-68
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use std::ops::Range;
3939
use rustc_data_structures::fx::FxHashMap;
4040
use rustc_data_structures::stack::ensure_sufficient_stack;
4141
use rustc_index::bit_set::BitSet;
42-
use rustc_index::{IndexSlice, IndexVec};
42+
use rustc_index::IndexVec;
4343
use rustc_middle::bug;
4444
use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
4545
use rustc_middle::mir::*;
@@ -336,14 +336,14 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper
336336
const NAME: &'static str = T::NAME;
337337

338338
fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain {
339-
State(StateData::Unreachable)
339+
State::Unreachable
340340
}
341341

342342
fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) {
343343
// The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥.
344-
assert!(matches!(state.0, StateData::Unreachable));
345-
let values = IndexVec::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
346-
*state = State(StateData::Reachable(values));
344+
assert!(matches!(state, State::Unreachable));
345+
let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
346+
*state = State::Reachable(values);
347347
for arg in body.args_iter() {
348348
state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map());
349349
}
@@ -415,27 +415,30 @@ rustc_index::newtype_index!(
415415

416416
/// See [`State`].
417417
#[derive(PartialEq, Eq, Debug)]
418-
enum StateData<V> {
419-
Reachable(IndexVec<ValueIndex, V>),
420-
Unreachable,
418+
struct StateData<V> {
419+
map: IndexVec<ValueIndex, V>,
420+
}
421+
422+
impl<V: Clone> StateData<V> {
423+
fn from_elem_n(elem: V, n: usize) -> StateData<V> {
424+
StateData { map: IndexVec::from_elem_n(elem, n) }
425+
}
421426
}
422427

423428
impl<V: Clone> Clone for StateData<V> {
424429
fn clone(&self) -> Self {
425-
match self {
426-
Self::Reachable(x) => Self::Reachable(x.clone()),
427-
Self::Unreachable => Self::Unreachable,
428-
}
430+
StateData { map: self.map.clone() }
429431
}
430432

431433
fn clone_from(&mut self, source: &Self) {
432-
match (&mut *self, source) {
433-
(Self::Reachable(x), Self::Reachable(y)) => {
434-
// We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
435-
x.raw.clone_from(&y.raw);
436-
}
437-
_ => *self = source.clone(),
438-
}
434+
// We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
435+
self.map.raw.clone_from(&source.map.raw)
436+
}
437+
}
438+
439+
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for StateData<V> {
440+
fn join(&mut self, other: &Self) -> bool {
441+
self.map.join(&other.map)
439442
}
440443
}
441444

@@ -450,33 +453,43 @@ impl<V: Clone> Clone for StateData<V> {
450453
///
451454
/// Flooding means assigning a value (by default `⊤`) to all tracked projections of a given place.
452455
#[derive(PartialEq, Eq, Debug)]
453-
pub struct State<V>(StateData<V>);
456+
pub enum State<V> {
457+
Unreachable,
458+
Reachable(StateData<V>),
459+
}
454460

455461
impl<V: Clone> Clone for State<V> {
456462
fn clone(&self) -> Self {
457-
Self(self.0.clone())
463+
match self {
464+
Self::Reachable(x) => Self::Reachable(x.clone()),
465+
Self::Unreachable => Self::Unreachable,
466+
}
458467
}
459468

460469
fn clone_from(&mut self, source: &Self) {
461-
self.0.clone_from(&source.0);
470+
match (&mut *self, source) {
471+
(Self::Reachable(x), Self::Reachable(y)) => {
472+
x.clone_from(&y);
473+
}
474+
_ => *self = source.clone(),
475+
}
462476
}
463477
}
464478

465479
impl<V: Clone> State<V> {
466480
pub fn new(init: V, map: &Map) -> State<V> {
467-
let values = IndexVec::from_elem_n(init, map.value_count);
468-
State(StateData::Reachable(values))
481+
State::Reachable(StateData::from_elem_n(init, map.value_count))
469482
}
470483

471484
pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
472-
match self.0 {
473-
StateData::Unreachable => true,
474-
StateData::Reachable(ref values) => values.iter().all(f),
485+
match self {
486+
State::Unreachable => true,
487+
State::Reachable(ref values) => values.map.iter().all(f),
475488
}
476489
}
477490

478491
fn is_reachable(&self) -> bool {
479-
matches!(&self.0, StateData::Reachable(_))
492+
matches!(self, State::Reachable(_))
480493
}
481494

482495
/// Assign `value` to all places that are contained in `place` or may alias one.
@@ -519,9 +532,9 @@ impl<V: Clone> State<V> {
519532
map: &Map,
520533
value: V,
521534
) {
522-
let StateData::Reachable(values) = &mut self.0 else { return };
535+
let State::Reachable(values) = self else { return };
523536
map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
524-
values[vi] = value.clone();
537+
values.map[vi] = value.clone();
525538
});
526539
}
527540

@@ -541,9 +554,9 @@ impl<V: Clone> State<V> {
541554
///
542555
/// The target place must have been flooded before calling this method.
543556
pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
544-
let StateData::Reachable(values) = &mut self.0 else { return };
557+
let State::Reachable(values) = self else { return };
545558
if let Some(value_index) = map.places[target].value_index {
546-
values[value_index] = value;
559+
values.map[value_index] = value;
547560
}
548561
}
549562

@@ -555,14 +568,14 @@ impl<V: Clone> State<V> {
555568
///
556569
/// The target place must have been flooded before calling this method.
557570
pub fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
558-
let StateData::Reachable(values) = &mut self.0 else { return };
571+
let State::Reachable(values) = self else { return };
559572

560573
// If both places are tracked, we copy the value to the target.
561574
// If the target is tracked, but the source is not, we do nothing, as invalidation has
562575
// already been performed.
563576
if let Some(target_value) = map.places[target].value_index {
564577
if let Some(source_value) = map.places[source].value_index {
565-
values[target_value] = values[source_value].clone();
578+
values.map[target_value] = values.map[source_value].clone();
566579
}
567580
}
568581
for target_child in map.children(target) {
@@ -616,11 +629,11 @@ impl<V: Clone> State<V> {
616629

617630
/// Retrieve the value stored for a place index, or `None` if it is not tracked.
618631
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
619-
match &self.0 {
620-
StateData::Reachable(values) => {
621-
map.places[place].value_index.map(|v| values[v].clone())
632+
match self {
633+
State::Reachable(values) => {
634+
map.places[place].value_index.map(|v| values.map[v].clone())
622635
}
623-
StateData::Unreachable => None,
636+
State::Unreachable => None,
624637
}
625638
}
626639

@@ -631,10 +644,10 @@ impl<V: Clone> State<V> {
631644
where
632645
V: HasBottom + HasTop,
633646
{
634-
match &self.0 {
635-
StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
647+
match self {
648+
State::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
636649
// Because this is unreachable, we can return any value we want.
637-
StateData::Unreachable => V::BOTTOM,
650+
State::Unreachable => V::BOTTOM,
638651
}
639652
}
640653

@@ -645,10 +658,10 @@ impl<V: Clone> State<V> {
645658
where
646659
V: HasBottom + HasTop,
647660
{
648-
match &self.0 {
649-
StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
661+
match self {
662+
State::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
650663
// Because this is unreachable, we can return any value we want.
651-
StateData::Unreachable => V::BOTTOM,
664+
State::Unreachable => V::BOTTOM,
652665
}
653666
}
654667

@@ -659,10 +672,10 @@ impl<V: Clone> State<V> {
659672
where
660673
V: HasBottom + HasTop,
661674
{
662-
match &self.0 {
663-
StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
675+
match self {
676+
State::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
664677
// Because this is unreachable, we can return any value we want.
665-
StateData::Unreachable => V::BOTTOM,
678+
State::Unreachable => V::BOTTOM,
666679
}
667680
}
668681

@@ -673,11 +686,11 @@ impl<V: Clone> State<V> {
673686
where
674687
V: HasBottom + HasTop,
675688
{
676-
match &self.0 {
677-
StateData::Reachable(values) => {
678-
map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP)
689+
match self {
690+
State::Reachable(values) => {
691+
map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP)
679692
}
680-
StateData::Unreachable => {
693+
State::Unreachable => {
681694
// Because this is unreachable, we can return any value we want.
682695
V::BOTTOM
683696
}
@@ -687,13 +700,13 @@ impl<V: Clone> State<V> {
687700

688701
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> {
689702
fn join(&mut self, other: &Self) -> bool {
690-
match (&mut self.0, &other.0) {
691-
(_, StateData::Unreachable) => false,
692-
(StateData::Unreachable, _) => {
703+
match (&mut *self, other) {
704+
(_, State::Unreachable) => false,
705+
(State::Unreachable, _) => {
693706
*self = other.clone();
694707
true
695708
}
696-
(StateData::Reachable(this), StateData::Reachable(other)) => this.join(other),
709+
(State::Reachable(this), State::Reachable(ref other)) => this.join(other),
697710
}
698711
}
699712
}
@@ -1194,9 +1207,9 @@ where
11941207
T::Value: Debug,
11951208
{
11961209
fn fmt_with(&self, ctxt: &ValueAnalysisWrapper<T>, f: &mut Formatter<'_>) -> std::fmt::Result {
1197-
match &self.0 {
1198-
StateData::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f),
1199-
StateData::Unreachable => write!(f, "unreachable"),
1210+
match self {
1211+
State::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f),
1212+
State::Unreachable => write!(f, "unreachable"),
12001213
}
12011214
}
12021215

@@ -1206,8 +1219,8 @@ where
12061219
ctxt: &ValueAnalysisWrapper<T>,
12071220
f: &mut Formatter<'_>,
12081221
) -> std::fmt::Result {
1209-
match (&self.0, &old.0) {
1210-
(StateData::Reachable(this), StateData::Reachable(old)) => {
1222+
match (self, old) {
1223+
(State::Reachable(this), State::Reachable(old)) => {
12111224
debug_with_context(this, Some(old), ctxt.0.map(), f)
12121225
}
12131226
_ => Ok(()), // Consider printing something here.
@@ -1218,18 +1231,18 @@ where
12181231
fn debug_with_context_rec<V: Debug + Eq>(
12191232
place: PlaceIndex,
12201233
place_str: &str,
1221-
new: &IndexSlice<ValueIndex, V>,
1222-
old: Option<&IndexSlice<ValueIndex, V>>,
1234+
new: &StateData<V>,
1235+
old: Option<&StateData<V>>,
12231236
map: &Map,
12241237
f: &mut Formatter<'_>,
12251238
) -> std::fmt::Result {
12261239
if let Some(value) = map.places[place].value_index {
12271240
match old {
1228-
None => writeln!(f, "{}: {:?}", place_str, new[value])?,
1241+
None => writeln!(f, "{}: {:?}", place_str, new.map[value])?,
12291242
Some(old) => {
1230-
if new[value] != old[value] {
1231-
writeln!(f, "\u{001f}-{}: {:?}", place_str, old[value])?;
1232-
writeln!(f, "\u{001f}+{}: {:?}", place_str, new[value])?;
1243+
if new.map[value] != old.map[value] {
1244+
writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?;
1245+
writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?;
12331246
}
12341247
}
12351248
}
@@ -1262,8 +1275,8 @@ fn debug_with_context_rec<V: Debug + Eq>(
12621275
}
12631276

12641277
fn debug_with_context<V: Debug + Eq>(
1265-
new: &IndexSlice<ValueIndex, V>,
1266-
old: Option<&IndexSlice<ValueIndex, V>>,
1278+
new: &StateData<V>,
1279+
old: Option<&StateData<V>>,
12671280
map: &Map,
12681281
f: &mut Formatter<'_>,
12691282
) -> std::fmt::Result {

0 commit comments

Comments
 (0)