Skip to content

Commit 143f393

Browse files
committed
Don't alloca just to look at a discriminant
Today we're making LLVM do a bunch of extra work for every enum you match on, even trivial stuff like `Option<bool>`. Let's not.
1 parent 6650252 commit 143f393

File tree

9 files changed

+177
-167
lines changed

9 files changed

+177
-167
lines changed

Diff for: compiler/rustc_codegen_ssa/src/mir/analyze.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
205205
| PlaceContext::MutatingUse(MutatingUseContext::Retag) => {}
206206

207207
PlaceContext::NonMutatingUse(
208-
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
208+
NonMutatingUseContext::Copy
209+
| NonMutatingUseContext::Move
210+
// Inspect covers things like `PtrMetadata` and `Discriminant`
211+
// which we can treat similar to `Copy` use for the purpose of
212+
// whether we can use SSA variables for things.
213+
| NonMutatingUseContext::Inspect,
209214
) => match &mut self.locals[local] {
210215
LocalKind::ZST => {}
211216
LocalKind::Memory => {}
@@ -229,8 +234,7 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
229234
| MutatingUseContext::Projection,
230235
)
231236
| PlaceContext::NonMutatingUse(
232-
NonMutatingUseContext::Inspect
233-
| NonMutatingUseContext::SharedBorrow
237+
NonMutatingUseContext::SharedBorrow
234238
| NonMutatingUseContext::FakeBorrow
235239
| NonMutatingUseContext::RawBorrow
236240
| NonMutatingUseContext::Projection,

Diff for: compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

+1-9
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
6262
let callee_ty = instance.ty(bx.tcx(), bx.typing_env());
6363

6464
let ty::FnDef(def_id, fn_args) = *callee_ty.kind() else {
65-
bug!("expected fn item type, found {}", callee_ty);
65+
span_bug!(span, "expected fn item type, found {}", callee_ty);
6666
};
6767

6868
let sig = callee_ty.fn_sig(bx.tcx());
@@ -325,14 +325,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
325325
}
326326
}
327327

