Skip to content

Commit cbdcbf0

Browse files
committed
interpret: reset provenance on typed copies
1 parent 85dc22f commit cbdcbf0

23 files changed

+489
-135
lines changed

compiler/rustc_const_eval/src/const_eval/eval_queries.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ fn eval_body_using_ecx<'tcx, R: InterpretationResult<'tcx>>(
9494
let intern_result = intern_const_alloc_recursive(ecx, intern_kind, &ret);
9595

9696
// Since evaluation had no errors, validate the resulting constant.
97-
const_validate_mplace(&ecx, &ret, cid)?;
97+
const_validate_mplace(ecx, &ret, cid)?;
9898

9999
// Only report this after validation, as validaiton produces much better diagnostics.
100100
// FIXME: ensure validation always reports this and stop making interning care about it.
@@ -391,7 +391,7 @@ fn eval_in_interpreter<'tcx, R: InterpretationResult<'tcx>>(
391391

392392
#[inline(always)]
393393
fn const_validate_mplace<'tcx>(
394-
ecx: &InterpCx<'tcx, CompileTimeMachine<'tcx>>,
394+
ecx: &mut InterpCx<'tcx, CompileTimeMachine<'tcx>>,
395395
mplace: &MPlaceTy<'tcx>,
396396
cid: GlobalId<'tcx>,
397397
) -> Result<(), ErrorHandled> {

compiler/rustc_const_eval/src/interpret/memory.rs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
99
use std::assert_matches::assert_matches;
1010
use std::borrow::Cow;
11-
use std::cell::Cell;
1211
use std::collections::VecDeque;
13-
use std::{fmt, ptr};
12+
use std::{fmt, mem, ptr};
1413

1514
use rustc_ast::Mutability;
1615
use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
@@ -118,7 +117,7 @@ pub struct Memory<'tcx, M: Machine<'tcx>> {
118117
/// This stores whether we are currently doing reads purely for the purpose of validation.
119118
/// Those reads do not trigger the machine's hooks for memory reads.
120119
/// Needless to say, this must only be set with great care!
121-
validation_in_progress: Cell<bool>,
120+
validation_in_progress: bool,
122121
}
123122

124123
/// A reference to some allocation that was already bounds-checked for the given region
@@ -145,7 +144,7 @@ impl<'tcx, M: Machine<'tcx>> Memory<'tcx, M> {
145144
alloc_map: M::MemoryMap::default(),
146145
extra_fn_ptr_map: FxIndexMap::default(),
147146
dead_alloc_map: FxIndexMap::default(),
148-
validation_in_progress: Cell::new(false),
147+
validation_in_progress: false,
149148
}
150149
}
151150

