Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 98dad13

Browse files
calebzulawskiworkingjubilee
authored andcommitted
Make implementation more scalable by using a helper trait to determine bitmask size. Improve bitmask to int conversion.
1 parent eec4280 commit 98dad13

File tree

12 files changed

+291
-204
lines changed

12 files changed

+291
-204
lines changed

crates/core_simd/src/comparisons.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ macro_rules! implement_mask_ops {
77
where
88
crate::$vector<LANES>: LanesAtMost32,
99
crate::$inner_ty<LANES>: LanesAtMost32,
10+
crate::$mask<LANES>: crate::Mask,
1011
{
1112
/// Test if each lane is equal to the corresponding lane in `other`.
1213
#[inline]

crates/core_simd/src/intrinsics.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ extern "platform-intrinsic" {
7979

8080
// truncate integer vector to bitmask
8181
pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
82+
83+
// select
84+
pub(crate) fn simd_select_bitmask<T, U>(m: T, a: U, b: U) -> U;
8285
}
8386

8487
#[cfg(feature = "std")]

crates/core_simd/src/masks/bitmask.rs

Lines changed: 104 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,103 @@
1+
use crate::Mask;
2+
use core::marker::PhantomData;
3+
4+
/// Helper trait for limiting int conversion types
5+
pub trait ConvertToInt {}
6+
impl<const LANES: usize> ConvertToInt for crate::SimdI8<LANES> where Self: crate::LanesAtMost32 {}
7+
impl<const LANES: usize> ConvertToInt for crate::SimdI16<LANES> where Self: crate::LanesAtMost32 {}
8+
impl<const LANES: usize> ConvertToInt for crate::SimdI32<LANES> where Self: crate::LanesAtMost32 {}
9+
impl<const LANES: usize> ConvertToInt for crate::SimdI64<LANES> where Self: crate::LanesAtMost32 {}
10+
impl<const LANES: usize> ConvertToInt for crate::SimdIsize<LANES> where Self: crate::LanesAtMost32 {}
11+
112
/// A mask where each lane is represented by a single bit.
2-
#[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Ord, Eq, Hash)]
313
#[repr(transparent)]
4-
pub struct BitMask<const LANES: usize>(u64);
14+
pub struct BitMask<T: Mask, const LANES: usize>(T::BitMask, PhantomData<[(); LANES]>);
515

6-
impl<const LANES: usize> BitMask<LANES>
7-
{
16+
impl<T: Mask, const LANES: usize> Copy for BitMask<T, LANES> {}
17+
18+
impl<T: Mask, const LANES: usize> Clone for BitMask<T, LANES> {
19+
fn clone(&self) -> Self {
20+
*self
21+
}
22+
}
23+
24+
impl<T: Mask, const LANES: usize> PartialEq for BitMask<T, LANES> {
25+
fn eq(&self, other: &Self) -> bool {
26+
self.0.as_ref() == other.0.as_ref()
27+
}
28+
}
29+
30+
impl<T: Mask, const LANES: usize> PartialOrd for BitMask<T, LANES> {
31+
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
32+
self.0.as_ref().partial_cmp(other.0.as_ref())
33+
}
34+
}
35+
36+
impl<T: Mask, const LANES: usize> Eq for BitMask<T, LANES> {}
37+
38+
impl<T: Mask, const LANES: usize> Ord for BitMask<T, LANES> {
39+
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
40+
self.0.as_ref().cmp(other.0.as_ref())
41+
}
42+
}
43+
44+
impl<T: Mask, const LANES: usize> BitMask<T, LANES> {
845
#[inline]
946
pub fn splat(value: bool) -> Self {
47+
let mut mask = T::BitMask::default();
1048
if value {
11-
Self(u64::MAX >> (64 - LANES))
49+
mask.as_mut().fill(u8::MAX)
1250
} else {
13-
Self(u64::MIN)
51+
mask.as_mut().fill(u8::MIN)
52+
}
53+
if LANES % 8 > 0 {
54+
*mask.as_mut().last_mut().unwrap() &= u8::MAX >> (8 - LANES % 8);
1455
}
56+
Self(mask, PhantomData)
1557
}
1658

1759
#[inline]
1860
pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
19-
(self.0 >> lane) & 0x1 > 0
61+
(self.0.as_ref()[lane / 8] >> lane % 8) & 0x1 > 0
2062
}
2163

2264
#[inline]
2365
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
24-
self.0 ^= ((value ^ self.test_unchecked(lane)) as u64) << lane
66+
self.0.as_mut()[lane / 8] ^= ((value ^ self.test_unchecked(lane)) as u8) << (lane % 8)
2567
}
2668

