Skip to content

Commit a1daa34

Browse files
committed
Use MirPatch in EnumSizeOpt.
Instead of `expand_statements`. This makes the code shorter and consistent with other MIR transform passes. The tests require updating because there is a slight change in MIR output: - the old code replaced the original statement with twelve new statements. - the new code inserts converts the original statement to a `nop` and then insert twelve new statements in front of it. I.e. we now end up with an extra `nop`, which doesn't matter at all.
1 parent ce36a96 commit a1daa34

5 files changed

+82
-112
lines changed

compiler/rustc_mir_transform/src/large_enums.rs

+74-112
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use rustc_middle::ty::util::IntTypeExt;
66
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
77
use rustc_session::Session;
88

9+
use crate::patch::MirPatch;
10+
911
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
1012
/// enough discrepancy between them.
1113
///
@@ -41,31 +43,34 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
4143
let mut alloc_cache = FxHashMap::default();
4244
let typing_env = body.typing_env(tcx);
4345

44-
let blocks = body.basic_blocks.as_mut();
45-
let local_decls = &mut body.local_decls;
46+
let mut patch = MirPatch::new(body);
4647

47-
for bb in blocks {
48-
bb.expand_statements(|st| {
48+
for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
49+
for (statement_index, st) in data.statements.iter_mut().enumerate() {
4950
let StatementKind::Assign(box (
5051
lhs,
5152
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
5253
)) = &st.kind
5354
else {
54-
return None;
55+
continue;
5556
};
5657

57-
let ty = lhs.ty(local_decls, tcx).ty;
58+
let location = Location { block, statement_index };
5859

59-
let (adt_def, num_variants, alloc_id) =
60-
self.candidate(tcx, typing_env, ty, &mut alloc_cache)?;
60+
let ty = lhs.ty(&body.local_decls, tcx).ty;
6161

62-
let source_info = st.source_info;
63-
let span = source_info.span;
62+
let Some((adt_def, num_variants, alloc_id)) =
63+
self.candidate(tcx, typing_env, ty, &mut alloc_cache)
64+
else {
65+
continue;
66+
};
67+
68+
let span = st.source_info.span;
6469

6570
let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
66-
let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
67-
let store_live =
68-
Statement { source_info, kind: StatementKind::StorageLive(size_array_local) };
71+
let size_array_local = patch.new_temp(tmp_ty, span);
72+
73+
let store_live = StatementKind::StorageLive(size_array_local);
6974

7075
let place = Place::from(size_array_local);
7176
let constant_vals = ConstOperand {
@@ -77,108 +82,63 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
7782
),
7883
};
7984
let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
80-
let const_assign =
81-
Statement { source_info, kind: StatementKind::Assign(Box::new((place, rval))) };
82-
83-
let discr_place = Place::from(
84-
local_decls.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
85-
);
86-
let store_discr = Statement {
87-
source_info,
88-
kind: StatementKind::Assign(Box::new((
89-
discr_place,
90-
Rvalue::Discriminant(*rhs),
91-
))),
92-
};
93-
94-
let discr_cast_place =
95-
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
96-
let cast_discr = Statement {
97-
source_info,
98-
kind: StatementKind::Assign(Box::new((
99-
discr_cast_place,
100-
Rvalue::Cast(
101-
CastKind::IntToInt,
102-
Operand::Copy(discr_place),
103-
tcx.types.usize,
104-
),
105-
))),
106-
};
107-
108-
let size_place =
109-
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
110-
let store_size = Statement {
111-
source_info,
112-
kind: StatementKind::Assign(Box::new((
113-
size_place,
114-
Rvalue::Use(Operand::Copy(Place {
115-
local: size_array_local,
116-
projection: tcx
117-
.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
118-
})),
119-
))),
120-
};
121-
122-
let dst =
123-
Place::from(local_decls.push(LocalDecl::new(Ty::new_mut_ptr(tcx, ty), span)));
124-
let dst_ptr = Statement {
125-
source_info,
126-
kind: StatementKind::Assign(Box::new((
127-
dst,
128-
Rvalue::RawPtr(RawPtrKind::Mut, *lhs),
129-
))),
130-
};
85+
let const_assign = StatementKind::Assign(Box::new((place, rval)));
86+
87+
let discr_place =
88+
Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
89+
let store_discr =
90+
StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));
91+
92+
let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
93+
let cast_discr = StatementKind::Assign(Box::new((
94+
discr_cast_place,
95+
Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
96+
)));
97+
98+
let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
99+
let store_size = StatementKind::Assign(Box::new((
100+
size_place,
101+
Rvalue::Use(Operand::Copy(Place {
102+
local: size_array_local,
103+
projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
104+
})),
105+
)));
106+
107+
let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
108+
let dst_ptr =
109+
StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));
131110

