Skip to content

Commit 20fa4b7

Browse files
calebzulawskiworkingjubilee
authored andcommitted
Make internal mask implementation safe
1 parent 11c3eef commit 20fa4b7

File tree

3 files changed

+75
-30
lines changed

3 files changed

+75
-30
lines changed

crates/core_simd/src/masks/bitmask.rs

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(unused_imports)]
22
use super::MaskElement;
33
use crate::simd::intrinsics;
4-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
4+
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
55
use core::marker::PhantomData;
66

77
/// A mask where each lane is represented by a single bit.
@@ -116,13 +116,20 @@ where
116116
}
117117

118118
#[inline]
119-
pub unsafe fn to_bitmask_integer<U>(self) -> U {
119+
pub fn to_bitmask_integer<U>(self) -> U
120+
where
121+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
122+
{
123+
// Safety: these are the same types
120124
unsafe { core::mem::transmute_copy(&self.0) }
121125
}
122126

123-
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
124127
#[inline]
125-
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
128+
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
129+
where
130+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
131+
{
132+
// Safety: these are the same types
126133
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
127134
}
128135

crates/core_simd/src/masks/full_masks.rs

+43-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use super::MaskElement;
44
use crate::simd::intrinsics;
5-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
5+
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
66

77
#[repr(transparent)]
88
pub struct Mask<T, const LANES: usize>(Simd<T, LANES>)
@@ -66,6 +66,23 @@ where
6666
}
6767
}
6868

69+
// Used for bitmask bit order workaround
70+
pub(crate) trait ReverseBits {
71+
fn reverse_bits(self) -> Self;
72+
}
73+
74+
macro_rules! impl_reverse_bits {
75+
{ $($int:ty),* } => {
76+
$(
77+
impl ReverseBits for $int {
78+
fn reverse_bits(self) -> Self { <$int>::reverse_bits(self) }
79+
}
80+
)*
81+
}
82+
}
83+
84+
impl_reverse_bits! { u8, u16, u32, u64 }
85+
6986
impl<T, const LANES: usize> Mask<T, LANES>
7087
where
7188
T: MaskElement,
@@ -110,16 +127,34 @@ where
110127
}
111128

112129
#[inline]
113-
pub unsafe fn to_bitmask_integer<U>(self) -> U {
114-
// Safety: caller must only return bitmask types
115-
unsafe { intrinsics::simd_bitmask(self.0) }
130+
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
131+
where
132+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
133+
{
134+
// Safety: U is required to be the appropriate bitmask type
135+
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
136+
137+
// LLVM assumes bit order should match endianness
138+
if cfg!(target_endian = "big") {
139+
bitmask.reverse_bits()
140+
} else {
141+
bitmask
142+
}
116143
}
117144

118-
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
119-
// this mask
120145
#[inline]
121-
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
122-
// Safety: caller must only pass bitmask types
146+
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
147+
where
148+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
149+
{
150+
// LLVM assumes bit order should match endianness
151+
let bitmask = if cfg!(target_endian = "big") {
152+
bitmask.reverse_bits()
153+
} else {
154+
bitmask
155+
};
156+
157+
// Safety: U is required to be the appropriate bitmask type
123158
unsafe {
124159
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
125160
bitmask,

crates/core_simd/src/masks/to_bitmask.rs

+21-18
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
use super::{mask_impl, Mask, MaskElement};
2+
use crate::simd::{LaneCount, SupportedLaneCount};
3+
4+
mod sealed {
5+
pub trait Sealed {}
6+
}
7+
pub use sealed::Sealed;
8+
9+
impl<T, const LANES: usize> Sealed for Mask<T, LANES>
10+
where
11+
T: MaskElement,
12+
LaneCount<LANES>: SupportedLaneCount,
13+
{
14+
}
215

316
/// Converts masks to and from integer bitmasks.
417
///
518
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB.
6-
pub trait ToBitMask {
19+
///
20+
/// # Safety
21+
/// This trait is `unsafe` and sealed, since the `BitMask` type must match the number of lanes in
22+
/// the mask.
23+
pub unsafe trait ToBitMask: Sealed {
724
/// The integer bitmask type.
825
type BitMask;
926

@@ -14,32 +31,18 @@ pub trait ToBitMask {
1431
fn from_bitmask(bitmask: Self::BitMask) -> Self;
1532
}
1633

17-
/// Converts masks to and from byte array bitmasks.
18-
///
19-
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte.
20-
pub trait ToBitMaskArray {
21-
/// The length of the bitmask array.
22-
const BYTES: usize;
23-
24-
/// Converts a mask to a bitmask.
25-
fn to_bitmask_array(self) -> [u8; Self::BYTES];
26-
27-
/// Converts a bitmask to a mask.
28-
fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self;
29-
}
30-
3134
macro_rules! impl_integer_intrinsic {
3235
{ $(unsafe impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
3336
$(
34-
impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
37+
unsafe impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
3538
type BitMask = $int;
3639

3740
fn to_bitmask(self) -> $int {
38-
unsafe { self.0.to_bitmask_integer() }
41+
self.0.to_bitmask_integer()
3942
}
4043

4144
fn from_bitmask(bitmask: $int) -> Self {
42-
unsafe { Self(mask_impl::Mask::from_bitmask_integer(bitmask)) }
45+
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
4346
}
4447
}
4548
)*

0 commit comments

Comments
 (0)