Skip to content

Commit 12497d5

Browse files
committed
refactor(cust_raw): Parse macro function renames in cuda headers.
Adds a custom bindgen callback that prevents function renames due to macro defines while still linking to the intend function after macro expansion. It does this by tracking macros when generating the bindings, so that it can change the name of the function back to what it was before the macro changed and link to the macro expanded function name. Doing so helps prevents breaking changes across CUDA versions when generating the bindings, and generates function bindings that match those used in Nvidia's CUDA documentation. misc.: optix-sys rebuilds if related optix environment variables change. misc.: unpin cc and jobserver from add example.
1 parent b85d9ca commit 12497d5

File tree

25 files changed

+289
-162
lines changed

25 files changed

+289
-162
lines changed

crates/blastoff/src/context.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ impl CublasContext {
9292
pub fn new() -> Result<Self> {
9393
let mut raw = MaybeUninit::uninit();
9494
unsafe {
95-
cublas_sys::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
96-
cublas_sys::cublasSetPointerMode_v2(
95+
cublas_sys::cublasCreate(raw.as_mut_ptr()).to_result()?;
96+
cublas_sys::cublasSetPointerMode(
9797
raw.assume_init(),
9898
cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
9999
)
@@ -112,7 +112,7 @@ impl CublasContext {
112112

113113
unsafe {
114114
let inner = mem::replace(&mut ctx.raw, ptr::null_mut());
115-
match cublas_sys::cublasDestroy_v2(inner).to_result() {
115+
match cublas_sys::cublasDestroy(inner).to_result() {
116116
Ok(()) => {
117117
mem::forget(ctx);
118118
Ok(())
@@ -127,7 +127,7 @@ impl CublasContext {
127127
let mut raw = MaybeUninit::<u32>::uninit();
128128
unsafe {
129129
// getVersion can't fail
130-
cublas_sys::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
130+
cublas_sys::cublasGetVersion(self.raw, raw.as_mut_ptr().cast())
131131
.to_result()
132132
.unwrap();
133133

@@ -145,7 +145,7 @@ impl CublasContext {
145145
) -> Result<T> {
146146
unsafe {
147147
// cudaStream_t is the same as CUstream
148-
cublas_sys::cublasSetStream_v2(
148+
cublas_sys::cublasSetStream(
149149
self.raw,
150150
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
151151
stream.as_inner(),
@@ -155,7 +155,7 @@ impl CublasContext {
155155
let res = func(self)?;
156156
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
157157
// execute a raw sys function with the context's handle.
158-
cublas_sys::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
158+
cublas_sys::cublasSetStream(self.raw, ptr::null_mut()).to_result()?;
159159
Ok(res)
160160
}
161161
}
@@ -340,7 +340,7 @@ impl CublasContext {
340340
impl Drop for CublasContext {
341341
fn drop(&mut self) {
342342
unsafe {
343-
cublas_sys::cublasDestroy_v2(self.raw);
343+
cublas_sys::cublasDestroy(self.raw);
344344
}
345345
}
346346
}

0 commit comments

Comments
 (0)