Skip to content

Commit 3a4c7ea

Browse files
committed
Feat: start work for CAS ops
1 parent 668b33b commit 3a4c7ea

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

crates/cuda_std/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Notable changes to this project will be documented in this file.
66

77
- Added warp shuffles, matches, reductions, and votes in the `warp` module.
88
- Added `activemask` in the `warp` module to query a mask of the active threads.
9+
- Fixed `lane_id` generating invalid ptx.
910

1011
## 0.2.2 - 2/7/22
1112

crates/cuda_std/src/atomic.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ macro_rules! safety_doc {
6161
};
6262
}
6363

64+
// taken from stdlib compare_and_swap docs
65+
fn double_ordering_from_one(ordering: Ordering) -> (Ordering, Ordering) {
66+
match ordering {
67+
Ordering::Relaxed => (Ordering::Relaxed, Ordering::Relaxed),
68+
Ordering::Acquire => (Ordering::Acquire, Ordering::Acquire),
69+
Ordering::Release => (Ordering::Release, Ordering::Relaxed),
70+
Ordering::AcqRel => (Ordering::AcqRel, Ordering::Acquire),
71+
Ordering::SeqCst => (Ordering::SeqCst, Ordering::SeqCst),
72+
_ => unreachable!(),
73+
}
74+
}
75+
6476
macro_rules! atomic_float {
6577
($float_ty:ident, $atomic_ty:ident, $align:tt, $scope:ident, $width:tt $(,$unsafety:ident)?) => {
6678
#[doc = concat!("A ", stringify!($width), "-bit float type which can be safely shared between threads and synchronizes across ", scope_doc!($scope))]
@@ -220,6 +232,17 @@ macro_rules! atomic_float {
220232
#[cfg(not(target_os = "cuda"))]
221233
self.as_atomic_bits().store(val.to_bits(), order);
222234
}
235+
236+
// $(#[doc = safety_doc!($unsafety)])?
237+
// pub $($unsafety)? fn compare_and_swap(&self, current: f32, new: f32, order: Ordering) -> Result<$float_ty, $float_ty> {
238+
// #[cfg(target_os = "cuda")]
239+
// unsafe {
240+
// let res = mid::[<atomic_compare_and_swap_ $float_ty _ $scope>](self.v.get().cast(), order, current, new);
241+
// }
242+
243+
// #[cfg(not(target_os = "cuda"))]
244+
// self.as_atomic_bits().compare_exchange
245+
// }
223246
}
224247
}
225248
};

crates/cuda_std/src/atomic/mid.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,67 @@ fetch_ops_1_param! {
251251
max => (u32, u64, i32, i64),
252252
exch => (u32, u64, i32, i64, f32, f64),
253253
}
254+
255+
macro_rules! inner_cas {
256+
($($type:ty, $scope:ident),* $(,)?) => {
257+
$(
258+
paste! {
259+
#[$crate::gpu_only]
260+
#[allow(clippy::missing_safety_doc)]
261+
pub unsafe fn [<atomic_compare_and_swap_ $type _ $scope>](ptr: *mut $type, current: $type, new: $type, ordering: Ordering) -> $type {
262+
if ge_sm70() {
263+
match ordering {
264+
SeqCst => {
265+
intrinsics::[<fence_sc_ $scope>]();
266+
intrinsics::[<atomic_fetch_cas_acquire_ $type _ $scope>](ptr, current, new)
267+
},
268+
Acquire => intrinsics::[<atomic_fetch_cas_acquire_ $type _ $scope>](ptr, current, new),
269+
AcqRel => intrinsics::[<atomic_fetch_cas_acqrel_ $type _ $scope>](ptr, current, new),
270+
Release => intrinsics::[<atomic_fetch_cas_release_ $type _ $scope>](ptr, current, new),
271+
Relaxed => intrinsics::[<atomic_fetch_cas_relaxed_ $type _ $scope>](ptr, current, new),
272+
_ => unimplemented!("Weird ordering added by core")
273+
}
274+
} else {
275+
match ordering {
276+
SeqCst | AcqRel => {
277+
intrinsics::[<membar_ $scope>]();
278+
let val = intrinsics::[<atomic_fetch_cas_volatile_ $type _ $scope>](ptr, current, new);
279+
intrinsics::[<membar_ $scope>]();
280+
val
281+
},
282+
Acquire => {
283+
let val = intrinsics::[<atomic_fetch_cas_volatile_ $type _ $scope>](ptr, current, new);
284+
intrinsics::[<membar_ $scope>]();
285+
val
286+
},
287+
Release => {
288+
intrinsics::[<membar_ $scope>]();
289+
intrinsics::[<atomic_fetch_cas_volatile_ $type _ $scope>](ptr, current, new)
290+
},
291+
Relaxed => {
292+
intrinsics::[<atomic_fetch_cas_volatile_ $type _ $scope>](ptr, current, new)
293+
},
294+
_ => unimplemented!("Weird ordering added by core")
295+
}
296+
}
297+
}
298+
}
299+
)*
300+
}
301+
}
302+
303+
macro_rules! impl_cas {
304+
($($type:ident),* $(,)?) => {
305+
$(
306+
inner_cas!(
307+
$type, block,
308+
$type, device,
309+
$type, system,
310+
);
311+
)*
312+
};
313+
}
314+
315+
impl_cas! {
316+
u32, u64, i32, i64, f32, f64
317+
}

0 commit comments

Comments
 (0)