Skip to content

Commit de822dc

Browse files
committed
Auto merge of rust-lang#3662 - RalfJung:simd-bitmask, r=RalfJung
simd_bitmask: work correctly for sizes like 24
2 parents 773415d + ba45198 commit de822dc

File tree

4 files changed

+130
-51
lines changed

4 files changed

+130
-51
lines changed

src/tools/miri/src/helpers.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
374374
let val = if dest.layout().abi.is_signed() {
375375
Scalar::from_int(i, dest.layout().size)
376376
} else {
377-
Scalar::from_uint(u64::try_from(i.into()).unwrap(), dest.layout().size)
377+
// `unwrap` can only fail here if `i` is negative
378+
Scalar::from_uint(u128::try_from(i.into()).unwrap(), dest.layout().size)
378379
};
379380
self.eval_context_mut().write_scalar(val, dest)
380381
}

src/tools/miri/src/intrinsics/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![warn(clippy::arithmetic_side_effects)]
2+
13
mod atomic;
24
mod simd;
35

src/tools/miri/src/intrinsics/simd.rs

+71-34
Original file line numberDiff line numberDiff line change
@@ -458,26 +458,48 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
458458
);
459459
}
460460

461-
// The mask must be an integer or an array.
462-
assert!(
463-
mask.layout.ty.is_integral()
464-
|| matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
465-
);
466-
assert_eq!(bitmask_len, mask.layout.size.bits());
467461
assert_eq!(dest_len, yes_len);
468462
assert_eq!(dest_len, no_len);
469-
let dest_len = u32::try_from(dest_len).unwrap();
470-
let bitmask_len = u32::try_from(bitmask_len).unwrap();
471463

472-
// To read the mask, we transmute it to an integer.
473-
// That does the right thing wrt endianness.
474-
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
475-
let mask = mask.transmute(mask_ty, this)?;
476-
let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();
464+
// Read the mask, either as an integer or as an array.
465+
let mask: u64 = match mask.layout.ty.kind() {
466+
ty::Uint(_) => {
467+
// Any larger integer type is fine.
468+
assert!(mask.layout.size.bits() >= bitmask_len);
469+
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
470+
}
471+
ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
472+
// The array must have exactly the right size.
473+
assert_eq!(mask.layout.size.bits(), bitmask_len);
474+
// Read the raw bytes.
475+
let mask = mask.assert_mem_place(); // arrays cannot be immediate
476+
let mask_bytes =
477+
this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
478+
// Turn them into a `u64` in the right way.
479+
let mask_size = mask.layout.size.bytes_usize();
480+
let mut mask_arr = [0u8; 8];
481+
match this.data_layout().endian {
482+
Endian::Little => {
483+
// Fill the first N bytes.
484+
mask_arr[..mask_size].copy_from_slice(mask_bytes);
485+
u64::from_le_bytes(mask_arr)
486+
}
487+
Endian::Big => {
488+
// Fill the last N bytes.
489+
let i = mask_arr.len().strict_sub(mask_size);
490+
mask_arr[i..].copy_from_slice(mask_bytes);
491+
u64::from_be_bytes(mask_arr)
492+
}
493+
}
494+
}
495+
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
496+
};
477497

498+
let dest_len = u32::try_from(dest_len).unwrap();
499+
let bitmask_len = u32::try_from(bitmask_len).unwrap();
478500
for i in 0..dest_len {
479501
let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
480-
let mask = mask & 1u64.checked_shl(bit_i).unwrap();
502+
let mask = mask & 1u64.strict_shl(bit_i);
481503
let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
482504
let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
483505
let dest = this.project_index(&dest, i.into())?;
@@ -489,7 +511,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
489511
// If the mask is "padded", ensure that padding is all-zero.
490512
// This deliberately does not use `simd_bitmask_index`; these bits are outside
491513
// the bitmask. It does not matter in which order we check them.
492-
let mask = mask & 1u64.checked_shl(i).unwrap();
514+
let mask = mask & 1u64.strict_shl(i);
493515
if mask != 0 {
494516
throw_ub_format!(
495517
"a SIMD bitmask less than 8 bits long must be filled with 0s for the remaining bits"
@@ -508,28 +530,43 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
508530
);
509531
}
510532

511-
// Returns either an unsigned integer or array of `u8`.
512-
assert!(
513-
dest.layout.ty.is_integral()
514-
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
515-
);
516-
assert_eq!(bitmask_len, dest.layout.size.bits());
517533
let op_len = u32::try_from(op_len).unwrap();
518-
519534
let mut res = 0u64;
520535
for i in 0..op_len {
521536
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
522537
if simd_element_to_bool(op)? {
523-
res |= 1u64
524-
.checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
525-
.unwrap();
538+
let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian);
539+
res |= 1u64.strict_shl(bit_i);
540+
}
541+
}
542+
// Write the result, depending on the `dest` type.
543+
// Returns either an unsigned integer or array of `u8`.
544+
match dest.layout.ty.kind() {
545+
ty::Uint(_) => {
546+
// Any larger integer type is fine, it will be zero-extended.
547+
assert!(dest.layout.size.bits() >= bitmask_len);
548+
this.write_int(res, dest)?;
549+
}
550+
ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
551+
// The array must have exactly the right size.
552+
assert_eq!(dest.layout.size.bits(), bitmask_len);
553+
// We have to write the result byte-for-byte.
554+
let res_size = dest.layout.size.bytes_usize();
555+
let res_bytes;
556+
let res_bytes_slice = match this.data_layout().endian {
557+
Endian::Little => {
558+
res_bytes = res.to_le_bytes();
559+
&res_bytes[..res_size] // take the first N bytes
560+
}
561+
Endian::Big => {
562+
res_bytes = res.to_be_bytes();
563+
&res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
564+
}
565+
};
566+
this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
526567
}
568+
_ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
527569
}
528-
// We have to change the type of the place to be able to write `res` into it. This
529-
// transmutes the integer to an array, which does the right thing wrt endianness.
530-
let dest =
531-
dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
532-
this.write_int(res, &dest)?;
533570
}
534571
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
535572
let [op] = check_arg_count(args)?;
@@ -615,8 +652,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
615652

