Skip to content

Commit 8c1b039

Browse files
committed
Use a ConstValue instead.
1 parent 31d1010 commit 8c1b039

23 files changed

+720
-277
lines changed

compiler/rustc_mir_transform/src/dataflow_const_prop.rs

+159-96
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,23 @@
22
//!
33
//! Currently, this pass only propagates scalar values.
44
5-
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
5+
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
66
use rustc_data_structures::fx::FxHashMap;
77
use rustc_hir::def::DefKind;
88
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
99
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
1010
use rustc_middle::mir::*;
11-
use rustc_middle::ty::layout::TyAndLayout;
11+
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1212
use rustc_middle::ty::{self, Ty, TyCtxt};
1313
use rustc_mir_dataflow::value_analysis::{
1414
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
1515
};
1616
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
1717
use rustc_span::def_id::DefId;
1818
use rustc_span::DUMMY_SP;
19-
use rustc_target::abi::{FieldIdx, VariantIdx};
19+
use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
2020

21+
use crate::const_prop::throw_machine_stop_str;
2122
use crate::MirPass;
2223

2324
// These constants are somewhat random guesses and have not been optimized.
@@ -553,107 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
553554

554555
fn try_make_constant(
555556
&self,
557+
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
556558
place: Place<'tcx>,
557559
state: &State<FlatSet<Scalar>>,
558560
map: &Map,
559561
) -> Option<Const<'tcx>> {
560562
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
563+
let layout = ecx.layout_of(ty).ok()?;
564+
565+
if layout.is_zst() {
566+
return Some(Const::zero_sized(ty));
567+
}
568+
569+
if layout.is_unsized() {
570+
return None;
571+
}
572+
561573
let place = map.find(place.as_ref())?;
562-
if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) {
563-
Some(Const::Val(ConstValue::Scalar(value.into()), ty))
564-
} else {
565-
let valtree = self.try_make_valtree(place, ty, state, map)?;
566-
let constant = ty::Const::new_value(self.patch.tcx, valtree, ty);
567-
Some(Const::Ty(constant))
574+
if layout.abi.is_scalar()
575+
&& let Some(value) = propagatable_scalar(place, state, map)
576+
{
577+
return Some(Const::Val(ConstValue::Scalar(value), ty));
578+
}
579+
580+
if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
581+
let alloc_id = ecx
582+
.intern_with_temp_alloc(layout, |ecx, dest| {
583+
try_write_constant(ecx, dest, place, ty, state, map)
584+
})
585+
.ok()?;
586+
return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
568587
}
588+
589+
None
569590
}
591+
}
570592

571-
fn try_make_valtree(
572-
&self,
573-
place: PlaceIndex,
574-
ty: Ty<'tcx>,
575-
state: &State<FlatSet<Scalar>>,
576-
map: &Map,
577-
) -> Option<ty::ValTree<'tcx>> {
578-
let tcx = self.patch.tcx;
579-
match ty.kind() {
580-
// ZSTs.
581-
ty::FnDef(..) => Some(ty::ValTree::zst()),
582-
583-
// Scalars.
584-
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
585-
if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) {
586-
Some(ty::ValTree::Leaf(value))
587-
} else {
588-
None
589-
}
590-
}
593+
fn propagatable_scalar(
594+
place: PlaceIndex,
595+
state: &State<FlatSet<Scalar>>,
596+
map: &Map,
597+
) -> Option<Scalar> {
598+
if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
599+
// Do not attempt to propagate pointers, as we may fail to preserve their identity.
600+
Some(value)
601+
} else {
602+
None
603+
}
604+
}
591605

592-
// Unsupported for now.
593-
ty::Array(_, _) => None,
594-
595-
ty::Tuple(elem_tys) => {
596-
let branches = elem_tys
597-
.iter()
598-
.enumerate()
599-
.map(|(i, ty)| {
600-
let field = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i)))?;
601-
self.try_make_valtree(field, ty, state, map)
602-
})
603-
.collect::<Option<Vec<_>>>()?;
604-
Some(ty::ValTree::Branch(tcx.arena.alloc_from_iter(branches.into_iter())))
605-
}
606+
#[instrument(level = "trace", skip(ecx, state, map))]
607+
fn try_write_constant<'tcx>(
608+
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
609+
dest: &PlaceTy<'tcx>,
610+
place: PlaceIndex,
611+
ty: Ty<'tcx>,
612+
state: &State<FlatSet<Scalar>>,
613+
map: &Map,
614+
) -> InterpResult<'tcx> {
615+
let layout = ecx.layout_of(ty)?;
616+
617+
// Fast path for ZSTs.
618+
if layout.is_zst() {
619+
return Ok(());
620+
}
621+
622+
// Fast path for scalars.
623+
if layout.abi.is_scalar()
624+
&& let Some(value) = propagatable_scalar(place, state, map)
625+
{
626+
return ecx.write_immediate(Immediate::Scalar(value), dest);
627+
}
606628