328-
sym::discriminant_value => {
329-
if ret_ty.is_integral() {
330-
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty)
331-
} else {
332-
span_bug!(span, "Invalid discriminant type for `{:?}`", arg_tys[0])
333-
}
334-
}
335-
336328
// This requires that atomic intrinsics follow a specific naming pattern:
337329
// "atomic_<operation>[_<ordering>]"
338330
name if let Some(atomic) = name_str.strip_prefix("atomic_") => {

Diff for: compiler/rustc_codegen_ssa/src/mir/operand.rs

+137-2
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@ use std::fmt;
33
use arrayvec::ArrayVec;
44
use either::Either;
55
use rustc_abi as abi;
6-
use rustc_abi::{Align, BackendRepr, Size};
6+
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
77
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
88
use rustc_middle::mir::{self, ConstValue};
99
use rustc_middle::ty::Ty;
1010
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1111
use rustc_middle::{bug, span_bug};
12-
use tracing::debug;
12+
use tracing::{debug, instrument};
1313

1414
use super::place::{PlaceRef, PlaceValue};
1515
use super::{FunctionCx, LocalRef};
16+
use crate::common::IntPredicate;
1617
use crate::traits::*;
1718
use crate::{MemFlags, size_of_val};
1819

@@ -415,6 +416,140 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
415416

416417
OperandRef { val, layout: field }
417418
}
419+
420+
/// Obtain the actual discriminant of a value.
421+
#[instrument(level = "trace", skip(fx, bx))]
422+
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
423+
self,
424+
fx: &mut FunctionCx<'a, 'tcx, Bx>,
425+
bx: &mut Bx,
426+
cast_to: Ty<'tcx>,
427+
) -> V {
428+
let dl = &bx.tcx().data_layout;
429+
let cast_to_layout = bx.cx().layout_of(cast_to);
430+
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
431+
if self.layout.is_uninhabited() {
432+
return bx.cx().const_poison(cast_to);
433+
}
434+
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
435+
Variants::Empty => unreachable!("we already handled uninhabited types"),
436+
Variants::Single { index } => {
437+
let discr_val =
438+
if let Some(discr) = self.layout.ty.discriminant_for_variant(bx.tcx(), index) {
439+
discr.val
440+
} else {
441+
assert_eq!(index, FIRST_VARIANT);
442+
0
443+
};
444+
return bx.cx().const_uint_big(cast_to, discr_val);
445+
}
446+
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
447+
(tag, tag_encoding, tag_field)
448+
}
449+
};
450+
451+
// Read the tag/niche-encoded discriminant from memory.
452+
let tag_op = match self.val {
453+
OperandValue::ZeroSized => bug!(),
454+
OperandValue::Immediate(_) | OperandValue::Pair(_, _) => {
455+
self.extract_field(fx, bx, tag_field)
456+
}
457+
OperandValue::Ref(place) => {
458+
let tag = place.with_type(self.layout).project_field(bx, tag_field);
459+
bx.load_operand(tag)
460+
}
461+
};
462+
let tag_imm = tag_op.immediate();
463+
464+
// Decode the discriminant (specifically if it's niche-encoded).
465+
match *tag_encoding {
466+
TagEncoding::Direct => {
467+
let signed = match tag_scalar.primitive() {
468+
// We use `i1` for bytes that are always `0` or `1`,
469+
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
470+
// let LLVM interpret the `i1` as signed, because
471+
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
472+
Primitive::Int(_, signed) => !tag_scalar.is_bool() && signed,
473+
_ => false,
474+
};
475+
bx.intcast(tag_imm, cast_to, signed)
476+
}
477+
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
478+
// Cast to an integer so we don't have to treat a pointer as a
479+
// special case.
480+
let (tag, tag_llty) = match tag_scalar.primitive() {
481+
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
482+
Primitive::Pointer(_) => {
483+
let t = bx.type_from_integer(dl.ptr_sized_integer());
484+
let tag = bx.ptrtoint(tag_imm, t);
485+
(tag, t)
486+
}
487+
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
488+
};
489+
490+
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
491+
492+
// We have a subrange `niche_start..=niche_end` inside `range`.
493+
// If the value of the tag is inside this subrange, it's a
494+
// "niche value", an increment of the discriminant. Otherwise it
495+
// indicates the untagged variant.
496+
// A general algorithm to extract the discriminant from the tag
497+
// is:
498+
// relative_tag = tag - niche_start
499+
// is_niche = relative_tag <= (ule) relative_max
500+
// discr = if is_niche {
501+
// cast(relative_tag) + niche_variants.start()
502+
// } else {
503+
// untagged_variant
504+
// }
505+
// However, we will likely be able to emit simpler code.
506+
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
507+
// Best case scenario: only one tagged variant. This will
508+
// likely become just a comparison and a jump.
509+
// The algorithm is:
510+
// is_niche = tag == niche_start
511+
// discr = if is_niche {
512+
// niche_start
513+
// } else {
514+
// untagged_variant
515+
// }
516+
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
517+
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
518+
let tagged_discr =
519+
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
520+
(is_niche, tagged_discr, 0)
521+
} else {
522+
// The special cases don't apply, so we'll have to go with
523+
// the general algorithm.
524+
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
525+
let cast_tag = bx.intcast(relative_discr, cast_to, false);
526+
let is_niche = bx.icmp(
527+
IntPredicate::IntULE,
528+
relative_discr,
529+
bx.cx().const_uint(tag_llty, relative_max as u64),
530+
);
531+
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
532+
};
533+
534+
let tagged_discr = if delta == 0 {
535+
tagged_discr
536+
} else {
537+
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
538+
};
539+
540+
let discr = bx.select(
541+
is_niche,
542+
tagged_discr,
543+
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
544+
);
545+
546+
// In principle we could insert assumes on the possible range of `discr`, but
547+
// currently in LLVM this seems to be a pessimization.
548+
549+
discr
550+
}
551+
}
552+
}
418553
}
419554

420555
impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {

Diff for: compiler/rustc_codegen_ssa/src/mir/place.rs

-124
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use rustc_abi::Primitive::{Int, Pointer};
21
use rustc_abi::{Align, BackendRepr, FieldsShape, Size, TagEncoding, VariantIdx, Variants};
32
use rustc_middle::mir::PlaceTy;
43
use rustc_middle::mir::interpret::Scalar;
@@ -233,129 +232,6 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
233232
val.with_type(field)
234233
}
235234

