Skip to content

Commit ce3263e

Browse files
committed
Auto merge of rust-lang#124113 - RalfJung:interpret-scalar-ops, r=oli-obk
interpret: use ScalarInt for bin-ops; avoid PartialOrd for ScalarInt Best reviewed commit-by-commit r? `@oli-obk`
2 parents d1a0fa5 + d3f927d commit ce3263e

File tree

15 files changed

+211
-152
lines changed

15 files changed

+211
-152
lines changed

Diff for: compiler/rustc_codegen_cranelift/src/constant.rs

+19-22
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ pub(crate) fn codegen_const_value<'tcx>(
110110
if fx.clif_type(layout.ty).is_some() {
111111
return CValue::const_val(fx, layout, int);
112112
} else {
113-
let raw_val = int.size().truncate(int.to_bits(int.size()).unwrap());
113+
let raw_val = int.size().truncate(int.assert_bits(int.size()));
114114
let val = match int.size().bytes() {
115115
1 => fx.bcx.ins().iconst(types::I8, raw_val as i64),
116116
2 => fx.bcx.ins().iconst(types::I16, raw_val as i64),
@@ -491,27 +491,24 @@ pub(crate) fn mir_operand_get_const_val<'tcx>(
491491
return None;
492492
}
493493
let scalar_int = mir_operand_get_const_val(fx, operand)?;
494-
let scalar_int = match fx
495-
.layout_of(*ty)
496-
.size
497-
.cmp(&scalar_int.size())
498-
{
499-
Ordering::Equal => scalar_int,
500-
Ordering::Less => match ty.kind() {
501-
ty::Uint(_) => ScalarInt::try_from_uint(
502-
scalar_int.try_to_uint(scalar_int.size()).unwrap(),
503-
fx.layout_of(*ty).size,
504-
)
505-
.unwrap(),
506-
ty::Int(_) => ScalarInt::try_from_int(
507-
scalar_int.try_to_int(scalar_int.size()).unwrap(),
508-
fx.layout_of(*ty).size,
509-
)
510-
.unwrap(),
511-
_ => unreachable!(),
512-
},
513-
Ordering::Greater => return None,
514-
};
494+
let scalar_int =
495+
match fx.layout_of(*ty).size.cmp(&scalar_int.size()) {
496+
Ordering::Equal => scalar_int,
497+
Ordering::Less => match ty.kind() {
498+
ty::Uint(_) => ScalarInt::try_from_uint(
499+
scalar_int.assert_uint(scalar_int.size()),
500+
fx.layout_of(*ty).size,
501+
)
502+
.unwrap(),
503+
ty::Int(_) => ScalarInt::try_from_int(
504+
scalar_int.assert_int(scalar_int.size()),
505+
fx.layout_of(*ty).size,
506+
)
507+
.unwrap(),
508+
_ => unreachable!(),
509+
},
510+
Ordering::Greater => return None,
511+
};
515512
computed_scalar_int = Some(scalar_int);
516513
}
517514
Rvalue::Use(operand) => {

Diff for: compiler/rustc_codegen_cranelift/src/value_and_place.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ impl<'tcx> CValue<'tcx> {
326326

327327
let val = match layout.ty.kind() {
328328
ty::Uint(UintTy::U128) | ty::Int(IntTy::I128) => {
329-
let const_val = const_val.to_bits(layout.size).unwrap();
329+
let const_val = const_val.assert_bits(layout.size);
330330
let lsb = fx.bcx.ins().iconst(types::I64, const_val as u64 as i64);
331331
let msb = fx.bcx.ins().iconst(types::I64, (const_val >> 64) as u64 as i64);
332332
fx.bcx.ins().iconcat(lsb, msb)
@@ -338,7 +338,7 @@ impl<'tcx> CValue<'tcx> {
338338
| ty::Ref(..)
339339
| ty::RawPtr(..)
340340
| ty::FnPtr(..) => {
341-
let raw_val = const_val.size().truncate(const_val.to_bits(layout.size).unwrap());
341+
let raw_val = const_val.size().truncate(const_val.assert_bits(layout.size));
342342
fx.bcx.ins().iconst(clif_ty, raw_val as i64)
343343
}
344344
ty::Float(FloatTy::F32) => {

Diff for: compiler/rustc_const_eval/src/interpret/discriminant.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
295295
&niche_start_val,
296296
)?
297297
.to_scalar()
298-
.try_to_int()
299-
.unwrap();
298+
.assert_int();
300299
Ok(Some((tag, tag_field)))
301300
}
302301
}

Diff for: compiler/rustc_const_eval/src/interpret/operand.rs

+22-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ use std::assert_matches::assert_matches;
66
use either::{Either, Left, Right};
77

88
use rustc_hir::def::Namespace;
9+
use rustc_middle::mir::interpret::ScalarSizeMismatch;
910
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1011
use rustc_middle::ty::print::{FmtPrinter, PrettyPrinter};
11-
use rustc_middle::ty::{ConstInt, Ty, TyCtxt};
12+
use rustc_middle::ty::{ConstInt, ScalarInt, Ty, TyCtxt};
1213
use rustc_middle::{mir, ty};
1314
use rustc_target::abi::{self, Abi, HasDataLayout, Size};
1415

@@ -210,6 +211,12 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
210211
ImmTy { imm: Immediate::Uninit, layout }
211212
}
212213

214+
#[inline]
215+
pub fn from_scalar_int(s: ScalarInt, layout: TyAndLayout<'tcx>) -> Self {
216+
assert_eq!(s.size(), layout.size);
217+
Self::from_scalar(Scalar::from(s), layout)
218+
}
219+
213220
#[inline]
214221
pub fn try_from_uint(i: impl Into<u128>, layout: TyAndLayout<'tcx>) -> Option<Self> {
215222
Some(Self::from_scalar(Scalar::try_from_uint(i, layout.size)?, layout))
@@ -223,7 +230,6 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
223230
pub fn try_from_int(i: impl Into<i128>, layout: TyAndLayout<'tcx>) -> Option<Self> {
224231
Some(Self::from_scalar(Scalar::try_from_int(i, layout.size)?, layout))
225232
}
226-
227233
#[inline]
228234
pub fn from_int(i: impl Into<i128>, layout: TyAndLayout<'tcx>) -> Self {
229235
Self::from_scalar(Scalar::from_int(i, layout.size), layout)
@@ -242,6 +248,20 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
242248
Self::from_scalar(Scalar::from_i8(c as i8), layout)
243249
}
244250

251+
/// Return the immediate as a `ScalarInt`. Ensures that it has the size that the layout of the
252+
/// immediate indicates.
253+
#[inline]
254+
pub fn to_scalar_int(&self) -> InterpResult<'tcx, ScalarInt> {
255+
let s = self.to_scalar().to_scalar_int()?;
256+
if s.size() != self.layout.size {
257+
throw_ub!(ScalarSizeMismatch(ScalarSizeMismatch {
258+
target_size: self.layout.size.bytes(),
259+
data_size: s.size().bytes(),
260+
}));
261+
}
262+
Ok(s)
263+
}
264+
245265
#[inline]
246266
pub fn to_const_int(self) -> ConstInt {
247267
assert!(self.layout.ty.is_integral());

Diff for: compiler/rustc_const_eval/src/interpret/operator.rs

+59-54
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use rustc_apfloat::{Float, FloatConvert};
22
use rustc_middle::mir;
33
use rustc_middle::mir::interpret::{InterpResult, Scalar};
44
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
5-
use rustc_middle::ty::{self, FloatTy, Ty};
5+
use rustc_middle::ty::{self, FloatTy, ScalarInt, Ty};
66
use rustc_span::symbol::sym;
77
use rustc_target::abi::Abi;
88

@@ -146,14 +146,20 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
146146
fn binary_int_op(
147147
&self,
148148
bin_op: mir::BinOp,
149-
// passing in raw bits
150-
l: u128,
151-
left_layout: TyAndLayout<'tcx>,
152-
r: u128,
153-
right_layout: TyAndLayout<'tcx>,
149+
left: &ImmTy<'tcx, M::Provenance>,
150+
right: &ImmTy<'tcx, M::Provenance>,
154151
) -> InterpResult<'tcx, (ImmTy<'tcx, M::Provenance>, bool)> {
155152
use rustc_middle::mir::BinOp::*;
156153

154+
// This checks the size, so that we can just assert it below.
155+
let l = left.to_scalar_int()?;
156+
let r = right.to_scalar_int()?;
157+
// Prepare to convert the values to signed or unsigned form.
158+
let l_signed = || l.assert_int(left.layout.size);
159+
let l_unsigned = || l.assert_uint(left.layout.size);
160+
let r_signed = || r.assert_int(right.layout.size);
161+
let r_unsigned = || r.assert_uint(right.layout.size);
162+
157163
let throw_ub_on_overflow = match bin_op {
158164
AddUnchecked => Some(sym::unchecked_add),
159165
SubUnchecked => Some(sym::unchecked_sub),
@@ -165,69 +171,72 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
165171

166172
// Shift ops can have an RHS with a different numeric type.
167173
if matches!(bin_op, Shl | ShlUnchecked | Shr | ShrUnchecked) {
168-
let size = left_layout.size.bits();
174+
let size = left.layout.size.bits();
169175
// The shift offset is implicitly masked to the type size. (This is the one MIR operator
170176
// that does *not* directly map to a single LLVM operation.) Compute how much we
171177
// actually shift and whether there was an overflow due to shifting too much.
172-
let (shift_amount, overflow) = if right_layout.abi.is_signed() {
173-
let shift_amount = self.sign_extend(r, right_layout) as i128;
178+
let (shift_amount, overflow) = if right.layout.abi.is_signed() {
179+
let shift_amount = r_signed();
174180
let overflow = shift_amount < 0 || shift_amount >= i128::from(size);
181+
// Deliberately wrapping `as` casts: shift_amount *can* be negative, but the result
182+
// of the `as` will be equal modulo `size` (since it is a power of two).
175183
let masked_amount = (shift_amount as u128) % u128::from(size);
176-
debug_assert_eq!(overflow, shift_amount != (masked_amount as i128));
184+
assert_eq!(overflow, shift_amount != (masked_amount as i128));
177185
(masked_amount, overflow)
178186
} else {
179-
let shift_amount = r;
187+
let shift_amount = r_unsigned();
180188
let masked_amount = shift_amount % u128::from(size);
181189
(masked_amount, shift_amount != masked_amount)
182190
};
183191
let shift_amount = u32::try_from(shift_amount).unwrap(); // we masked so this will always fit
184192
// Compute the shifted result.
185-
let result = if left_layout.abi.is_signed() {
186-
let l = self.sign_extend(l, left_layout) as i128;
193+
let result = if left.layout.abi.is_signed() {
194+
let l = l_signed();
187195
let result = match bin_op {
188196
Shl | ShlUnchecked => l.checked_shl(shift_amount).unwrap(),
189197
Shr | ShrUnchecked => l.checked_shr(shift_amount).unwrap(),
190198
_ => bug!(),
191199
};
192-
result as u128
200+
ScalarInt::truncate_from_int(result, left.layout.size).0
193201
} else {
194-
match bin_op {
202+
let l = l_unsigned();
203+
let result = match bin_op {
195204
Shl | ShlUnchecked => l.checked_shl(shift_amount).unwrap(),
196205
Shr | ShrUnchecked => l.checked_shr(shift_amount).unwrap(),
197206
_ => bug!(),
198-
}
207+
};
208+
ScalarInt::truncate_from_uint(result, left.layout.size).0
199209
};
200-
let truncated = self.truncate(result, left_layout);
201210

202211
if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
203212
throw_ub_custom!(
204213
fluent::const_eval_overflow_shift,
205-
val = if right_layout.abi.is_signed() {
206-
(self.sign_extend(r, right_layout) as i128).to_string()
214+
val = if right.layout.abi.is_signed() {
215+
r_signed().to_string()
207216
} else {
208-
r.to_string()
217+
r_unsigned().to_string()
209218
},
210219
name = intrinsic_name
211220
);
212221
}
213222

214-
return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
223+
return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
215224
}
216225

217226
// For the remaining ops, the types must be the same on both sides
218-
if left_layout.ty != right_layout.ty {
227+
if left.layout.ty != right.layout.ty {
219228
span_bug!(
220229
self.cur_span(),
221230
"invalid asymmetric binary op {bin_op:?}: {l:?} ({l_ty}), {r:?} ({r_ty})",
222-
l_ty = left_layout.ty,
223-
r_ty = right_layout.ty,
231+
l_ty = left.layout.ty,
232+
r_ty = right.layout.ty,
224233
)
225234
}
226235

227-
let size = left_layout.size;
236+
let size = left.layout.size;
228237

229238
// Operations that need special treatment for signed integers
230-
if left_layout.abi.is_signed() {
239+
if left.layout.abi.is_signed() {
231240
let op: Option<fn(&i128, &i128) -> bool> = match bin_op {
232241
Lt => Some(i128::lt),
233242
Le => Some(i128::le),
@@ -236,18 +245,14 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
236245
_ => None,
237246
};
238247
if let Some(op) = op {
239-
let l = self.sign_extend(l, left_layout) as i128;
240-
let r = self.sign_extend(r, right_layout) as i128;
241-
return Ok((ImmTy::from_bool(op(&l, &r), *self.tcx), false));
248+
return Ok((ImmTy::from_bool(op(&l_signed(), &r_signed()), *self.tcx), false));
242249
}
243250
if bin_op == Cmp {
244-
let l = self.sign_extend(l, left_layout) as i128;
245-
let r = self.sign_extend(r, right_layout) as i128;
246-
return Ok(self.three_way_compare(l, r));
251+
return Ok(self.three_way_compare(l_signed(), r_signed()));
247252
}
248253
let op: Option<fn(i128, i128) -> (i128, bool)> = match bin_op {
249-
Div if r == 0 => throw_ub!(DivisionByZero),
250-
Rem if r == 0 => throw_ub!(RemainderByZero),
254+
Div if r.is_null() => throw_ub!(DivisionByZero),
255+
Rem if r.is_null() => throw_ub!(RemainderByZero),
251256
Div => Some(i128::overflowing_div),
252257
Rem => Some(i128::overflowing_rem),
253258
Add | AddUnchecked => Some(i128::overflowing_add),
@@ -256,8 +261,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
256261
_ => None,
257262
};
258263
if let Some(op) = op {
259-
let l = self.sign_extend(l, left_layout) as i128;
260-
let r = self.sign_extend(r, right_layout) as i128;
264+
let l = l_signed();
265+
let r = r_signed();
261266

262267
// We need a special check for overflowing Rem and Div since they are *UB*
263268
// on overflow, which can happen with "int_min $OP -1".
@@ -272,17 +277,19 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
272277
}
273278

274279
let (result, oflo) = op(l, r);
275-
// This may be out-of-bounds for the result type, so we have to truncate ourselves.
280+
// This may be out-of-bounds for the result type, so we have to truncate.
276281
// If that truncation loses any information, we have an overflow.
277-
let result = result as u128;
278-
let truncated = self.truncate(result, left_layout);
279-
let overflow = oflo || self.sign_extend(truncated, left_layout) != result;
282+
let (result, lossy) = ScalarInt::truncate_from_int(result, left.layout.size);
283+
let overflow = oflo || lossy;
280284
if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
281285
throw_ub_custom!(fluent::const_eval_overflow, name = intrinsic_name);
282286
}
283-
return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
287+
return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
284288
}
285289
}
290+
// From here on it's okay to treat everything as unsigned.
291+
let l = l_unsigned();
292+
let r = r_unsigned();
286293

287294
if bin_op == Cmp {
288295
return Ok(self.three_way_compare(l, r));
@@ -297,12 +304,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
297304
Gt => ImmTy::from_bool(l > r, *self.tcx),
298305
Ge => ImmTy::from_bool(l >= r, *self.tcx),
299306

300-
BitOr => ImmTy::from_uint(l | r, left_layout),
301-
BitAnd => ImmTy::from_uint(l & r, left_layout),
302-
BitXor => ImmTy::from_uint(l ^ r, left_layout),
307+
BitOr => ImmTy::from_uint(l | r, left.layout),
308+
BitAnd => ImmTy::from_uint(l & r, left.layout),
309+
BitXor => ImmTy::from_uint(l ^ r, left.layout),
303310

304311
Add | AddUnchecked | Sub | SubUnchecked | Mul | MulUnchecked | Rem | Div => {
305-
assert!(!left_layout.abi.is_signed());
312+
assert!(!left.layout.abi.is_signed());
306313
let op: fn(u128, u128) -> (u128, bool) = match bin_op {
307314
Add | AddUnchecked => u128::overflowing_add,
308315
Sub | SubUnchecked => u128::overflowing_sub,
@@ -316,21 +323,21 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
316323
let (result, oflo) = op(l, r);
317324
// Truncate to target type.
318325
// If that truncation loses any information, we have an overflow.
319-
let truncated = self.truncate(result, left_layout);
320-
let overflow = oflo || truncated != result;
326+
let (result, lossy) = ScalarInt::truncate_from_uint(result, left.layout.size);
327+
let overflow = oflo || lossy;
321328
if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
322329
throw_ub_custom!(fluent::const_eval_overflow, name = intrinsic_name);
323330
}
324-
return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
331+
return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
325332
}
326333

327334
_ => span_bug!(
328335
self.cur_span(),
329336
"invalid binary op {:?}: {:?}, {:?} (both {})",
330337
bin_op,
331-
l,
332-
r,
333-
right_layout.ty,
338+
left,
339+
right,
340+
right.layout.ty,
334341
),
335342
};
336343

@@ -427,9 +434,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
427434
right.layout.ty
428435
);
429436

430-
let l = left.to_scalar().to_bits(left.layout.size)?;
431-
let r = right.to_scalar().to_bits(right.layout.size)?;
432-
self.binary_int_op(bin_op, l, left.layout, r, right.layout)
437+
self.binary_int_op(bin_op, left, right)
433438
}
434439
_ if left.layout.ty.is_any_ptr() => {
435440
// The RHS type must be a `pointer` *or an integer type* (for `Offset`).

0 commit comments

Comments
 (0)