2769
#[inline]
28-
pub fn to_int<V, T>(self) -> V
70+
pub fn to_int<V>(self) -> V
2971
where
30-
V: Default + AsMut<[T; LANES]>,
31-
T: From<i8>,
72+
V: ConvertToInt + Default + core::ops::Not<Output = V>,
3273
{
33-
// TODO this should be an intrinsic sign-extension
34-
let mut v = V::default();
35-
for i in 0..LANES {
36-
let lane = unsafe { self.test_unchecked(i) };
37-
v.as_mut()[i] = (-(lane as i8)).into();
74+
unsafe {
75+
let mask: T::IntBitMask = core::mem::transmute_copy(&self);
76+
crate::intrinsics::simd_select_bitmask(mask, !V::default(), V::default())
3877
}
39-
v
4078
}
4179

4280
#[inline]
4381
pub unsafe fn from_int_unchecked<V>(value: V) -> Self
4482
where
4583
V: crate::LanesAtMost32,
4684
{
47-
let mask: V::BitMask = crate::intrinsics::simd_bitmask(value);
48-
Self(mask.into())
85+
// TODO remove the transmute when rustc is more flexible
86+
assert_eq!(
87+
core::mem::size_of::<T::IntBitMask>(),
88+
core::mem::size_of::<T::BitMask>()
89+
);
90+
let mask: T::IntBitMask = crate::intrinsics::simd_bitmask(value);
91+
Self(core::mem::transmute_copy(&mask), PhantomData)
4992
}
5093

5194
#[inline]
52-
pub fn to_bitmask(self) -> u64 {
53-
self.0
95+
pub fn to_bitmask<U: Mask>(self) -> U::BitMask {
96+
assert_eq!(
97+
core::mem::size_of::<T::BitMask>(),
98+
core::mem::size_of::<U::BitMask>()
99+
);
100+
unsafe { core::mem::transmute_copy(&self.0) }
54101
}
55102

56103
#[inline]
@@ -64,87 +111,61 @@ impl<const LANES: usize> BitMask<LANES>
64111
}
65112
}
66113

67-
impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES>
114+
impl<T: Mask, const LANES: usize> core::ops::BitAnd for BitMask<T, LANES>
115+
where
116+
T::BitMask: Default + AsRef<[u8]> + AsMut<[u8]>,
68117
{
69118
type Output = Self;
70119
#[inline]
71-
fn bitand(self, rhs: Self) -> Self {
72-
Self(self.0 & rhs.0)
120+
fn bitand(mut self, rhs: Self) -> Self {
121+
for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) {
122+
*l &= r;
123+
}
124+
self
73125
}
74126
}
75127

76-
impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES>
128+
impl<T: Mask, const LANES: usize> core::ops::BitOr for BitMask<T, LANES>
129+
where
130+
T::BitMask: Default + AsRef<[u8]> + AsMut<[u8]>,
77131
{
78132
type Output = Self;
79133
#[inline]
80-
fn bitand(self, rhs: bool) -> Self {
81-
self & Self::splat(rhs)
82-
}
83-
}
84-
85-
impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool
86-
{
87-
type Output = BitMask<LANES>;
88-
#[inline]
89-
fn bitand(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
90-
BitMask::<LANES>::splat(self) & rhs
134+
fn bitor(mut self, rhs: Self) -> Self {
135+
for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) {
136+
*l |= r;
137+
}
138+
self
91139
}
92140
}
93141

94-
impl<const LANES: usize> core::ops::BitOr for BitMask<LANES>
95-
{
142+
impl<T: Mask, const LANES: usize> core::ops::BitXor for BitMask<T, LANES> {
96143
type Output = Self;
97144
#[inline]
98-
fn bitor(self, rhs: Self) -> Self {
99-
Self(self.0 | rhs.0)
145+
fn bitxor(mut self, rhs: Self) -> Self::Output {
146+
for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) {
147+
*l ^= r;
148+
}
149+
self
100150
}
101151
}
102152

103-
impl<const LANES: usize> core::ops::BitXor for BitMask<LANES>
104-
{
153+
impl<T: Mask, const LANES: usize> core::ops::Not for BitMask<T, LANES> {
105154
type Output = Self;
106155
#[inline]
107-
fn bitxor(self, rhs: Self) -> Self::Output {
108-
Self(self.0 ^ rhs.0)
109-
}
110-
}
111-
112-
impl<const LANES: usize> core::ops::Not for BitMask<LANES>
113-
{
114-
type Output = BitMask<LANES>;
115-
#[inline]
116-
fn not(self) -> Self::Output {
117-
Self(!self.0) & Self::splat(true)
118-
}
119-
}
120-
121-
impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES>
122-
{
123-
#[inline]
124-
fn bitand_assign(&mut self, rhs: Self) {
125-
self.0 &= rhs.0;
126-
}
127-
}
128-
129-
impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES>
130-
{
131-
#[inline]
132-
fn bitor_assign(&mut self, rhs: Self) {
133-
self.0 |= rhs.0;
134-
}
135-
}
136-
137-
impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES>
138-
{
139-
#[inline]
140-
fn bitxor_assign(&mut self, rhs: Self) {
141-
self.0 ^= rhs.0;
156+
fn not(mut self) -> Self::Output {
157+
for x in self.0.as_mut() {
158+
*x = !*x;
159+
}
160+
if LANES % 8 > 0 {
161+
*self.0.as_mut().last_mut().unwrap() &= u8::MAX >> (8 - LANES % 8);
162+
}
163+
self
142164
}
143165
}
144166

145-
pub type Mask8<const LANES: usize> = BitMask<LANES>;
146-
pub type Mask16<const LANES: usize> = BitMask<LANES>;
147-
pub type Mask32<const LANES: usize> = BitMask<LANES>;
148-
pub type Mask64<const LANES: usize> = BitMask<LANES>;
149-
pub type Mask128<const LANES: usize> = BitMask<LANES>;
150-
pub type MaskSize<const LANES: usize> = BitMask<LANES>;
167+
pub type Mask8<T, const LANES: usize> = BitMask<T, LANES>;
168+
pub type Mask16<T, const LANES: usize> = BitMask<T, LANES>;
169+
pub type Mask32<T, const LANES: usize> = BitMask<T, LANES>;
170+
pub type Mask64<T, const LANES: usize> = BitMask<T, LANES>;
171+
pub type MaskSize<T, const LANES: usize> = BitMask<T, LANES>;

0 commit comments

Comments
 (0)