Skip to content

Commit 110b092

Browse files
committed
simd_bitmask: work correctly for sizes like 24
1 parent 773415d commit 110b092

File tree

4 files changed

+122
-42
lines changed

4 files changed

+122
-42
lines changed

src/tools/miri/src/helpers.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
374374
let val = if dest.layout().abi.is_signed() {
375375
Scalar::from_int(i, dest.layout().size)
376376
} else {
377-
Scalar::from_uint(u64::try_from(i.into()).unwrap(), dest.layout().size)
377+
// `unwrap` can only fail here if `i` is negative
378+
Scalar::from_uint(u128::try_from(i.into()).unwrap(), dest.layout().size)
378379
};
379380
self.eval_context_mut().write_scalar(val, dest)
380381
}

src/tools/miri/src/intrinsics/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![warn(clippy::arithmetic_side_effects)]
2+
13
mod atomic;
24
mod simd;
35

src/tools/miri/src/intrinsics/simd.rs

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -458,23 +458,45 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
458458
);
459459
}
460460

461-
// The mask must be an integer or an array.
462-
assert!(
463-
mask.layout.ty.is_integral()
464-
|| matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
465-
);
466-
assert_eq!(bitmask_len, mask.layout.size.bits());
467461
assert_eq!(dest_len, yes_len);
468462
assert_eq!(dest_len, no_len);
469-
let dest_len = u32::try_from(dest_len).unwrap();
470-
let bitmask_len = u32::try_from(bitmask_len).unwrap();
471463

472-
// To read the mask, we transmute it to an integer.
473-
// That does the right thing wrt endianness.
474-
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
475-
let mask = mask.transmute(mask_ty, this)?;
476-
let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();
464+
// Read the mask, either as an integer or as an array.
465+
let mask: u64 = match mask.layout.ty.kind() {
466+
ty::Uint(_) => {
467+
// Any larger integer type is fine.
468+
assert!(mask.layout.size.bits() >= bitmask_len);
469+
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
470+
}
471+
ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
472+
// The array must have exactly the right size.
473+
assert_eq!(mask.layout.size.bits(), bitmask_len);
474+
// Read the raw bytes.
475+
let mask = mask.assert_mem_place(); // arrays cannot be immediate
476+
let mask_bytes =
477+
this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
478+
// Turn them into a `u64` in the right way.
479+
let mask_size = mask.layout.size.bytes_usize();
480+
let mut mask_arr = [0u8; 8];
481+
match this.data_layout().endian {
482+
Endian::Little => {
483+
// Fill the first N bytes.
484+
mask_arr[..mask_size].copy_from_slice(mask_bytes);
485+
u64::from_le_bytes(mask_arr)
486+
}
487+
Endian::Big => {
488+
// Fill the last N bytes.
489+
let i = mask_arr.len().strict_sub(mask_size);
490+
mask_arr[i..].copy_from_slice(mask_bytes);
491+
u64::from_be_bytes(mask_arr)
492+
}
493+
}
494+
}
495+
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
496+
};
477497

498+
let dest_len = u32::try_from(dest_len).unwrap();
499+
let bitmask_len = u32::try_from(bitmask_len).unwrap();
478500
for i in 0..dest_len {
479501
let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
480502
let mask = mask & 1u64.checked_shl(bit_i).unwrap();
@@ -508,14 +530,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
508530
);
509531
}
510532

511-
// Returns either an unsigned integer or array of `u8`.
512-
assert!(
513-
dest.layout.ty.is_integral()
514-
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
515-
);
516-
assert_eq!(bitmask_len, dest.layout.size.bits());
517533
let op_len = u32::try_from(op_len).unwrap();
518-
519534
let mut res = 0u64;
520535
for i in 0..op_len {
521536
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
@@ -525,11 +540,34 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
525540
.unwrap();
526541
}
527542
}
528-
// We have to change the type of the place to be able to write `res` into it. This
529-
// transmutes the integer to an array, which does the right thing wrt endianness.
530-
let dest =
531-
dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
532-
this.write_int(res, &dest)?;
543+
// Write the result, depending on the `dest` type.
544+
// Returns either an unsigned integer or array of `u8`.
545+
match dest.layout.ty.kind() {
546+
ty::Uint(_) => {
547+
// Any larger integer type is fine, it will be zero-extended.
548+
assert!(dest.layout.size.bits() >= bitmask_len);
549+
this.write_int(res, dest)?;
550+
}
551+
ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
552+
// The array must have exactly the right size.
553+
assert_eq!(dest.layout.size.bits(), bitmask_len);
554+
// We have to write the result byte-for-byte.
555+
let res_size = dest.layout.size.bytes_usize();
556+
let res_bytes;
557+
let res_bytes_slice = match this.data_layout().endian {
558+
Endian::Little => {
559+
res_bytes = res.to_le_bytes();
560+
&res_bytes[..res_size] // take the first N bytes
561+
}
562+
Endian::Big => {
563+
res_bytes = res.to_be_bytes();
564+
&res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
565+
}
566+
};
567+
this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
568+
}
569+
_ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
570+
}
533571
}
534572
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
535573
let [op] = check_arg_count(args)?;