132111
let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
133-
let dst_cast_place =
134-
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
135-
let dst_cast = Statement {
136-
source_info,
137-
kind: StatementKind::Assign(Box::new((
138-
dst_cast_place,
139-
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
140-
))),
141-
};
112+
let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
113+
let dst_cast = StatementKind::Assign(Box::new((
114+
dst_cast_place,
115+
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
116+
)));
142117

143-
let src =
144-
Place::from(local_decls.push(LocalDecl::new(Ty::new_imm_ptr(tcx, ty), span)));
145-
let src_ptr = Statement {
146-
source_info,
147-
kind: StatementKind::Assign(Box::new((
148-
src,
149-
Rvalue::RawPtr(RawPtrKind::Const, *rhs),
150-
))),
151-
};
118+
let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
119+
let src_ptr =
120+
StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));
152121

153122
let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
154-
let src_cast_place =
155-
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
156-
let src_cast = Statement {
157-
source_info,
158-
kind: StatementKind::Assign(Box::new((
159-
src_cast_place,
160-
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
161-
))),
162-
};
123+
let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
124+
let src_cast = StatementKind::Assign(Box::new((
125+
src_cast_place,
126+
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
127+
)));
163128

164-
let deinit_old =
165-
Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };
166-
167-
let copy_bytes = Statement {
168-
source_info,
169-
kind: StatementKind::Intrinsic(Box::new(
170-
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
171-
src: Operand::Copy(src_cast_place),
172-
dst: Operand::Copy(dst_cast_place),
173-
count: Operand::Copy(size_place),
174-
}),
175-
)),
176-
};
129+
let deinit_old = StatementKind::Deinit(Box::new(dst));
130+
131+
let copy_bytes = StatementKind::Intrinsic(Box::new(
132+
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
133+
src: Operand::Copy(src_cast_place),
134+
dst: Operand::Copy(dst_cast_place),
135+
count: Operand::Copy(size_place),
136+
}),
137+
));
177138

178-
let store_dead =
179-
Statement { source_info, kind: StatementKind::StorageDead(size_array_local) };
139+
let store_dead = StatementKind::StorageDead(size_array_local);
180140

181-
let iter = [
141+
let stmts = [
182142
store_live,
183143
const_assign,
184144
store_discr,
@@ -191,14 +151,16 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
191151
deinit_old,
192152
copy_bytes,
193153
store_dead,
194-
]
195-
.into_iter();
154+
];
155+
for stmt in stmts {
156+
patch.add_statement(location, stmt);
157+
}
196158

197159
st.make_nop();
198-
199-
Some(iter)
200-
});
160+
}
201161
}
162+
163+
patch.apply(body);
202164
}
203165

204166
fn is_required(&self) -> bool {

tests/mir-opt/enum_opt.cand.EnumSizeOpt.32bit.diff

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
+ Deinit(_8);
4848
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
4949
+ StorageDead(_4);
50+
+ nop;
5051
StorageDead(_2);
5152
- _0 = move _1;
5253
+ StorageLive(_12);
@@ -61,6 +62,7 @@
6162
+ Deinit(_16);
6263
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
6364
+ StorageDead(_12);
65+
+ nop;
6466
StorageDead(_1);
6567
return;
6668
}

tests/mir-opt/enum_opt.cand.EnumSizeOpt.64bit.diff

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
+ Deinit(_8);
4848
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
4949
+ StorageDead(_4);
50+
+ nop;
5051
StorageDead(_2);
5152
- _0 = move _1;
5253
+ StorageLive(_12);
@@ -61,6 +62,7 @@
6162
+ Deinit(_16);
6263
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
6364
+ StorageDead(_12);
65+
+ nop;
6466
StorageDead(_1);
6567
return;
6668
}

tests/mir-opt/enum_opt.unin.EnumSizeOpt.32bit.diff

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
+ Deinit(_8);
4848
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
4949
+ StorageDead(_4);
50+
+ nop;
5051
StorageDead(_2);
5152
- _0 = move _1;
5253
+ StorageLive(_12);
@@ -61,6 +62,7 @@
6162
+ Deinit(_16);
6263
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
6364
+ StorageDead(_12);
65+
+ nop;
6466
StorageDead(_1);
6567
return;
6668
}

tests/mir-opt/enum_opt.unin.EnumSizeOpt.64bit.diff

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
+ Deinit(_8);
4848
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
4949
+ StorageDead(_4);
50+
+ nop;
5051
StorageDead(_2);
5152
- _0 = move _1;
5253
+ StorageLive(_12);
@@ -61,6 +62,7 @@
6162
+ Deinit(_16);
6263
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
6364
+ StorageDead(_12);
65+
+ nop;
6466
StorageDead(_1);
6567
return;
6668
}

0 commit comments

Comments
 (0)