616653
let val = if src_index < left_len {
617654
this.read_immediate(&this.project_index(&left, src_index)?)?
618-
} else if src_index < left_len.checked_add(right_len).unwrap() {
619-
let right_idx = src_index.checked_sub(left_len).unwrap();
655+
} else if src_index < left_len.strict_add(right_len) {
656+
let right_idx = src_index.strict_sub(left_len);
620657
this.read_immediate(&this.project_index(&right, right_idx)?)?
621658
} else {
622659
throw_ub_format!(
@@ -655,8 +692,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
655692

656693
let val = if src_index < left_len {
657694
this.read_immediate(&this.project_index(&left, src_index)?)?
658-
} else if src_index < left_len.checked_add(right_len).unwrap() {
659-
let right_idx = src_index.checked_sub(left_len).unwrap();
695+
} else if src_index < left_len.strict_add(right_len) {
696+
let right_idx = src_index.strict_sub(left_len);
660697
this.read_immediate(&this.project_index(&right, right_idx)?)?
661698
} else {
662699
throw_ub_format!(

src/tools/miri/tests/pass/intrinsics/portable-simd.rs

+55-16
Original file line numberDiff line numberDiff line change
@@ -323,38 +323,77 @@ fn simd_mask() {
323323
#[repr(simd, packed)]
324324
#[allow(non_camel_case_types)]
325325
#[derive(Copy, Clone, Debug, PartialEq)]
326-
struct i32x10(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32);
326+
struct i32x10([i32; 10]);
327327
impl i32x10 {
328328
fn splat(x: i32) -> Self {
329-
Self(x, x, x, x, x, x, x, x, x, x)
330-
}
331-
fn from_array(a: [i32; 10]) -> Self {
332-
unsafe { std::mem::transmute(a) }
329+
Self([x; 10])
333330
}
334331
}
335332
unsafe {
336-
let mask = i32x10::from_array([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
333+
let mask = i32x10([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
334+
let mask_bits = if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 };
335+
let mask_bytes =
336+
if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] };
337+
337338
let bitmask1: u16 = simd_bitmask(mask);
338339
let bitmask2: [u8; 2] = simd_bitmask(mask);
339-
if cfg!(target_endian = "little") {
340-
assert_eq!(bitmask1, 0b0101001011);
341-
assert_eq!(bitmask2, [0b01001011, 0b01]);
342-
} else {
343-
assert_eq!(bitmask1, 0b1101001010);
344-
assert_eq!(bitmask2, [0b11, 0b01001010]);
345-
}
340+
assert_eq!(bitmask1, mask_bits);
341+
assert_eq!(bitmask2, mask_bytes);
342+
346343
let selected1 = simd_select_bitmask::<u16, _>(
347-
if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 },
344+
mask_bits,
348345
i32x10::splat(!0), // yes
349346
i32x10::splat(0), // no
350347
);
351348
let selected2 = simd_select_bitmask::<[u8; 2], _>(
352-
if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] },
349+
mask_bytes,
353350
i32x10::splat(!0), // yes
354351
i32x10::splat(0), // no
355352
);
356353
assert_eq!(selected1, mask);
357-
assert_eq!(selected2, selected1);
354+
assert_eq!(selected2, mask);
355+
}
356+
357+
// Test for a mask where the next multiple of 8 is not a power of two.
358+
#[repr(simd, packed)]
359+
#[allow(non_camel_case_types)]
360+
#[derive(Copy, Clone, Debug, PartialEq)]
361+
struct i32x20([i32; 20]);
362+
impl i32x20 {
363+
fn splat(x: i32) -> Self {
364+
Self([x; 20])
365+
}
366+
}
367+
unsafe {
368+
let mask = i32x20([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0, 0, 0, 0, !0, !0, !0, !0, !0, !0, !0]);
369+
let mask_bits = if cfg!(target_endian = "little") {
370+
0b11111110000101001011
371+
} else {
372+
0b11010010100001111111
373+
};
374+
let mask_bytes = if cfg!(target_endian = "little") {
375+
[0b01001011, 0b11100001, 0b1111]
376+
} else {
377+
[0b1101, 0b00101000, 0b01111111]
378+
};
379+
380+
let bitmask1: u32 = simd_bitmask(mask);
381+
let bitmask2: [u8; 3] = simd_bitmask(mask);
382+
assert_eq!(bitmask1, mask_bits);
383+
assert_eq!(bitmask2, mask_bytes);
384+
385+
let selected1 = simd_select_bitmask::<u32, _>(
386+
mask_bits,
387+
i32x20::splat(!0), // yes
388+
i32x20::splat(0), // no
389+
);
390+
let selected2 = simd_select_bitmask::<[u8; 3], _>(
391+
mask_bytes,
392+
i32x20::splat(!0), // yes
393+
i32x20::splat(0), // no
394+
);
395+
assert_eq!(selected1, mask);
396+
assert_eq!(selected2, mask);
358397
}
359398
}
360399

0 commit comments

Comments
 (0)