Skip to content

Commit 9a6c04f

Browse files
committed
Handle discriminants in dataflow-const-prop.
1 parent cd3649b commit 9a6c04f

File tree

6 files changed

+305
-59
lines changed

6 files changed

+305
-59
lines changed

compiler/rustc_mir_dataflow/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![feature(associated_type_defaults)]
22
#![feature(box_patterns)]
33
#![feature(exact_size_is_empty)]
4+
#![feature(let_chains)]
45
#![feature(min_specialization)]
56
#![feature(once_cell)]
67
#![feature(stmt_expr_attributes)]

compiler/rustc_mir_dataflow/src/value_analysis.rs

Lines changed: 142 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,8 @@ pub trait ValueAnalysis<'tcx> {
6565
StatementKind::Assign(box (place, rvalue)) => {
6666
self.handle_assign(*place, rvalue, state);
6767
}
68-
StatementKind::SetDiscriminant { .. } => {
69-
// Could treat this as writing a constant to a pseudo-place.
70-
// But discriminants are currently not tracked, so we do nothing.
71-
// Related: https://github.com/rust-lang/unsafe-code-guidelines/issues/84
68+
StatementKind::SetDiscriminant { box ref place, .. } => {
69+
state.flood_discr(place.as_ref(), self.map());
7270
}
7371
StatementKind::Intrinsic(box intrinsic) => {
7472
self.handle_intrinsic(intrinsic, state);
@@ -447,34 +445,39 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
447445
}
448446

449447
pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
450-
if let Some(root) = map.find(place) {
451-
self.flood_idx_with(root, map, value);
452-
}
448+
let StateData::Reachable(values) = &mut self.0 else { return };
449+
map.for_each_aliasing_place(place, None, &mut |place| {
450+
if let Some(vi) = map.places[place].value_index {
451+
values[vi] = value.clone();
452+
}
453+
});
453454
}
454455

455456
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
456457
self.flood_with(place, map, V::top())
457458
}
458459

459-
pub fn flood_idx_with(&mut self, place: PlaceIndex, map: &Map, value: V) {
460+
pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
460461
let StateData::Reachable(values) = &mut self.0 else { return };
461-
map.preorder_invoke(place, &mut |place| {
462+
map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |place| {
462463
if let Some(vi) = map.places[place].value_index {
463464
values[vi] = value.clone();
464465
}
465466
});
466467
}
467468

468-
pub fn flood_idx(&mut self, place: PlaceIndex, map: &Map) {
469-
self.flood_idx_with(place, map, V::top())
469+
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
470+
self.flood_discr_with(place, map, V::top())
470471
}
471472

472473
/// Copies `source` to `target`, including all tracked places beneath.
473474
///
474475
/// If `target` contains a place that is not contained in `source`, it will be overwritten with
475476
/// Top. Also, because this will copy all entries one after another, it may only be used for
476477
/// places that are non-overlapping or identical.
477-
pub fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
478+
///
479+
/// The target place must have been flooded before calling this method.
480+
fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
478481
let StateData::Reachable(values) = &mut self.0 else { return };
479482

480483
// If both places are tracked, we copy the value to the target. If the target is tracked,
@@ -492,26 +495,28 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
492495
let projection = map.places[target_child].proj_elem.unwrap();
493496
if let Some(source_child) = map.projections.get(&(source, projection)) {
494497
self.assign_place_idx(target_child, *source_child, map);
495-
} else {
496-
self.flood_idx(target_child, map);
497498
}
498499
}
499500
}
500501

501502
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
503+
self.flood(target, map);
502504
if let Some(target) = map.find(target) {
503505
self.assign_idx(target, result, map);
504-
} else {
505-
// We don't track this place nor any projections, assignment can be ignored.
506506
}
507507
}
508508