236-
/// Obtain the actual discriminant of a value.
237-
#[instrument(level = "trace", skip(bx))]
238-
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
239-
self,
240-
bx: &mut Bx,
241-
cast_to: Ty<'tcx>,
242-
) -> V {
243-
let dl = &bx.tcx().data_layout;
244-
let cast_to_layout = bx.cx().layout_of(cast_to);
245-
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
246-
if self.layout.is_uninhabited() {
247-
return bx.cx().const_poison(cast_to);
248-
}
249-
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
250-
Variants::Empty => unreachable!("we already handled uninhabited types"),
251-
Variants::Single { index } => {
252-
let discr_val = self
253-
.layout
254-
.ty
255-
.discriminant_for_variant(bx.cx().tcx(), index)
256-
.map_or(index.as_u32() as u128, |discr| discr.val);
257-
return bx.cx().const_uint_big(cast_to, discr_val);
258-
}
259-
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
260-
(tag, tag_encoding, tag_field)
261-
}
262-
};
263-
264-
// Read the tag/niche-encoded discriminant from memory.
265-
let tag = self.project_field(bx, tag_field);
266-
let tag_op = bx.load_operand(tag);
267-
let tag_imm = tag_op.immediate();
268-
269-
// Decode the discriminant (specifically if it's niche-encoded).
270-
match *tag_encoding {
271-
TagEncoding::Direct => {
272-
let signed = match tag_scalar.primitive() {
273-
// We use `i1` for bytes that are always `0` or `1`,
274-
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
275-
// let LLVM interpret the `i1` as signed, because
276-
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
277-
Int(_, signed) => !tag_scalar.is_bool() && signed,
278-
_ => false,
279-
};
280-
bx.intcast(tag_imm, cast_to, signed)
281-
}
282-
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
283-
// Cast to an integer so we don't have to treat a pointer as a
284-
// special case.
285-
let (tag, tag_llty) = match tag_scalar.primitive() {
286-
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
287-
Pointer(_) => {
288-
let t = bx.type_from_integer(dl.ptr_sized_integer());
289-
let tag = bx.ptrtoint(tag_imm, t);
290-
(tag, t)
291-
}
292-
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
293-
};
294-
295-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
296-
297-
// We have a subrange `niche_start..=niche_end` inside `range`.
298-
// If the value of the tag is inside this subrange, it's a
299-
// "niche value", an increment of the discriminant. Otherwise it
300-
// indicates the untagged variant.
301-
// A general algorithm to extract the discriminant from the tag
302-
// is:
303-
// relative_tag = tag - niche_start
304-
// is_niche = relative_tag <= (ule) relative_max
305-
// discr = if is_niche {
306-
// cast(relative_tag) + niche_variants.start()
307-
// } else {
308-
// untagged_variant
309-
// }
310-
// However, we will likely be able to emit simpler code.
311-
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
312-
// Best case scenario: only one tagged variant. This will
313-
// likely become just a comparison and a jump.
314-
// The algorithm is:
315-
// is_niche = tag == niche_start
316-
// discr = if is_niche {
317-
// niche_start
318-
// } else {
319-
// untagged_variant
320-
// }
321-
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
322-
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
323-
let tagged_discr =
324-
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
325-
(is_niche, tagged_discr, 0)
326-
} else {
327-
// The special cases don't apply, so we'll have to go with
328-
// the general algorithm.
329-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
330-
let cast_tag = bx.intcast(relative_discr, cast_to, false);
331-
let is_niche = bx.icmp(
332-
IntPredicate::IntULE,
333-
relative_discr,
334-
bx.cx().const_uint(tag_llty, relative_max as u64),
335-
);
336-
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
337-
};
338-
339-
let tagged_discr = if delta == 0 {
340-
tagged_discr
341-
} else {
342-
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
343-
};
344-
345-
let discr = bx.select(
346-
is_niche,
347-
tagged_discr,
348-
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
349-
);
350-
351-
// In principle we could insert assumes on the possible range of `discr`, but
352-
// currently in LLVM this seems to be a pessimization.
353-
354-
discr
355-
}
356-
}
357-
}
358-
359235
/// Sets the discriminant for a new value of the given case of the given
360236
/// representation.
361237
pub fn codegen_set_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(

Diff for: compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
706706
mir::Rvalue::Discriminant(ref place) => {
707707
let discr_ty = rvalue.ty(self.mir, bx.tcx());
708708
let discr_ty = self.monomorphize(discr_ty);
709-
let discr = self.codegen_place(bx, place.as_ref()).codegen_get_discr(bx, discr_ty);
709+
let operand = self.codegen_consume(bx, place.as_ref());
710+
let discr = operand.codegen_get_discr(self, bx, discr_ty);
710711
OperandRef {
711712
val: OperandValue::Immediate(discr),
712713
layout: self.cx.layout_of(discr_ty),

0 commit comments

Comments
 (0)