Skip to content

Commit 7f201f8

Browse files
committed
Implement roundps and roundpd SSE4.1 intrinsics
1 parent e25a04f commit 7f201f8

File tree

2 files changed

+124
-16
lines changed

2 files changed

+124
-16
lines changed

Diff for: src/tools/miri/src/shims/x86/sse41.rs

+64-16
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
148148

149149
round_first::<rustc_apfloat::ieee::Single>(this, left, right, rounding, dest)?;
150150
}
151+
// Used to implement the _mm_floor_ps, _mm_ceil_ps and _mm_round_ps
152+
// functions. Rounds the elements of `op` according to `rounding`.
153+
"round.ps" => {
154+
let [op, rounding] =
155+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
156+
157+
round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
158+
}
151159
// Used to implement the _mm_floor_sd, _mm_ceil_sd and _mm_round_sd
152160
// functions. Rounds the first element of `right` according to `rounding`
153161
// and copies the remaining elements from `left`.
@@ -157,6 +165,14 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
157165

158166
round_first::<rustc_apfloat::ieee::Double>(this, left, right, rounding, dest)?;
159167
}
168+
// Used to implement the _mm_floor_pd, _mm_ceil_pd and _mm_round_pd
169+
// functions. Rounds the elements of `op` according to `rounding`.
170+
"round.pd" => {
171+
let [op, rounding] =
172+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
173+
174+
round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
175+
}
160176
// Used to implement the _mm_minpos_epu16 function.
161177
// Find the minimum unsinged 16-bit integer in `op` and
162178
// returns its value and position.
@@ -283,22 +299,7 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
283299
assert_eq!(dest_len, left_len);
284300
assert_eq!(dest_len, right_len);
285301

286-
// The fourth bit of `rounding` only affects the SSE status
287-
// register, which cannot be accessed from Miri (or from Rust,
288-
// for that matter), so we can ignore it.
289-
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 {
290-
// When the third bit is 0, the rounding mode is determined by the
291-
// first two bits.
292-
0b000 => rustc_apfloat::Round::NearestTiesToEven,
293-
0b001 => rustc_apfloat::Round::TowardNegative,
294-
0b010 => rustc_apfloat::Round::TowardPositive,
295-
0b011 => rustc_apfloat::Round::TowardZero,
296-
// When the third bit is 1, the rounding mode is determined by the
297-
// SSE status register. Since we do not support modifying it from
298-
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
299-
0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven,
300-
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
301-
};
302+
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
302303

303304
let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
304305
let res = op0.round_to_integral(rounding).value;
@@ -317,3 +318,50 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
317318

318319
Ok(())
319320
}
321+
322+
// Rounds all elements of `op` according to `rounding`.
323+
fn round_all<'tcx, F: rustc_apfloat::Float>(
324+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
325+
op: &OpTy<'tcx, Provenance>,
326+
rounding: &OpTy<'tcx, Provenance>,
327+
dest: &PlaceTy<'tcx, Provenance>,
328+
) -> InterpResult<'tcx, ()> {
329+
let (op, op_len) = this.operand_to_simd(op)?;
330+
let (dest, dest_len) = this.place_to_simd(dest)?;
331+
332+
assert_eq!(dest_len, op_len);
333+
334+
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
335+
336+
for i in 0..dest_len {
337+
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
338+
let res = op.round_to_integral(rounding).value;
339+
this.write_scalar(
340+
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
341+
&this.project_index(&dest, i)?,
342+
)?;
343+
}
344+
345+
Ok(())
346+
}
347+
348+
/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
349+
/// `round.{ss,sd,ps,pd}` intrinsics.
350+
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
351+
// The fourth bit of `rounding` only affects the SSE status
352+
// register, which cannot be accessed from Miri (or from Rust,
353+
// for that matter), so we can ignore it.
354+
match rounding & !0b1000 {
355+
// When the third bit is 0, the rounding mode is determined by the
356+
// first two bits.
357+
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
358+
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
359+
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
360+
0b011 => Ok(rustc_apfloat::Round::TowardZero),
361+
// When the third bit is 1, the rounding mode is determined by the
362+
// SSE status register. Since we do not support modifying it from
363+
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
364+
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
365+
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
366+
}
367+
}