@@ -682,15 +681,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
682681
// We want to call the hook on *all* accesses that involve an AllocId, including zero-sized
683682
// accesses. That means we cannot rely on the closure above or the `Some` branch below. We
684683
// do this after `check_and_deref_ptr` to ensure some basic sanity has already been checked.
685-
if !self.memory.validation_in_progress.get() {
684+
if !self.memory.validation_in_progress {
686685
if let Ok((alloc_id, ..)) = self.ptr_try_get_alloc_id(ptr, size_i64) {
687686
M::before_alloc_read(self, alloc_id)?;
688687
}
689688
}
690689

691690
if let Some((alloc_id, offset, prov, alloc)) = ptr_and_alloc {
692691
let range = alloc_range(offset, size);
693-
if !self.memory.validation_in_progress.get() {
692+
if !self.memory.validation_in_progress {
694693
M::before_memory_read(
695694
self.tcx,
696695
&self.machine,
@@ -766,11 +765,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
766765
let parts = self.get_ptr_access(ptr, size)?;
767766
if let Some((alloc_id, offset, prov)) = parts {
768767
let tcx = self.tcx;
768+
let validation_in_progress = self.memory.validation_in_progress;
769769
// FIXME: can we somehow avoid looking up the allocation twice here?
770770
// We cannot call `get_raw_mut` inside `check_and_deref_ptr` as that would duplicate `&mut self`.
771771
let (alloc, machine) = self.get_alloc_raw_mut(alloc_id)?;
772772
let range = alloc_range(offset, size);
773-
M::before_memory_write(tcx, machine, &mut alloc.extra, (alloc_id, prov), range)?;
773+
if !validation_in_progress {
774+
M::before_memory_write(tcx, machine, &mut alloc.extra, (alloc_id, prov), range)?;
775+
}
774776
Ok(Some(AllocRefMut { alloc, range, tcx: *tcx, alloc_id }))
775777
} else {
776778
Ok(None)
@@ -1014,16 +1016,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
10141016
///
10151017
/// We do this so Miri's allocation access tracking does not show the validation
10161018
/// reads as spurious accesses.
1017-
pub fn run_for_validation<R>(&self, f: impl FnOnce() -> R) -> R {
1019+
pub fn run_for_validation<R>(&mut self, f: impl FnOnce(&mut Self) -> R) -> R {
10181020
// This deliberately uses `==` on `bool` to follow the pattern
10191021
// `assert!(val.replace(new) == old)`.
10201022
assert!(
1021-
self.memory.validation_in_progress.replace(true) == false,
1023+
mem::replace(&mut self.memory.validation_in_progress, true) == false,
10221024
"`validation_in_progress` was already set"
10231025
);
1024-
let res = f();
1026+
let res = f(self);
10251027
assert!(
1026-
self.memory.validation_in_progress.replace(false) == true,
1028+
mem::replace(&mut self.memory.validation_in_progress, false) == true,
10271029
"`validation_in_progress` was unset by someone else"
10281030
);
10291031
res
@@ -1115,6 +1117,10 @@ impl<'a, 'tcx, M: Machine<'tcx>> std::fmt::Debug for DumpAllocs<'a, 'tcx, M> {
11151117
impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes>
11161118
AllocRefMut<'a, 'tcx, Prov, Extra, Bytes>
11171119
{
1120+
pub fn as_ref<'b>(&'b self) -> AllocRef<'b, 'tcx, Prov, Extra, Bytes> {
1121+
AllocRef { alloc: self.alloc, range: self.range, tcx: self.tcx, alloc_id: self.alloc_id }
1122+
}
1123+
11181124
/// `range` is relative to this allocation reference, not the base of the allocation.
11191125
pub fn write_scalar(&mut self, range: AllocRange, val: Scalar<Prov>) -> InterpResult<'tcx> {
11201126
let range = self.range.subrange(range);
@@ -1137,6 +1143,14 @@ impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes>
11371143
.write_uninit(&self.tcx, self.range)
11381144
.map_err(|e| e.to_interp_error(self.alloc_id))?)
11391145
}
1146+
1147+
/// Remove all provenance in the reference range.
1148+
pub fn clear_provenance(&mut self) -> InterpResult<'tcx> {
1149+
Ok(self
1150+
.alloc
1151+
.clear_provenance(&self.tcx, self.range)
1152+
.map_err(|e| e.to_interp_error(self.alloc_id))?)
1153+
}
11401154
}
11411155

11421156
impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes> AllocRef<'a, 'tcx, Prov, Extra, Bytes> {
@@ -1278,7 +1292,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
12781292
};
12791293
let src_alloc = self.get_alloc_raw(src_alloc_id)?;
12801294
let src_range = alloc_range(src_offset, size);
1281-
assert!(!self.memory.validation_in_progress.get(), "we can't be copying during validation");
1295+
assert!(!self.memory.validation_in_progress, "we can't be copying during validation");
12821296
M::before_memory_read(
12831297
tcx,
12841298
&self.machine,

compiler/rustc_const_eval/src/interpret/operand.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ impl<Prov: Provenance> Immediate<Prov> {
137137
}
138138
}
139139
}
140+
141+
pub fn clear_provenance<'tcx>(&mut self) -> InterpResult<'tcx> {
142+
match self {
143+
Immediate::Scalar(s) => {
144+
s.clear_provenance()?;
145+
}
146+
Immediate::ScalarPair(a, b) => {
147+
a.clear_provenance()?;
148+
b.clear_provenance()?;
149+
}
150+
Immediate::Uninit => {}
151+
}
152+
Ok(())
153+
}
140154
}
141155

142156
// ScalarPair needs a type to interpret, so we often have an immediate and a type together

compiler/rustc_const_eval/src/interpret/place.rs

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -605,8 +605,9 @@ where
605605
if M::enforce_validity(self, dest.layout()) {
606606
// Data got changed, better make sure it matches the type!
607607
self.validate_operand(
608-
&dest.to_op(self)?,
608+
&dest.to_place(),
609609
M::enforce_validity_recursively(self, dest.layout()),
610+
/*reset_provenance*/ true,
610611
)?;
611612
}
612613

@@ -636,7 +637,7 @@ where
636637
/// Write an immediate to a place.
637638
/// If you use this you are responsible for validating that things got copied at the
638639
/// right type.
639-
fn write_immediate_no_validate(
640+
pub(super) fn write_immediate_no_validate(
640641
&mut self,
641642
src: Immediate<M::Provenance>,
642643
dest: &impl Writeable<'tcx, M::Provenance>,
@@ -684,15 +685,7 @@ where
684685

685686
match value {
686687
Immediate::Scalar(scalar) => {
687-
let Abi::Scalar(s) = layout.abi else {
688-
span_bug!(
689-
self.cur_span(),
690-
"write_immediate_to_mplace: invalid Scalar layout: {layout:#?}",
691-
)
692-
};
693-
let size = s.size(&tcx);
694-
assert_eq!(size, layout.size, "abi::Scalar size does not match layout size");
695-
alloc.write_scalar(alloc_range(Size::ZERO, size), scalar)
688+
alloc.write_scalar(alloc_range(Size::ZERO, scalar.size()), scalar)
696689
}
697690
Immediate::ScalarPair(a_val, b_val) => {
698691
let Abi::ScalarPair(a, b) = layout.abi else {
@@ -702,16 +695,15 @@ where
702695
layout
703696
)
704697
};
705-
let (a_size, b_size) = (a.size(&tcx), b.size(&tcx));
706-
let b_offset = a_size.align_to(b.align(&tcx).abi);
698+
let b_offset = a.size(&tcx).align_to(b.align(&tcx).abi);
707699
assert!(b_offset.bytes() > 0); // in `operand_field` we use the offset to tell apart the fields
708700

709701
// It is tempting to verify `b_offset` against `layout.fields.offset(1)`,
710702
// but that does not work: We could be a newtype around a pair, then the
711703
// fields do not match the `ScalarPair` components.
712704

713-
alloc.write_scalar(alloc_range(Size::ZERO, a_size), a_val)?;
714-
alloc.write_scalar(alloc_range(b_offset, b_size), b_val)
705+
alloc.write_scalar(alloc_range(Size::ZERO, a_val.size()), a_val)?;
706+
alloc.write_scalar(alloc_range(b_offset, b_val.size()), b_val)
715707
}
716708
Immediate::Uninit => alloc.write_uninit(),
717709
}
@@ -736,6 +728,26 @@ where
736728
Ok(())
737729
}
738730

731+
/// Remove all provenance in the given place.
732+
pub fn clear_provenance(
733+
&mut self,
734+
dest: &impl Writeable<'tcx, M::Provenance>,
735+
) -> InterpResult<'tcx> {
736+
match self.as_mplace_or_mutable_local(&dest.to_place())? {
737+
Right((local_val, _local_layout)) => {
738+
local_val.clear_provenance()?;
739+
}
740+
Left(mplace) => {
741+
let Some(mut alloc) = self.get_place_alloc_mut(&mplace)? else {
742+
// Zero-sized access
743+
return Ok(());
744+
};
745+
alloc.clear_provenance()?;
746+
}
747+
}
748+
Ok(())
749+
}
750+
739751
/// Copies the data from an operand to a place.
740752
/// The layouts of the `src` and `dest` may disagree.
741753
/// Does not perform validation of the destination.
@@ -789,23 +801,30 @@ where
789801
allow_transmute: bool,
790802
validate_dest: bool,
791803
) -> InterpResult<'tcx> {
792-
// Generally for transmutation, data must be valid both at the old and new type.
793-
// But if the types are the same, the 2nd validation below suffices.
794-
if src.layout().ty != dest.layout().ty && M::enforce_validity(self, src.layout()) {
795-
self.validate_operand(
796-
&src.to_op(self)?,
797-
M::enforce_validity_recursively(self, src.layout()),
798-
)?;
799-
}
804+
// These are technically *two* typed copies: `src` is a not-yet-loaded value,
805+
// so we're going a typed copy at `src` type from there to some intermediate storage.
806+
// And then we're doing a second typed copy from that intermediate storage to `dest`.
807+
// But as an optimization, we only make a single direct copy here.
800808

801809
// Do the actual copy.
802810
self.copy_op_no_validate(src, dest, allow_transmute)?;
803811

804812
if validate_dest && M::enforce_validity(self, dest.layout()) {
805-
// Data got changed, better make sure it matches the type!
813+
let dest = dest.to_place();
814+
// Given that there were two typed copies, we have to ensure this is valid at both types,
815+
// and we have to ensure this loses provenance and padding according to both types.
816+
// But if the types are identical, we only do one pass.
817+
if src.layout().ty != dest.layout().ty {
818+
self.validate_operand(
819+
&dest.transmute(src.layout(), self)?,
820+
M::enforce_validity_recursively(self, src.layout()),
821+
/*reset_provenance*/ true,
822+
)?;
823+
}
806824
self.validate_operand(
807-
&dest.to_op(self)?,
825+
&dest,
808826
M::enforce_validity_recursively(self, dest.layout()),
827+
/*reset_provenance*/ true,
809828
)?;
810829
}
811830

0 commit comments

Comments
 (0)