src/tools/miri/tests/pass/intrinsics/portable-simd.rs

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -323,38 +323,77 @@ fn simd_mask() {
323323
#[repr(simd, packed)]
324324
#[allow(non_camel_case_types)]
325325
#[derive(Copy, Clone, Debug, PartialEq)]
326-
struct i32x10(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32);
326+
struct i32x10([i32; 10]);
327327
impl i32x10 {
328328
fn splat(x: i32) -> Self {
329-
Self(x, x, x, x, x, x, x, x, x, x)
330-
}
331-
fn from_array(a: [i32; 10]) -> Self {
332-
unsafe { std::mem::transmute(a) }
329+
Self([x; 10])
333330
}
334331
}
335332
unsafe {
336-
let mask = i32x10::from_array([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
333+
let mask = i32x10([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
334+
let mask_bits = if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 };
335+
let mask_bytes =
336+
if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] };
337+
337338
let bitmask1: u16 = simd_bitmask(mask);
338339
let bitmask2: [u8; 2] = simd_bitmask(mask);
339-
if cfg!(target_endian = "little") {
340-
assert_eq!(bitmask1, 0b0101001011);
341-
assert_eq!(bitmask2, [0b01001011, 0b01]);
342-
} else {
343-
assert_eq!(bitmask1, 0b1101001010);
344-
assert_eq!(bitmask2, [0b11, 0b01001010]);
345-
}
340+
assert_eq!(bitmask1, mask_bits);
341+
assert_eq!(bitmask2, mask_bytes);
342+
346343
let selected1 = simd_select_bitmask::<u16, _>(
347-
if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 },
344+
mask_bits,
348345
i32x10::splat(!0), // yes
349346
i32x10::splat(0), // no
350347
);
351348
let selected2 = simd_select_bitmask::<[u8; 2], _>(
352-
if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] },
349+
mask_bytes,
353350
i32x10::splat(!0), // yes
354351
i32x10::splat(0), // no
355352
);
356353
assert_eq!(selected1, mask);
357-
assert_eq!(selected2, selected1);
354+
assert_eq!(selected2, mask);
355+
}
356+
357+
// Test for a mask where the next multiple of 8 is not a power of two.
358+
#[repr(simd, packed)]
359+
#[allow(non_camel_case_types)]
360+
#[derive(Copy, Clone, Debug, PartialEq)]
361+
struct i32x20([i32; 20]);
362+
impl i32x20 {
363+
fn splat(x: i32) -> Self {
364+
Self([x; 20])
365+
}
366+
}
367+
unsafe {
368+
let mask = i32x20([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0, 0, 0, 0, !0, !0, !0, !0, !0, !0, !0]);
369+
let mask_bits = if cfg!(target_endian = "little") {
370+
0b11111110000101001011
371+
} else {
372+
0b11010010100001111111
373+
};
374+
let mask_bytes = if cfg!(target_endian = "little") {
375+
[0b01001011, 0b11100001, 0b1111]
376+
} else {
377+
[0b1101, 0b00101000, 0b01111111]
378+
};
379+
380+
let bitmask1: u32 = simd_bitmask(mask);
381+
let bitmask2: [u8; 3] = simd_bitmask(mask);
382+
assert_eq!(bitmask1, mask_bits);
383+
assert_eq!(bitmask2, mask_bytes);
384+
385+
let selected1 = simd_select_bitmask::<u32, _>(
386+
mask_bits,
387+
i32x20::splat(!0), // yes
388+
i32x20::splat(0), // no
389+
);
390+
let selected2 = simd_select_bitmask::<[u8; 3], _>(
391+
mask_bytes,
392+
i32x20::splat(!0), // yes
393+
i32x20::splat(0), // no
394+
);
395+
assert_eq!(selected1, mask);
396+
assert_eq!(selected2, mask);
358397
}
359398
}
360399

0 commit comments

Comments
 (0)