509+
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
510+
self.flood_discr(target, map);
511+
if let Some(target) = map.find_discr(target) {
512+
self.assign_idx(target, result, map);
513+
}
514+
}
515+
516+
/// The target place must have been flooded before calling this method.
509517
pub fn assign_idx(&mut self, target: PlaceIndex, result: ValueOrPlace<V>, map: &Map) {
510518
match result {
511519
ValueOrPlace::Value(value) => {
512-
// First flood the target place in case we also track any projections (although
513-
// this scenario is currently not well-supported by the API).
514-
self.flood_idx(target, map);
515520
let StateData::Reachable(values) = &mut self.0 else { return };
516521
if let Some(value_index) = map.places[target].value_index {
517522
values[value_index] = value;
@@ -526,6 +531,14 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
526531
map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::top())
527532
}
528533

534+
/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
535+
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
536+
match map.find_discr(place) {
537+
Some(place) => self.get_idx(place, map),
538+
None => V::top(),
539+
}
540+
}
541+
529542
/// Retrieve the value stored for a place index, or ⊤ if it is not tracked.
530543
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
531544
match &self.0 {
@@ -582,7 +595,6 @@ impl Map {
582595
/// This is currently the only way to create a [`Map`]. The way in which the tracked places are
583596
/// chosen is an implementation detail and may not be relied upon (other than that their type
584597
/// passes the filter).
585-
#[instrument(skip_all, level = "debug")]
586598
pub fn from_filter<'tcx>(
587599
tcx: TyCtxt<'tcx>,
588600
body: &Body<'tcx>,
@@ -614,7 +626,7 @@ impl Map {
614626

615627
/// Potentially register the (local, projection) place and its fields, recursively.
616628
///
617-
/// Invariant: The projection must only contain fields.
629+
/// Invariant: The projection must only contain trackable elements.
618630
fn register_with_filter_rec<'tcx>(
619631
&mut self,
620632
tcx: TyCtxt<'tcx>,
@@ -623,21 +635,46 @@ impl Map {
623635
ty: Ty<'tcx>,
624636
filter: &mut impl FnMut(Ty<'tcx>) -> bool,
625637
) {
626-
if filter(ty) {
627-
// We know that the projection only contains trackable elements.
628-
let place = self.make_place(local, projection).unwrap();
638+
// We know that the projection only contains trackable elements.
639+
let place = self.make_place(local, projection).unwrap();
629640

630-
// Allocate a value slot if it doesn't have one.
631-
if self.places[place].value_index.is_none() {
632-
self.places[place].value_index = Some(self.value_count.into());
633-
self.value_count += 1;
641+
// Allocate a value slot if it doesn't have one, and the user requested one.
642+
if self.places[place].value_index.is_none() && filter(ty) {
643+
self.places[place].value_index = Some(self.value_count.into());
644+
self.value_count += 1;
645+
}
646+
647+
if ty.is_enum() {
648+
let discr_ty = ty.discriminant_ty(tcx);
649+
if filter(discr_ty) {
650+
let discr = *self
651+
.projections
652+
.entry((place, TrackElem::Discriminant))
653+
.or_insert_with(|| {
654+
// Prepend new child to the linked list.
655+
let next = self.places.push(PlaceInfo::new(Some(TrackElem::Discriminant)));
656+
self.places[next].next_sibling = self.places[place].first_child;
657+
self.places[place].first_child = Some(next);
658+
next
659+
});
660+
661+
// Allocate a value slot if it doesn't have one.
662+
if self.places[discr].value_index.is_none() {
663+
self.places[discr].value_index = Some(self.value_count.into());
664+
self.value_count += 1;
665+
}
634666
}
635667
}
636668

637669
// Recurse with all fields of this place.
638670
iter_fields(ty, tcx, |variant, field, ty| {
639-
if variant.is_some() {
640-
// Downcasts are currently not supported.
671+
if let Some(variant) = variant {
672+
projection.push(PlaceElem::Downcast(None, variant));
673+
let _ = self.make_place(local, projection);
674+
projection.push(PlaceElem::Field(field, ty));
675+
self.register_with_filter_rec(tcx, local, projection, ty, filter);
676+
projection.pop();
677+
projection.pop();
641678
return;
642679
}
643680
projection.push(PlaceElem::Field(field, ty));
@@ -694,13 +731,77 @@ impl Map {
694731
Some(index)
695732
}
696733

734+
/// Locates the given place, if it exists in the tree.
735+
pub fn find_discr(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
736+
let index = self.find(place)?;
737+
self.apply(index, TrackElem::Discriminant)
738+
}
739+
697740
/// Iterate over all direct children.
698741
pub fn children(&self, parent: PlaceIndex) -> impl Iterator<Item = PlaceIndex> + '_ {
699742
Children::new(self, parent)
700743
}
701744

745+
/// Invoke a function on the given place and all places that may alias it.
746+
///
747+
/// In particular, when the given place has a variant downcast, we invoke the function on all
748+
/// the other variants.
749+
///
750+
/// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
751+
/// as such.
752+
fn for_each_aliasing_place(
753+
&self,
754+
place: PlaceRef<'_>,
755+
tail_elem: Option<TrackElem>,
756+
f: &mut impl FnMut(PlaceIndex),
757+
) {
758+
let Some(&Some(mut index)) = self.locals.get(place.local) else {
759+
// The local is not tracked at all, nothing to invalidate.
760+
return;
761+
};
762+
let elems = place
763+
.projection
764+
.iter()
765+
.map(|&elem| elem.try_into())
766+
.chain(tail_elem.map(Ok).into_iter());
767+
for elem in elems {
768+
let Ok(elem) = elem else { return };
769+
let sub = self.apply(index, elem);
770+
if let TrackElem::Variant(..) | TrackElem::Discriminant = elem {
771+
// Writing to an enum variant field invalidates the other variants and the discriminant.
772+
self.for_each_variant_sibling(index, sub, f);
773+
}
774+
if let Some(sub) = sub {
775+
index = sub
776+
} else {
777+
return;
778+
}
779+
}
780+
self.preorder_invoke(index, f);
781+
}
782+
783+
/// Invoke the given function on all the descendants of the given place, except one branch.
784+
pub fn for_each_variant_sibling(
785+
&self,
786+
parent: PlaceIndex,
787+
preserved_child: Option<PlaceIndex>,
788+
f: &mut impl FnMut(PlaceIndex),
789+
) {
790+
for sibling in self.children(parent) {
791+
let elem = self.places[sibling].proj_elem;
792+
// Only invalidate variants and discriminant. Fields (for generators) are not
793+
// invalidated by assignment to a variant.
794+
if let Some(TrackElem::Variant(..) | TrackElem::Discriminant) = elem
795+
// Only invalidate the other variants, the current one is fine.
796+
&& Some(sibling) != preserved_child
797+
{
798+
self.preorder_invoke(sibling, f);
799+
}
800+
}
801+
}
802+
702803
/// Invoke a function on the given place and all descendants.
703-
pub fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
804+
fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
704805
f(root);
705806
for child in self.children(root) {
706807
self.preorder_invoke(child, f);
@@ -759,6 +860,7 @@ impl<'a> Iterator for Children<'a> {
759860
}
760861

761862
/// Used as the result of an operand or r-value.
863+
#[derive(Debug)]
762864
pub enum ValueOrPlace<V> {
763865
Value(V),
764866
Place(PlaceIndex),
@@ -776,6 +878,8 @@ impl<V: HasTop> ValueOrPlace<V> {
776878
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
777879
pub enum TrackElem {
778880
Field(Field),
881+
Variant(VariantIdx),
882+
Discriminant,
779883
}
780884

781885
impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
@@ -784,6 +888,7 @@ impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
784888
fn try_from(value: ProjectionElem<V, T>) -> Result<Self, Self::Error> {
785889
match value {
786890
ProjectionElem::Field(field, _) => Ok(TrackElem::Field(field)),
891+
ProjectionElem::Downcast(_, idx) => Ok(TrackElem::Variant(idx)),
787892
_ => Err(()),
788893
}
789894
}
@@ -900,6 +1005,12 @@ fn debug_with_context_rec<V: Debug + Eq>(
9001005
for child in map.children(place) {
9011006
let info_elem = map.places[child].proj_elem.unwrap();
9021007
let child_place_str = match info_elem {
1008+
TrackElem::Discriminant => {
1009+
format!("discriminant({})", place_str)
1010+
}
1011+
TrackElem::Variant(idx) => {
1012+
format!("({} as {:?})", place_str, idx)
1013+
}
9031014
TrackElem::Field(field) => {
9041015
if place_str.starts_with('*') {
9051016
format!("({}).{}", place_str, field.index())

0 commit comments

Comments
 (0)