Skip to content

Commit 842ac87

Browse files
calebzulawskiworkingjubilee
authored andcommitted
Use bitmask trait
1 parent 4910274 commit 842ac87

File tree

5 files changed

+93
-60
lines changed

5 files changed

+93
-60
lines changed

crates/core_simd/src/masks.rs

+4-18
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
)]
1313
mod mask_impl;
1414

15-
use crate::simd::intrinsics;
16-
use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
15+
mod to_bitmask;
16+
pub use to_bitmask::ToBitMask;
17+
18+
use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount};
1719
use core::cmp::Ordering;
1820
use core::{fmt, mem};
1921

@@ -216,22 +218,6 @@ where
216218
}
217219
}
218220

219-
/// Convert this mask to a bitmask, with one bit set per lane.
220-
#[cfg(feature = "generic_const_exprs")]
221-
#[inline]
222-
#[must_use = "method returns a new array and does not mutate the original value"]
223-
pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
224-
self.0.to_bitmask()
225-
}
226-
227-
/// Convert a bitmask to a mask.
228-
#[cfg(feature = "generic_const_exprs")]
229-
#[inline]
230-
#[must_use = "method returns a new mask and does not mutate the original value"]
231-
pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
232-
Self(mask_impl::Mask::from_bitmask(bitmask))
233-
}
234-
235221
/// Returns true if any lane is set, or false otherwise.
236222
#[inline]
237223
#[must_use = "method returns a new bool and does not mutate the original value"]

crates/core_simd/src/masks/bitmask.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,14 @@ where
115115
unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) }
116116
}
117117

118-
#[cfg(feature = "generic_const_exprs")]
119118
#[inline]
120-
#[must_use = "method returns a new array and does not mutate the original value"]
121-
pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
122-
// Safety: these are the same type and we are laundering the generic
119+
pub unsafe fn to_bitmask_intrinsic<U>(self) -> U {
123120
unsafe { core::mem::transmute_copy(&self.0) }
124121
}
125122

126-
#[cfg(feature = "generic_const_exprs")]
127123
#[inline]
128-
#[must_use = "method returns a new mask and does not mutate the original value"]
129-
pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
130-
// Safety: these are the same type and we are laundering the generic
131-
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
124+
pub unsafe fn from_bitmask_intrinsic<U>(bitmask: U) -> Self {
125+
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
132126
}
133127

134128
#[inline]

crates/core_simd/src/masks/full_masks.rs

+5-30
Original file line numberDiff line numberDiff line change
@@ -109,41 +109,16 @@ where
109109
unsafe { Mask(intrinsics::simd_cast(self.0)) }
110110
}
111111

112-
#[cfg(feature = "generic_const_exprs")]
113112
#[inline]
114-
#[must_use = "method returns a new array and does not mutate the original value"]
115-
pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
116-
unsafe {
117-
let mut bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN] =
118-
intrinsics::simd_bitmask(self.0);
119-
120-
// There is a bug where LLVM appears to implement this operation with the wrong
121-
// bit order.
122-
// TODO fix this in a better way
123-
if cfg!(target_endian = "big") {
124-
for x in bitmask.as_mut() {
125-
*x = x.reverse_bits();
126-
}
127-
}
128-
129-
bitmask
130-
}
113+
pub unsafe fn to_bitmask_intrinsic<U>(self) -> U {
114+
// Safety: caller must only return bitmask types
115+
unsafe { intrinsics::simd_bitmask(self.0) }
131116
}
132117