607-
ty::Adt(def, args) => {
608-
if def.is_union() {
609-
return None;
610-
}
629+
match ty.kind() {
630+
// ZSTs. Nothing to do.
631+
ty::FnDef(..) => {}
611632

612-
let (variant_idx, variant_def, variant_place) = if def.is_enum() {
613-
let discr = map.apply(place, TrackElem::Discriminant)?;
614-
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
615-
return None;
616-
};
617-
let discr_bits = discr.assert_bits(discr.size());
618-
let (variant, _) =
619-
def.discriminants(tcx).find(|(_, var)| discr_bits == var.val)?;
620-
let variant_place = map.apply(place, TrackElem::Variant(variant))?;
621-
let variant_int = ty::ValTree::Leaf(variant.as_u32().into());
622-
(Some(variant_int), def.variant(variant), variant_place)
623-
} else {
624-
(None, def.non_enum_variant(), place)
633+
// Those are scalars, must be handled above.
634+
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
635+
636+
ty::Tuple(elem_tys) => {
637+
for (i, elem) in elem_tys.iter().enumerate() {
638+
let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
639+
throw_machine_stop_str!("missing field in tuple")
625640
};
641+
let field_dest = ecx.project_field(dest, i)?;
642+
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
643+
}
644+
}
626645

627-
let branches = variant_def
628-
.fields
629-
.iter_enumerated()
630-
.map(|(i, field)| {
631-
let ty = field.ty(tcx, args);
632-
let field = map.apply(variant_place, TrackElem::Field(i))?;
633-
self.try_make_valtree(field, ty, state, map)
634-
})
635-
.collect::<Option<Vec<_>>>()?;
636-
Some(ty::ValTree::Branch(
637-
tcx.arena.alloc_from_iter(variant_idx.into_iter().chain(branches)),
638-
))
646+
ty::Adt(def, args) => {
647+
if def.is_union() {
648+
throw_machine_stop_str!("cannot propagate unions")
639649
}
640650

641-
// Do not attempt to support indirection in constants.
642-
ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) => None,
651+
let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
652+
let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
653+
throw_machine_stop_str!("missing discriminant for enum")
654+
};
655+
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
656+
throw_machine_stop_str!("discriminant with provenance")
657+
};
658+
let discr_bits = discr.assert_bits(discr.size());
659+
let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
660+
throw_machine_stop_str!("illegal discriminant for enum")
661+
};
662+
let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
663+
throw_machine_stop_str!("missing variant for enum")
664+
};
665+
let variant_dest = ecx.project_downcast(dest, variant)?;
666+
(variant, def.variant(variant), variant_place, variant_dest)
667+
} else {
668+
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
669+
};
670+
671+
for (i, field) in variant_def.fields.iter_enumerated() {
672+
let ty = field.ty(*ecx.tcx, args);
673+
let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
674+
throw_machine_stop_str!("missing field in ADT")
675+
};
676+
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
677+
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
678+
}
679+
ecx.write_discriminant(variant_idx, dest)?;
680+
}
643681

644-
ty::Never
645-
| ty::Foreign(..)
646-
| ty::Alias(..)
647-
| ty::Param(_)
648-
| ty::Bound(..)
649-
| ty::Placeholder(..)
650-
| ty::Closure(..)
651-
| ty::Coroutine(..)
652-
| ty::Dynamic(..) => None,
682+
// Unsupported for now.
683+
ty::Array(_, _)
653684

654-
ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
655-
}
685+
// Do not attempt to support indirection in constants.
686+
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
687+
688+
| ty::Never
689+
| ty::Foreign(..)
690+
| ty::Alias(..)
691+
| ty::Param(_)
692+
| ty::Bound(..)
693+
| ty::Placeholder(..)
694+
| ty::Closure(..)
695+
| ty::Coroutine(..)
696+
| ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),
697+
698+
ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
656699
}
700+
701+
Ok(())
657702
}
658703