Diff for: src/tools/miri/tests/pass/intrinsics-x86-sse41.rs

+60
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,36 @@ unsafe fn test_sse41() {
147147
}
148148
test_mm_round_sd();
149149

150+
#[target_feature(enable = "sse4.1")]
151+
unsafe fn test_mm_round_pd() {
152+
let a = _mm_setr_pd(-1.75, -4.25);
153+
let r = _mm_round_pd::<_MM_FROUND_TO_NEAREST_INT>(a);
154+
let e = _mm_setr_pd(-2.0, -4.0);
155+
assert_eq_m128d(r, e);
156+
157+
let a = _mm_setr_pd(-1.75, -4.25);
158+
let r = _mm_round_pd::<_MM_FROUND_TO_NEG_INF>(a);
159+
let e = _mm_setr_pd(-2.0, -5.0);
160+
assert_eq_m128d(r, e);
161+
162+
let a = _mm_setr_pd(-1.75, -4.25);
163+
let r = _mm_round_pd::<_MM_FROUND_TO_POS_INF>(a);
164+
let e = _mm_setr_pd(-1.0, -4.0);
165+
assert_eq_m128d(r, e);
166+
167+
let a = _mm_setr_pd(-1.75, -4.25);
168+
let r = _mm_round_pd::<_MM_FROUND_TO_ZERO>(a);
169+
let e = _mm_setr_pd(-1.0, -4.0);
170+
assert_eq_m128d(r, e);
171+
172+
// Assume round-to-nearest by default
173+
let a = _mm_setr_pd(-1.75, -4.25);
174+
let r = _mm_round_pd::<_MM_FROUND_CUR_DIRECTION>(a);
175+
let e = _mm_setr_pd(-2.0, -4.0);
176+
assert_eq_m128d(r, e);
177+
}
178+
test_mm_round_pd();
179+
150180
#[target_feature(enable = "sse4.1")]
151181
unsafe fn test_mm_round_ss() {
152182
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
@@ -182,6 +212,36 @@ unsafe fn test_sse41() {
182212
}
183213
test_mm_round_ss();
184214

215+
#[target_feature(enable = "sse4.1")]
216+
unsafe fn test_mm_round_ps() {
217+
let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
218+
let r = _mm_round_ps::<_MM_FROUND_TO_NEAREST_INT>(a);
219+
let e = _mm_setr_ps(-2.0, -4.0, -8.0, -16.0);
220+
assert_eq_m128(r, e);
221+
222+
let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
223+
let r = _mm_round_ps::<_MM_FROUND_TO_NEG_INF>(a);
224+
let e = _mm_setr_ps(-2.0, -5.0, -9.0, -17.0);
225+
assert_eq_m128(r, e);
226+
227+
let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
228+
let r = _mm_round_ps::<_MM_FROUND_TO_POS_INF>(a);
229+
let e = _mm_setr_ps(-1.0, -4.0, -8.0, -16.0);
230+
assert_eq_m128(r, e);
231+
232+
let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
233+
let r = _mm_round_ps::<_MM_FROUND_TO_ZERO>(a);
234+
let e = _mm_setr_ps(-1.0, -4.0, -8.0, -16.0);
235+
assert_eq_m128(r, e);
236+
237+
// Assume round-to-nearest by default
238+
let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
239+
let r = _mm_round_ps::<_MM_FROUND_CUR_DIRECTION>(a);
240+
let e = _mm_setr_ps(-2.0, -4.0, -8.0, -16.0);
241+
assert_eq_m128(r, e);
242+
}
243+
test_mm_round_ps();
244+
185245
#[target_feature(enable = "sse4.1")]
186246
unsafe fn test_mm_minpos_epu16() {
187247
let a = _mm_setr_epi16(23, 18, 44, 97, 50, 13, 67, 66);

0 commit comments

Comments
 (0)