133-
#[cfg(feature = "generic_const_exprs")]
134118
#[inline]
135-
#[must_use = "method returns a new mask and does not mutate the original value"]
136-
pub fn from_bitmask(mut bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
119+
pub unsafe fn from_bitmask_intrinsic<U>(bitmask: U) -> Self {
120+
// Safety: caller must only pass bitmask types
137121
unsafe {
138-
// There is a bug where LLVM appears to implement this operation with the wrong
139-
// bit order.
140-
// TODO fix this in a better way
141-
if cfg!(target_endian = "big") {
142-
for x in bitmask.as_mut() {
143-
*x = x.reverse_bits();
144-
}
145-
}
146-
147122
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
148123
bitmask,
149124
Self::splat(true).to_int(),
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use super::{mask_impl, Mask, MaskElement};
2+
3+
/// Converts masks to and from bitmasks.
4+
///
5+
/// In a bitmask, each bit represents if the corresponding lane in the mask is set.
6+
pub trait ToBitMask<BitMask> {
7+
/// Converts a mask to a bitmask.
8+
fn to_bitmask(self) -> BitMask;
9+
10+
/// Converts a bitmask to a mask.
11+
fn from_bitmask(bitmask: BitMask) -> Self;
12+
}
13+
14+
macro_rules! impl_integer_intrinsic {
15+
{ $(unsafe impl ToBitMask<$int:ty> for Mask<_, $lanes:literal>)* } => {
16+
$(
17+
impl<T: MaskElement> ToBitMask<$int> for Mask<T, $lanes> {
18+
fn to_bitmask(self) -> $int {
19+
unsafe { self.0.to_bitmask_intrinsic() }
20+
}
21+
22+
fn from_bitmask(bitmask: $int) -> Self {
23+
unsafe { Self(mask_impl::Mask::from_bitmask_intrinsic(bitmask)) }
24+
}
25+
}
26+
)*
27+
}
28+
}
29+
30+
impl_integer_intrinsic! {
31+
unsafe impl ToBitMask<u8> for Mask<_, 8>
32+
unsafe impl ToBitMask<u16> for Mask<_, 16>
33+
unsafe impl ToBitMask<u32> for Mask<_, 32>
34+
unsafe impl ToBitMask<u64> for Mask<_, 64>
35+
}
36+
37+
macro_rules! impl_integer_via {
38+
{ $(impl ToBitMask<$int:ty, via $via:ty> for Mask<_, $lanes:literal>)* } => {
39+
$(
40+
impl<T: MaskElement> ToBitMask<$int> for Mask<T, $lanes> {
41+
fn to_bitmask(self) -> $int {
42+
let bitmask: $via = self.to_bitmask();
43+
bitmask as _
44+
}
45+
46+
fn from_bitmask(bitmask: $int) -> Self {
47+
Self::from_bitmask(bitmask as $via)
48+
}
49+
}
50+
)*
51+
}
52+
}
53+
54+
impl_integer_via! {
55+
impl ToBitMask<u16, via u8> for Mask<_, 8>
56+
impl ToBitMask<u32, via u8> for Mask<_, 8>
57+
impl ToBitMask<u64, via u8> for Mask<_, 8>
58+
59+
impl ToBitMask<u32, via u16> for Mask<_, 16>
60+
impl ToBitMask<u64, via u16> for Mask<_, 16>
61+
62+
impl ToBitMask<u64, via u32> for Mask<_, 32>
63+
}
64+
65+
#[cfg(target_pointer_width = "32")]
66+
impl_integer_via! {
67+
impl ToBitMask<usize, via u8> for Mask<_, 8>
68+
impl ToBitMask<usize, via u16> for Mask<_, 16>
69+
impl ToBitMask<usize, via u32> for Mask<_, 32>
70+
}
71+
72+
#[cfg(target_pointer_width = "64")]
73+
impl_integer_via! {
74+
impl ToBitMask<usize, via u8> for Mask<_, 8>
75+
impl ToBitMask<usize, via u16> for Mask<_, 16>
76+
impl ToBitMask<usize, via u32> for Mask<_, 32>
77+
impl ToBitMask<usize, via u64> for Mask<_, 64>
78+
}

crates/core_simd/tests/masks.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,16 @@ macro_rules! test_mask_api {
6868
assert_eq!(core_simd::Mask::<$type, 8>::from_int(int), mask);
6969
}
7070

71-
#[cfg(feature = "generic_const_exprs")]
7271
#[test]
7372
fn roundtrip_bitmask_conversion() {
73+
use core_simd::ToBitMask;
7474
let values = [
7575
true, false, false, true, false, false, true, false,
7676
true, true, false, false, false, false, false, true,
7777
];
7878
let mask = core_simd::Mask::<$type, 16>::from_array(values);
79-
let bitmask = mask.to_bitmask();
80-
assert_eq!(bitmask, [0b01001001, 0b10000011]);
79+
let bitmask: u16 = mask.to_bitmask();
80+
assert_eq!(bitmask, 0b1000001101001001);
8181
assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask);
8282
}
8383
}

0 commit comments

Comments
 (0)