Skip to content

Commit d32a725

Browse files
committed
Auto merge of #62584 - eddyb:circular-math-is-hard, r=pnkfelix
rustc_codegen_ssa: fix range check in codegen_get_discr. Fixes #61696, see #61696 (comment) for more details. In short, I had wanted to use `x - a <= b - a` to check whether `x` is in `a..=b` (as it's 1 comparison instead of 2 *and* `b - a` is guaranteed to fit in the same data type, while `b` itself might not), but I ended up with `x - a + c <= b - a + c` instead, because `x - a + c` was the final value needed. That latter comparison is equivalent to checking that `x` is in `(a - c)..=b`, i.e. it also includes `(a - c)..a`, not just `a..=b`, so if `c` is not `0`, it will cause false positives. This presented itself as the non-niche ("dataful") variant sometimes being treated like a niche variant, in the presence of uninhabited variants (which made `c`, aka the index of the first niche variant, arbitrarily large). r? @nagisa, @rkruppe or @oli-obk
2 parents 69656fa + c063057 commit d32a725

File tree

2 files changed

+126
-24
lines changed

2 files changed

+126
-24
lines changed

src/librustc_codegen_ssa/mir/place.rs

+62-24
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,11 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
228228
}
229229
};
230230

231-
let discr = self.project_field(bx, discr_index);
232-
let lldiscr = bx.load_operand(discr).immediate();
231+
// Read the tag/niche-encoded discriminant from memory.
232+
let encoded_discr = self.project_field(bx, discr_index);
233+
let encoded_discr = bx.load_operand(encoded_discr);
234+
235+
// Decode the discriminant (specifically if it's niche-encoded).
233236
match *discr_kind {
234237
layout::DiscriminantKind::Tag => {
235238
let signed = match discr_scalar.value {
@@ -240,38 +243,73 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
240243
layout::Int(_, signed) => !discr_scalar.is_bool() && signed,
241244
_ => false
242245
};
243-
bx.intcast(lldiscr, cast_to, signed)
246+
bx.intcast(encoded_discr.immediate(), cast_to, signed)
244247
}
245248
layout::DiscriminantKind::Niche {
246249
dataful_variant,
247250
ref niche_variants,
248251
niche_start,
249252
} => {
250-
let niche_llty = bx.cx().immediate_backend_type(discr.layout);
251-
if niche_variants.start() == niche_variants.end() {
252-
// FIXME(eddyb): check the actual primitive type here.
253-
let niche_llval = if niche_start == 0 {
254-
// HACK(eddyb): using `c_null` as it works on all types.
253+
// Rebase from niche values to discriminants, and check
254+
// whether the result is in range for the niche variants.
255+
let niche_llty = bx.cx().immediate_backend_type(encoded_discr.layout);
256+
let encoded_discr = encoded_discr.immediate();
257+
258+
// We first compute the "relative discriminant" (wrt `niche_variants`),
259+
// that is, if `n = niche_variants.end() - niche_variants.start()`,
260+
// we remap `niche_start..=niche_start + n` (which may wrap around)
261+
// to (non-wrap-around) `0..=n`, to be able to check whether the
262+
// discriminant corresponds to a niche variant with one comparison.
263+
// We also can't go directly to the (variant index) discriminant
264+
// and check that it is in the range `niche_variants`, because
265+
// that might not fit in the same type, on top of needing an extra
266+
// comparison (see also the comment on `let niche_discr`).
267+
let relative_discr = if niche_start == 0 {
268+
// Avoid subtracting `0`, which wouldn't work for pointers.
269+
// FIXME(eddyb) check the actual primitive type here.
270+
encoded_discr
271+
} else {
272+
bx.sub(encoded_discr, bx.cx().const_uint_big(niche_llty, niche_start))
273+
};
274+
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
275+
let is_niche = {
276+
let relative_max = if relative_max == 0 {
277+
// Avoid calling `const_uint`, which wouldn't work for pointers.
278+
// FIXME(eddyb) check the actual primitive type here.
255279
bx.cx().const_null(niche_llty)
256280
} else {
257-
bx.cx().const_uint_big(niche_llty, niche_start)
281+
bx.cx().const_uint(niche_llty, relative_max as u64)
282+
};
283+
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
284+
};
285+
286+
// NOTE(eddyb) this addition needs to be performed on the final
287+
// type, in case the niche itself can't represent all variant
288+
// indices (e.g. `u8` niche with more than `256` variants,
289+
// but enough uninhabited variants so that the remaining variants
290+
// fit in the niche).
291+
// In other words, `niche_variants.end - niche_variants.start`
292+
// is representable in the niche, but `niche_variants.end`
293+
// might not be, in extreme cases.
294+
let niche_discr = {
295+
let relative_discr = if relative_max == 0 {
296+
// HACK(eddyb) since we have only one niche, we know which
297+
// one it is, and we can avoid having a dynamic value here.
298+
bx.cx().const_uint(cast_to, 0)
299+
} else {
300+
bx.intcast(relative_discr, cast_to, false)
258301
};
259-
let select_arg = bx.icmp(IntPredicate::IntEQ, lldiscr, niche_llval);
260-
bx.select(select_arg,
302+
bx.add(
303+
relative_discr,
261304
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
262-
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64))
263-
} else {
264-
// Rebase from niche values to discriminant values.
265-
let delta = niche_start.wrapping_sub(niche_variants.start().as_u32() as u128);
266-
let lldiscr = bx.sub(lldiscr, bx.cx().const_uint_big(niche_llty, delta));
267-
let lldiscr_max =
268-
bx.cx().const_uint(niche_llty, niche_variants.end().as_u32() as u64);
269-
let select_arg = bx.icmp(IntPredicate::IntULE, lldiscr, lldiscr_max);
270-
let cast = bx.intcast(lldiscr, cast_to, false);
271-
bx.select(select_arg,
272-
cast,
273-
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64))
274-
}
305+
)
306+
};
307+
308+
bx.select(
309+
is_niche,
310+
niche_discr,
311+
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64),
312+
)
275313
}
276314
}
277315
}
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
pub enum Infallible {}
2+
3+
// The check that the `bool` field of `V1` is encoding a "niche variant"
4+
// (i.e. not `V1`, so `V3` or `V4`) used to be mathematically incorrect,
5+
// causing valid `V1` values to be interpreted as other variants.
6+
pub enum E1 {
7+
V1 { f: bool },
8+
V2 { f: Infallible },
9+
V3,
10+
V4,
11+
}
12+
13+
// Computing the discriminant used to be done using the niche type (here `u8`,
14+
// from the `bool` field of `V1`), overflowing for variants with large enough
15+
// indices (`V3` and `V4`), causing them to be interpreted as other variants.
16+
pub enum E2<X> {
17+
V1 { f: bool },
18+
19+
/*_00*/ _01(X), _02(X), _03(X), _04(X), _05(X), _06(X), _07(X),
20+
_08(X), _09(X), _0A(X), _0B(X), _0C(X), _0D(X), _0E(X), _0F(X),
21+
_10(X), _11(X), _12(X), _13(X), _14(X), _15(X), _16(X), _17(X),
22+
_18(X), _19(X), _1A(X), _1B(X), _1C(X), _1D(X), _1E(X), _1F(X),
23+
_20(X), _21(X), _22(X), _23(X), _24(X), _25(X), _26(X), _27(X),
24+
_28(X), _29(X), _2A(X), _2B(X), _2C(X), _2D(X), _2E(X), _2F(X),
25+
_30(X), _31(X), _32(X), _33(X), _34(X), _35(X), _36(X), _37(X),
26+
_38(X), _39(X), _3A(X), _3B(X), _3C(X), _3D(X), _3E(X), _3F(X),
27+
_40(X), _41(X), _42(X), _43(X), _44(X), _45(X), _46(X), _47(X),
28+
_48(X), _49(X), _4A(X), _4B(X), _4C(X), _4D(X), _4E(X), _4F(X),
29+
_50(X), _51(X), _52(X), _53(X), _54(X), _55(X), _56(X), _57(X),
30+
_58(X), _59(X), _5A(X), _5B(X), _5C(X), _5D(X), _5E(X), _5F(X),
31+
_60(X), _61(X), _62(X), _63(X), _64(X), _65(X), _66(X), _67(X),
32+
_68(X), _69(X), _6A(X), _6B(X), _6C(X), _6D(X), _6E(X), _6F(X),
33+
_70(X), _71(X), _72(X), _73(X), _74(X), _75(X), _76(X), _77(X),
34+
_78(X), _79(X), _7A(X), _7B(X), _7C(X), _7D(X), _7E(X), _7F(X),
35+
_80(X), _81(X), _82(X), _83(X), _84(X), _85(X), _86(X), _87(X),
36+
_88(X), _89(X), _8A(X), _8B(X), _8C(X), _8D(X), _8E(X), _8F(X),
37+
_90(X), _91(X), _92(X), _93(X), _94(X), _95(X), _96(X), _97(X),
38+
_98(X), _99(X), _9A(X), _9B(X), _9C(X), _9D(X), _9E(X), _9F(X),
39+
_A0(X), _A1(X), _A2(X), _A3(X), _A4(X), _A5(X), _A6(X), _A7(X),
40+
_A8(X), _A9(X), _AA(X), _AB(X), _AC(X), _AD(X), _AE(X), _AF(X),
41+
_B0(X), _B1(X), _B2(X), _B3(X), _B4(X), _B5(X), _B6(X), _B7(X),
42+
_B8(X), _B9(X), _BA(X), _BB(X), _BC(X), _BD(X), _BE(X), _BF(X),
43+
_C0(X), _C1(X), _C2(X), _C3(X), _C4(X), _C5(X), _C6(X), _C7(X),
44+
_C8(X), _C9(X), _CA(X), _CB(X), _CC(X), _CD(X), _CE(X), _CF(X),
45+
_D0(X), _D1(X), _D2(X), _D3(X), _D4(X), _D5(X), _D6(X), _D7(X),
46+
_D8(X), _D9(X), _DA(X), _DB(X), _DC(X), _DD(X), _DE(X), _DF(X),
47+
_E0(X), _E1(X), _E2(X), _E3(X), _E4(X), _E5(X), _E6(X), _E7(X),
48+
_E8(X), _E9(X), _EA(X), _EB(X), _EC(X), _ED(X), _EE(X), _EF(X),
49+
_F0(X), _F1(X), _F2(X), _F3(X), _F4(X), _F5(X), _F6(X), _F7(X),
50+
_F8(X), _F9(X), _FA(X), _FB(X), _FC(X), _FD(X), _FE(X), _FF(X),
51+
52+
V3,
53+
V4,
54+
}
55+
56+
fn main() {
57+
if let E1::V2 { .. } = (E1::V1 { f: true }) {
58+
unreachable!()
59+
}
60+
61+
if let E2::V1 { .. } = E2::V3::<Infallible> {
62+
unreachable!()
63+
}
64+
}

0 commit comments

Comments
 (0)