Skip to content

Commit 2bb7260

Browse files
authored
Rollup merge of rust-lang#107731 - RalfJung:interpret-discriminant, r=cjgillot
interpret: move discriminant reading and writing to separate file This is quite different from the otherwise fairly general read and write functions in place.rs and operand.rs, and also it's nice to have these two functions close together as they are basically inverses of each other.
2 parents 4e163af + e1926b2 commit 2bb7260

File tree

4 files changed

+245
-234
lines changed

4 files changed

+245
-234
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
//! Functions for reading and writing discriminants of multi-variant layouts (enums and generators).
2+
3+
use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
4+
use rustc_middle::{mir, ty};
5+
use rustc_target::abi::{self, TagEncoding};
6+
use rustc_target::abi::{VariantIdx, Variants};
7+
8+
use super::{ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Scalar};
9+
10+
impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
11+
/// Writes the discriminant of the given variant.
12+
#[instrument(skip(self), level = "trace")]
13+
pub fn write_discriminant(
14+
&mut self,
15+
variant_index: VariantIdx,
16+
dest: &PlaceTy<'tcx, M::Provenance>,
17+
) -> InterpResult<'tcx> {
18+
// Layout computation excludes uninhabited variants from consideration
19+
// therefore there's no way to represent those variants in the given layout.
20+
// Essentially, uninhabited variants do not have a tag that corresponds to their
21+
// discriminant, so we cannot do anything here.
22+
// When evaluating we will always error before even getting here, but ConstProp 'executes'
23+
// dead code, so we cannot ICE here.
24+
if dest.layout.for_variant(self, variant_index).abi.is_uninhabited() {
25+
throw_ub!(UninhabitedEnumVariantWritten)
26+
}
27+
28+
match dest.layout.variants {
29+
abi::Variants::Single { index } => {
30+
assert_eq!(index, variant_index);
31+
}
32+
abi::Variants::Multiple {
33+
tag_encoding: TagEncoding::Direct,
34+
tag: tag_layout,
35+
tag_field,
36+
..
37+
} => {
38+
// No need to validate that the discriminant here because the
39+
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
40+
41+
let discr_val =
42+
dest.layout.ty.discriminant_for_variant(*self.tcx, variant_index).unwrap().val;
43+
44+
// raw discriminants for enums are isize or bigger during
45+
// their computation, but the in-memory tag is the smallest possible
46+
// representation
47+
let size = tag_layout.size(self);
48+
let tag_val = size.truncate(discr_val);
49+
50+
let tag_dest = self.place_field(dest, tag_field)?;
51+
self.write_scalar(Scalar::from_uint(tag_val, size), &tag_dest)?;
52+
}
53+
abi::Variants::Multiple {
54+
tag_encoding:
55+
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
56+
tag: tag_layout,
57+
tag_field,
58+
..
59+
} => {
60+
// No need to validate that the discriminant here because the
61+
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
62+
63+
if variant_index != untagged_variant {
64+
let variants_start = niche_variants.start().as_u32();
65+
let variant_index_relative = variant_index
66+
.as_u32()
67+
.checked_sub(variants_start)
68+
.expect("overflow computing relative variant idx");
69+
// We need to use machine arithmetic when taking into account `niche_start`:
70+
// tag_val = variant_index_relative + niche_start_val
71+
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
72+
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
73+
let variant_index_relative_val =
74+
ImmTy::from_uint(variant_index_relative, tag_layout);
75+
let tag_val = self.binary_op(
76+
mir::BinOp::Add,
77+
&variant_index_relative_val,
78+
&niche_start_val,
79+
)?;
80+
// Write result.
81+
let niche_dest = self.place_field(dest, tag_field)?;
82+
self.write_immediate(*tag_val, &niche_dest)?;
83+
}
84+
}
85+
}
86+
87+
Ok(())
88+
}
89+
90+
/// Read discriminant, return the runtime value as well as the variant index.
91+
/// Can also legally be called on non-enums (e.g. through the discriminant_value intrinsic)!
92+
#[instrument(skip(self), level = "trace")]
93+
pub fn read_discriminant(
94+
&self,
95+
op: &OpTy<'tcx, M::Provenance>,
96+
) -> InterpResult<'tcx, (Scalar<M::Provenance>, VariantIdx)> {
97+
trace!("read_discriminant_value {:#?}", op.layout);
98+
// Get type and layout of the discriminant.
99+
let discr_layout = self.layout_of(op.layout.ty.discriminant_ty(*self.tcx))?;
100+
trace!("discriminant type: {:?}", discr_layout.ty);
101+
102+
// We use "discriminant" to refer to the value associated with a particular enum variant.
103+
// This is not to be confused with its "variant index", which is just determining its position in the
104+
// declared list of variants -- they can differ with explicitly assigned discriminants.
105+
// We use "tag" to refer to how the discriminant is encoded in memory, which can be either
106+
// straight-forward (`TagEncoding::Direct`) or with a niche (`TagEncoding::Niche`).
107+
let (tag_scalar_layout, tag_encoding, tag_field) = match op.layout.variants {
108+
Variants::Single { index } => {
109+
let discr = match op.layout.ty.discriminant_for_variant(*self.tcx, index) {
110+
Some(discr) => {
111+
// This type actually has discriminants.
112+
assert_eq!(discr.ty, discr_layout.ty);
113+
Scalar::from_uint(discr.val, discr_layout.size)
114+
}
115+
None => {
116+
// On a type without actual discriminants, variant is 0.
117+
assert_eq!(index.as_u32(), 0);
118+
Scalar::from_uint(index.as_u32(), discr_layout.size)
119+
}
120+
};
121+
return Ok((discr, index));
122+
}
123+
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
124+
(tag, tag_encoding, tag_field)
125+
}
126+
};
127+
128+
// There are *three* layouts that come into play here:
129+
// - The discriminant has a type for typechecking. This is `discr_layout`, and is used for
130+
// the `Scalar` we return.
131+
// - The tag (encoded discriminant) has layout `tag_layout`. This is always an integer type,
132+
// and used to interpret the value we read from the tag field.
133+
// For the return value, a cast to `discr_layout` is performed.
134+
// - The field storing the tag has a layout, which is very similar to `tag_layout` but
135+
// may be a pointer. This is `tag_val.layout`; we just use it for sanity checks.
136+
137+
// Get layout for tag.
138+
let tag_layout = self.layout_of(tag_scalar_layout.primitive().to_int_ty(*self.tcx))?;
139+
140+
// Read tag and sanity-check `tag_layout`.
141+
let tag_val = self.read_immediate(&self.operand_field(op, tag_field)?)?;
142+
assert_eq!(tag_layout.size, tag_val.layout.size);
143+
assert_eq!(tag_layout.abi.is_signed(), tag_val.layout.abi.is_signed());
144+
trace!("tag value: {}", tag_val);
145+
146+
// Figure out which discriminant and variant this corresponds to.
147+
Ok(match *tag_encoding {
148+
TagEncoding::Direct => {
149+
let scalar = tag_val.to_scalar();
150+
// Generate a specific error if `tag_val` is not an integer.
151+
// (`tag_bits` itself is only used for error messages below.)
152+
let tag_bits = scalar
153+
.try_to_int()
154+
.map_err(|dbg_val| err_ub!(InvalidTag(dbg_val)))?
155+
.assert_bits(tag_layout.size);
156+
// Cast bits from tag layout to discriminant layout.
157+
// After the checks we did above, this cannot fail, as
158+
// discriminants are int-like.
159+
let discr_val =
160+
self.cast_from_int_like(scalar, tag_val.layout, discr_layout.ty).unwrap();
161+
let discr_bits = discr_val.assert_bits(discr_layout.size);
162+
// Convert discriminant to variant index, and catch invalid discriminants.
163+
let index = match *op.layout.ty.kind() {
164+
ty::Adt(adt, _) => {
165+
adt.discriminants(*self.tcx).find(|(_, var)| var.val == discr_bits)
166+
}
167+
ty::Generator(def_id, substs, _) => {
168+
let substs = substs.as_generator();
169+
substs
170+
.discriminants(def_id, *self.tcx)
171+
.find(|(_, var)| var.val == discr_bits)
172+
}
173+
_ => span_bug!(self.cur_span(), "tagged layout for non-adt non-generator"),
174+
}
175+
.ok_or_else(|| err_ub!(InvalidTag(Scalar::from_uint(tag_bits, tag_layout.size))))?;
176+
// Return the cast value, and the index.
177+
(discr_val, index.0)
178+
}
179+
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
180+
let tag_val = tag_val.to_scalar();
181+
// Compute the variant this niche value/"tag" corresponds to. With niche layout,
182+
// discriminant (encoded in niche/tag) and variant index are the same.
183+
let variants_start = niche_variants.start().as_u32();
184+
let variants_end = niche_variants.end().as_u32();
185+
let variant = match tag_val.try_to_int() {
186+
Err(dbg_val) => {
187+
// So this is a pointer then, and casting to an int failed.
188+
// Can only happen during CTFE.
189+
// The niche must be just 0, and the ptr not null, then we know this is
190+
// okay. Everything else, we conservatively reject.
191+
let ptr_valid = niche_start == 0
192+
&& variants_start == variants_end
193+
&& !self.scalar_may_be_null(tag_val)?;
194+
if !ptr_valid {
195+
throw_ub!(InvalidTag(dbg_val))
196+
}
197+
untagged_variant
198+
}
199+
Ok(tag_bits) => {
200+
let tag_bits = tag_bits.assert_bits(tag_layout.size);
201+
// We need to use machine arithmetic to get the relative variant idx:
202+
// variant_index_relative = tag_val - niche_start_val
203+
let tag_val = ImmTy::from_uint(tag_bits, tag_layout);
204+
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
205+
let variant_index_relative_val =
206+
self.binary_op(mir::BinOp::Sub, &tag_val, &niche_start_val)?;
207+
let variant_index_relative =
208+
variant_index_relative_val.to_scalar().assert_bits(tag_val.layout.size);
209+
// Check if this is in the range that indicates an actual discriminant.
210+
if variant_index_relative <= u128::from(variants_end - variants_start) {
211+
let variant_index_relative = u32::try_from(variant_index_relative)
212+
.expect("we checked that this fits into a u32");
213+
// Then computing the absolute variant idx should not overflow any more.
214+
let variant_index = variants_start
215+
.checked_add(variant_index_relative)
216+
.expect("overflow computing absolute variant idx");
217+
let variants_len = op
218+
.layout
219+
.ty
220+
.ty_adt_def()
221+
.expect("tagged layout for non adt")
222+
.variants()
223+
.len();
224+
assert!(usize::try_from(variant_index).unwrap() < variants_len);
225+
VariantIdx::from_u32(variant_index)
226+
} else {
227+
untagged_variant
228+
}
229+
}
230+
};
231+
// Compute the size of the scalar we need to return.
232+
// No need to cast, because the variant index directly serves as discriminant and is
233+
// encoded in the tag.
234+
(Scalar::from_uint(variant.as_u32(), discr_layout.size), variant)
235+
}
236+
})
237+
}
238+
}

compiler/rustc_const_eval/src/interpret/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! An interpreter for MIR used in CTFE and by miri
22
33
mod cast;
4+
mod discriminant;
45
mod eval_context;
56
mod intern;
67
mod intrinsics;

0 commit comments

Comments
 (0)