659704
impl<'mir, 'tcx>
@@ -671,8 +716,13 @@ impl<'mir, 'tcx>
671716
) {
672717
match &statement.kind {
673718
StatementKind::Assign(box (_, rvalue)) => {
674-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
675-
.visit_rvalue(rvalue, location);
719+
OperandCollector {
720+
state,
721+
visitor: self,
722+
ecx: &mut results.analysis.0.ecx,
723+
map: &results.analysis.0.map,
724+
}
725+
.visit_rvalue(rvalue, location);
676726
}
677727
_ => (),
678728
}
@@ -690,7 +740,12 @@ impl<'mir, 'tcx>
690740
// Don't overwrite the assignment if it already uses a constant (to keep the span).
691741
}
692742
StatementKind::Assign(box (place, _)) => {
693-
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
743+
if let Some(value) = self.try_make_constant(
744+
&mut results.analysis.0.ecx,
745+
place,
746+
state,
747+
&results.analysis.0.map,
748+
) {
694749
self.patch.assignments.insert(location, value);
695750
}
696751
}
@@ -705,8 +760,13 @@ impl<'mir, 'tcx>
705760
terminator: &'mir Terminator<'tcx>,
706761
location: Location,
707762
) {
708-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
709-
.visit_terminator(terminator, location);
763+
OperandCollector {
764+
state,
765+
visitor: self,
766+
ecx: &mut results.analysis.0.ecx,
767+
map: &results.analysis.0.map,
768+
}
769+
.visit_terminator(terminator, location);
710770
}
711771
}
712772

@@ -761,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
761821
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
762822
state: &'a State<FlatSet<Scalar>>,
763823
visitor: &'a mut Collector<'tcx, 'locals>,
824+
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
764825
map: &'map Map,
765826
}
766827

@@ -773,15 +834,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
773834
location: Location,
774835
) {
775836
if let PlaceElem::Index(local) = elem
776-
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
837+
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
777838
{
778839
self.visitor.patch.before_effect.insert((location, local.into()), value);
779840
}
780841
}
781842

782843
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
783844
if let Some(place) = operand.place() {
784-
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
845+
if let Some(value) =
846+
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
847+
{
785848
self.visitor.patch.before_effect.insert((location, place), value);
786849
} else if !place.projection.is_empty() {
787850
// Try to propagate into `Index` projections.
@@ -804,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
804867
}
805868

806869
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
807-
unimplemented!()
870+
false
808871
}
809872

810873
fn before_access_global(
@@ -816,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
816879
is_write: bool,
817880
) -> InterpResult<'tcx> {
818881
if is_write {
819-
crate::const_prop::throw_machine_stop_str!("can't write to global");
882+
throw_machine_stop_str!("can't write to global");
820883
}
821884

822885
// If the static allocation is mutable, then we can't const prop it as its content
823886
// might be different at runtime.
824887
if alloc.inner().mutability.is_mut() {
825-
crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
888+
throw_machine_stop_str!("can't access mutable globals in ConstProp");
826889
}
827890

828891
Ok(())
@@ -872,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
872935
_left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
873936
_right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
874937
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
875-
crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
938+
throw_machine_stop_str!("can't do pointer arithmetic");
876939
}
877940

878941
fn expose_ptr(

tests/mir-opt/const_debuginfo.main.ConstDebugInfo.diff

+6-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
let _10: std::option::Option<u16>;
4343
scope 7 {
4444
- debug o => _10;
45-
+ debug o => const Option::<u16>::Some(99);
45+
+ debug o => const Option::<u16>::Some(99_u16);
4646
let _17: u32;
4747
let _18: u32;
4848
scope 8 {
@@ -82,7 +82,7 @@
8282
_15 = const false;
8383
_16 = const 123_u32;
8484
StorageLive(_10);
85-
_10 = const Option::<u16>::Some(99);
85+
_10 = const Option::<u16>::Some(99_u16);
8686
_17 = const 32_u32;
8787
_18 = const 32_u32;
8888
StorageLive(_11);
@@ -98,3 +98,7 @@
9898
}
9999
}
100100

101+
ALLOC0 (size: 4, align: 2) {
102+
01 00 63 00 │ ..c.
103+
}
104+

0 commit comments

Comments
 (0)