Skip to content

Commit 44aceb5

Browse files
committed
Auto merge of rust-lang#3185 - eduardosm:float_to_int_checked-generic, r=RalfJung
Refactor `float_to_int_checked` to remove its generic parameter and reduce code duplication a bit
2 parents 2a1e0ce + 2855024 commit 44aceb5

31 files changed

+185
-205
lines changed

src/tools/miri/src/helpers.rs

+57-47
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@ use std::time::Duration;
55

66
use log::trace;
77

8+
use rustc_apfloat::ieee::{Double, Single};
89
use rustc_hir::def::{DefKind, Namespace};
910
use rustc_hir::def_id::{DefId, CRATE_DEF_INDEX};
1011
use rustc_index::IndexVec;
1112
use rustc_middle::mir;
1213
use rustc_middle::ty::{
1314
self,
14-
layout::{IntegerExt as _, LayoutOf, TyAndLayout},
15-
IntTy, Ty, TyCtxt, UintTy,
15+
layout::{LayoutOf, TyAndLayout},
16+
FloatTy, IntTy, Ty, TyCtxt, UintTy,
1617
};
1718
use rustc_span::{def_id::CrateNum, sym, Span, Symbol};
18-
use rustc_target::abi::{Align, FieldIdx, FieldsShape, Integer, Size, Variants};
19+
use rustc_target::abi::{Align, FieldIdx, FieldsShape, Size, Variants};
1920
use rustc_target::spec::abi::Abi;
2021

2122
use rand::RngCore;
@@ -986,65 +987,74 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
986987
}
987988
}
988989

989-
/// Converts `f` to integer type `dest_ty` after rounding with mode `round`.
990+
/// Converts `src` from floating point to integer type `dest_ty`
991+
/// after rounding with mode `round`.
990992
/// Returns `None` if `f` is NaN or out of range.
991-
fn float_to_int_checked<F>(
993+
fn float_to_int_checked(
992994
&self,
993-
f: F,
995+
src: &ImmTy<'tcx, Provenance>,
994996
cast_to: TyAndLayout<'tcx>,
995997
round: rustc_apfloat::Round,
996-
) -> Option<ImmTy<'tcx, Provenance>>
997-
where
998-
F: rustc_apfloat::Float + Into<Scalar<Provenance>>,
999-
{
998+
) -> InterpResult<'tcx, Option<ImmTy<'tcx, Provenance>>> {
1000999
let this = self.eval_context_ref();
10011000

1002-
let val = match cast_to.ty.kind() {
1003-
// Unsigned
1004-
ty::Uint(t) => {
1005-
let size = Integer::from_uint_ty(this, *t).size();
1006-
let res = f.to_u128_r(size.bits_usize(), round, &mut false);
1007-
if res.status.intersects(
1008-
rustc_apfloat::Status::INVALID_OP
1009-
| rustc_apfloat::Status::OVERFLOW
1010-
| rustc_apfloat::Status::UNDERFLOW,
1011-
) {
1012-
// Floating point value is NaN (flagged with INVALID_OP) or outside the range
1013-
// of values of the integer type (flagged with OVERFLOW or UNDERFLOW).
1014-
return None;
1015-
} else {
1016-
// Floating point value can be represented by the integer type after rounding.
1017-
// The INEXACT flag is ignored on purpose to allow rounding.
1018-
Scalar::from_uint(res.value, size)
1001+
fn float_to_int_inner<'tcx, F: rustc_apfloat::Float>(
1002+
this: &MiriInterpCx<'_, 'tcx>,
1003+
src: F,
1004+
cast_to: TyAndLayout<'tcx>,
1005+
round: rustc_apfloat::Round,
1006+
) -> (Scalar<Provenance>, rustc_apfloat::Status) {
1007+
let int_size = cast_to.layout.size;
1008+
match cast_to.ty.kind() {
1009+
// Unsigned
1010+
ty::Uint(_) => {
1011+
let res = src.to_u128_r(int_size.bits_usize(), round, &mut false);
1012+
(Scalar::from_uint(res.value, int_size), res.status)
10191013
}
1020-
}
1021-
// Signed
1022-
ty::Int(t) => {
1023-
let size = Integer::from_int_ty(this, *t).size();
1024-
let res = f.to_i128_r(size.bits_usize(), round, &mut false);
1025-
if res.status.intersects(
1026-
rustc_apfloat::Status::INVALID_OP
1027-
| rustc_apfloat::Status::OVERFLOW
1028-
| rustc_apfloat::Status::UNDERFLOW,
1029-
) {
1030-
// Floating point value is NaN (flagged with INVALID_OP) or outside the range
1031-
// of values of the integer type (flagged with OVERFLOW or UNDERFLOW).
1032-
return None;
1033-
} else {
1034-
// Floating point value can be represented by the integer type after rounding.
1035-
// The INEXACT flag is ignored on purpose to allow rounding.
1036-
Scalar::from_int(res.value, size)
1014+
// Signed
1015+
ty::Int(_) => {
1016+
let res = src.to_i128_r(int_size.bits_usize(), round, &mut false);
1017+
(Scalar::from_int(res.value, int_size), res.status)
10371018
}
1019+
// Nothing else
1020+
_ =>
1021+
span_bug!(
1022+
this.cur_span(),
1023+
"attempted float-to-int conversion with non-int output type {}",
1024+
cast_to.ty,
1025+
),
10381026
}
1027+
}
1028+
1029+
let (val, status) = match src.layout.ty.kind() {
1030+
// f32
1031+
ty::Float(FloatTy::F32) =>
1032+
float_to_int_inner::<Single>(this, src.to_scalar().to_f32()?, cast_to, round),
1033+
// f64
1034+
ty::Float(FloatTy::F64) =>
1035+
float_to_int_inner::<Double>(this, src.to_scalar().to_f64()?, cast_to, round),
10391036
// Nothing else
10401037
_ =>
10411038
span_bug!(
10421039
this.cur_span(),
1043-
"attempted float-to-int conversion with non-int output type {}",
1044-
cast_to.ty,
1040+
"attempted float-to-int conversion with non-float input type {}",
1041+
src.layout.ty,
10451042
),
10461043
};
1047-
Some(ImmTy::from_scalar(val, cast_to))
1044+
1045+
if status.intersects(
1046+
rustc_apfloat::Status::INVALID_OP
1047+
| rustc_apfloat::Status::OVERFLOW
1048+
| rustc_apfloat::Status::UNDERFLOW,
1049+
) {
1050+
// Floating point value is NaN (flagged with INVALID_OP) or outside the range
1051+
// of values of the integer type (flagged with OVERFLOW or UNDERFLOW).
1052+
Ok(None)
1053+
} else {
1054+
// Floating point value can be represented by the integer type after rounding.
1055+
// The INEXACT flag is ignored on purpose to allow rounding.
1056+
Ok(Some(ImmTy::from_scalar(val, cast_to)))
1057+
}
10481058
}
10491059

10501060
/// Returns an integer type that is twice wide as `ty`

src/tools/miri/src/shims/intrinsics/mod.rs

+8-30
Original file line numberDiff line numberDiff line change
@@ -365,36 +365,14 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
365365
let [val] = check_arg_count(args)?;
366366
let val = this.read_immediate(val)?;
367367

368-
let res = match val.layout.ty.kind() {
369-
ty::Float(FloatTy::F32) => {
370-
let f = val.to_scalar().to_f32()?;
371-
this
372-
.float_to_int_checked(f, dest.layout, Round::TowardZero)
373-
.ok_or_else(|| {
374-
err_ub_format!(
375-
"`float_to_int_unchecked` intrinsic called on {f} which cannot be represented in target type `{:?}`",
376-
dest.layout.ty
377-
)
378-
})?
379-
}
380-
ty::Float(FloatTy::F64) => {
381-
let f = val.to_scalar().to_f64()?;
382-
this
383-
.float_to_int_checked(f, dest.layout, Round::TowardZero)
384-
.ok_or_else(|| {
385-
err_ub_format!(
386-
"`float_to_int_unchecked` intrinsic called on {f} which cannot be represented in target type `{:?}`",
387-
dest.layout.ty
388-
)
389-
})?
390-
}
391-
_ =>
392-
span_bug!(
393-
this.cur_span(),
394-
"`float_to_int_unchecked` called with non-float input type {:?}",
395-
val.layout.ty
396-
),
397-
};
368+
let res = this
369+
.float_to_int_checked(&val, dest.layout, Round::TowardZero)?
370+
.ok_or_else(|| {
371+
err_ub_format!(
372+
"`float_to_int_unchecked` intrinsic called on {val} which cannot be represented in target type `{:?}`",
373+
dest.layout.ty
374+
)
375+
})?;
398376

399377
this.write_immediate(*res, dest)?;
400378
}

src/tools/miri/src/shims/intrinsics/simd.rs

+3-14
Original file line numberDiff line numberDiff line change
@@ -447,22 +447,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
447447
(ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast =>
448448
this.float_to_float_or_int(&op, dest.layout)?,
449449
// Float-to-int in unchecked mode
450-
(ty::Float(FloatTy::F32), ty::Int(_) | ty::Uint(_)) if unsafe_cast => {
451-
let f = op.to_scalar().to_f32()?;
452-
this.float_to_int_checked(f, dest.layout, Round::TowardZero)
450+
(ty::Float(_), ty::Int(_) | ty::Uint(_)) if unsafe_cast => {
451+
this.float_to_int_checked(&op, dest.layout, Round::TowardZero)?
453452
.ok_or_else(|| {
454453
err_ub_format!(
455-
"`simd_cast` intrinsic called on {f} which cannot be represented in target type `{:?}`",
456-
dest.layout.ty
457-
)
458-
})?
459-
}
460-
(ty::Float(FloatTy::F64), ty::Int(_) | ty::Uint(_)) if unsafe_cast => {
461-
let f = op.to_scalar().to_f64()?;
462-
this.float_to_int_checked(f, dest.layout, Round::TowardZero)
463-
.ok_or_else(|| {
464-
err_ub_format!(
465-
"`simd_cast` intrinsic called on {f} which cannot be represented in target type `{:?}`",
454+
"`simd_cast` intrinsic called on {op} which cannot be represented in target type `{:?}`",
466455
dest.layout.ty
467456
)
468457
})?

src/tools/miri/src/shims/x86/mod.rs

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use rustc_middle::mir;
1+
use rustc_middle::{mir, ty};
22
use rustc_span::Symbol;
33
use rustc_target::abi::Size;
44
use rustc_target::spec::abi::Abi;
@@ -331,6 +331,43 @@ fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
331331
Ok(())
332332
}
333333

334+
/// Converts each element of `op` from floating point to signed integer.
335+
///
336+
/// When the input value is NaN or out of range, fall back to minimum value.
337+
///
338+
/// If `op` has more elements than `dest`, extra elements are ignored. If `op`
339+
/// has less elements than `dest`, the rest is filled with zeros.
340+
fn convert_float_to_int<'tcx>(
341+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
342+
op: &OpTy<'tcx, Provenance>,
343+
rnd: rustc_apfloat::Round,
344+
dest: &PlaceTy<'tcx, Provenance>,
345+
) -> InterpResult<'tcx, ()> {
346+
let (op, op_len) = this.operand_to_simd(op)?;
347+
let (dest, dest_len) = this.place_to_simd(dest)?;
348+
349+
// Output must be *signed* integers.
350+
assert!(matches!(dest.layout.field(this, 0).ty.kind(), ty::Int(_)));
351+
352+
for i in 0..op_len.min(dest_len) {
353+
let op = this.read_immediate(&this.project_index(&op, i)?)?;
354+
let dest = this.project_index(&dest, i)?;
355+
356+
let res = this.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
357+
// Fallback to minimum acording to SSE/AVX semantics.
358+
ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
359+
});
360+
this.write_immediate(*res, &dest)?;
361+
}
362+
// Fill remainder with zeros
363+
for i in op_len..dest_len {
364+
let dest = this.project_index(&dest, i)?;
365+
this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
366+
}
367+
368+
Ok(())
369+
}
370+
334371
/// Horizontaly performs `which` operation on adjacent values of
335372
/// `left` and `right` SIMD vectors and stores the result in `dest`.
336373
fn horizontal_bin_op<'tcx>(

src/tools/miri/src/shims/x86/sse.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
168168
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
169169
let (op, _) = this.operand_to_simd(op)?;
170170

171-
let op = this.read_scalar(&this.project_index(&op, 0)?)?.to_f32()?;
171+
let op = this.read_immediate(&this.project_index(&op, 0)?)?;
172172

173173
let rnd = match unprefixed_name {
174174
// "current SSE rounding mode", assume nearest
@@ -180,7 +180,7 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
180180
_ => unreachable!(),
181181
};
182182

183-
let res = this.float_to_int_checked(op, dest.layout, rnd).unwrap_or_else(|| {
183+
let res = this.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
184184
// Fallback to minimum acording to SSE semantics.
185185
ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
186186
});

0 commit comments

Comments
 (0)