diff --git a/crates/blastoff/src/context.rs b/crates/blastoff/src/context.rs index 421e2d49..2a517be0 100644 --- a/crates/blastoff/src/context.rs +++ b/crates/blastoff/src/context.rs @@ -92,8 +92,8 @@ impl CublasContext { pub fn new() -> Result { let mut raw = MaybeUninit::uninit(); unsafe { - cublas_sys::cublasCreate_v2(raw.as_mut_ptr()).to_result()?; - cublas_sys::cublasSetPointerMode_v2( + cublas_sys::cublasCreate(raw.as_mut_ptr()).to_result()?; + cublas_sys::cublasSetPointerMode( raw.assume_init(), cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE, ) @@ -112,7 +112,7 @@ impl CublasContext { unsafe { let inner = mem::replace(&mut ctx.raw, ptr::null_mut()); - match cublas_sys::cublasDestroy_v2(inner).to_result() { + match cublas_sys::cublasDestroy(inner).to_result() { Ok(()) => { mem::forget(ctx); Ok(()) @@ -127,7 +127,7 @@ impl CublasContext { let mut raw = MaybeUninit::::uninit(); unsafe { // getVersion can't fail - cublas_sys::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast()) + cublas_sys::cublasGetVersion(self.raw, raw.as_mut_ptr().cast()) .to_result() .unwrap(); @@ -145,7 +145,7 @@ impl CublasContext { ) -> Result { unsafe { // cudaStream_t is the same as CUstream - cublas_sys::cublasSetStream_v2( + cublas_sys::cublasSetStream( self.raw, mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>( stream.as_inner(), @@ -155,7 +155,7 @@ impl CublasContext { let res = func(self)?; // reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to // execute a raw sys function with the context's handle. - cublas_sys::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?; + cublas_sys::cublasSetStream(self.raw, ptr::null_mut()).to_result()?; Ok(res) } } @@ -340,7 +340,7 @@ impl CublasContext { impl Drop for CublasContext { fn drop(&mut self) { unsafe { - cublas_sys::cublasDestroy_v2(self.raw); + cublas_sys::cublasDestroy(self.raw); } } } diff --git a/crates/blastoff/src/raw/level1.rs b/crates/blastoff/src/raw/level1.rs index d2167fbb..7268d642 100644 --- a/crates/blastoff/src/raw/level1.rs +++ b/crates/blastoff/src/raw/level1.rs @@ -103,7 +103,7 @@ impl Level1 for f32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIsamax_v2(handle, n, x, incx, result) + cublasIsamax(handle, n, x, incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -112,7 +112,7 @@ impl Level1 for f32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIsamin_v2(handle, n, x, incx, result) + cublasIsamin(handle, n, x, incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -123,7 +123,7 @@ impl Level1 for f32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasSaxpy_v2(handle, n, alpha, x, incx, y, incy) + cublasSaxpy(handle, n, alpha, x, incx, y, incy) } unsafe fn copy( handle: cublasHandle_t, @@ -133,7 +133,7 @@ impl Level1 for f32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasScopy_v2(handle, n, x, incx, y, incy) + cublasScopy(handle, n, x, incx, y, incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -142,7 +142,7 @@ impl Level1 for f32 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasSnrm2_v2(handle, n, x, incx, result) + cublasSnrm2(handle, n, x, incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -154,7 +154,7 @@ impl Level1 for f32 { c: *const Self::FloatTy, s: *const Self, ) -> cublasStatus_t { - cublasSrot_v2(handle, n, x, incx, y, incy, c, s) + cublasSrot(handle, n, x, incx, y, incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -163,7 +163,7 @@ impl Level1 for f32 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasSrotg_v2(handle, a, b, c, s) + cublasSrotg(handle, a, b, c, s) } unsafe fn rotm( handle: cublasHandle_t, @@ -174,7 +174,7 @@ impl Level1 for f32 { incy: c_int, param: *const Self::FloatTy, ) -> cublasStatus_t { - cublasSrotm_v2(handle, n, x, incx, y, incy, param) + cublasSrotm(handle, n, x, incx, y, incy, param) } unsafe fn rotmg( handle: cublasHandle_t, @@ -184,7 +184,7 @@ impl Level1 for f32 { y1: *const Self, param: *mut Self, ) -> cublasStatus_t { - cublasSrotmg_v2(handle, d1, d2, x1, y1, param) + cublasSrotmg(handle, d1, d2, x1, y1, param) } unsafe fn scal( handle: cublasHandle_t, @@ -193,7 +193,7 @@ impl Level1 for f32 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasSscal_v2(handle, n, alpha, x, incx) + cublasSscal(handle, n, alpha, x, incx) } unsafe fn swap( handle: cublasHandle_t, @@ -203,7 +203,7 @@ impl Level1 for f32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasSswap_v2(handle, n, x, incx, y, incy) + cublasSswap(handle, n, x, incx, y, incy) } } @@ -215,7 +215,7 @@ impl Level1 for f64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIdamax_v2(handle, n, x, incx, result) + cublasIdamax(handle, n, x, incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -224,7 +224,7 @@ impl Level1 for f64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIdamin_v2(handle, n, x, incx, result) + cublasIdamin(handle, n, x, incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -235,7 +235,7 @@ impl Level1 for f64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasDaxpy_v2(handle, n, alpha, x, incx, y, incy) + cublasDaxpy(handle, n, alpha, x, incx, y, incy) } unsafe fn copy( handle: cublasHandle_t, @@ -245,7 +245,7 @@ impl Level1 for f64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasDcopy_v2(handle, n, x, incx, y, incy) + cublasDcopy(handle, n, x, incx, y, incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -254,7 +254,7 @@ impl Level1 for f64 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasDnrm2_v2(handle, n, x, incx, result) + cublasDnrm2(handle, n, x, incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -266,7 +266,7 @@ impl Level1 for f64 { c: *const Self::FloatTy, s: *const Self, ) -> cublasStatus_t { - cublasDrot_v2(handle, n, x, incx, y, incy, c, s) + cublasDrot(handle, n, x, incx, y, incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -275,7 +275,7 @@ impl Level1 for f64 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasDrotg_v2(handle, a, b, c, s) + cublasDrotg(handle, a, b, c, s) } unsafe fn rotm( handle: cublasHandle_t, @@ -286,7 +286,7 @@ impl Level1 for f64 { incy: c_int, param: *const Self::FloatTy, ) -> cublasStatus_t { - cublasDrotm_v2(handle, n, x, incx, y, incy, param) + cublasDrotm(handle, n, x, incx, y, incy, param) } unsafe fn rotmg( handle: cublasHandle_t, @@ -296,7 +296,7 @@ impl Level1 for f64 { y1: *const Self, param: *mut Self, ) -> cublasStatus_t { - cublasDrotmg_v2(handle, d1, d2, x1, y1, param) + cublasDrotmg(handle, d1, d2, x1, y1, param) } unsafe fn scal( handle: cublasHandle_t, @@ -305,7 +305,7 @@ impl Level1 for f64 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasDscal_v2(handle, n, alpha, x, incx) + cublasDscal(handle, n, alpha, x, incx) } unsafe fn swap( handle: cublasHandle_t, @@ -315,7 +315,7 @@ impl Level1 for f64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasDswap_v2(handle, n, x, incx, y, incy) + cublasDswap(handle, n, x, incx, y, incy) } } @@ -327,7 +327,7 @@ impl Level1 for Complex32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIcamax_v2(handle, n, x.cast(), incx, result) + cublasIcamax(handle, n, x.cast(), incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -336,7 +336,7 @@ impl Level1 for Complex32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIcamin_v2(handle, n, x.cast(), incx, result) + cublasIcamin(handle, n, x.cast(), incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -347,7 +347,7 @@ impl Level1 for Complex32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasCaxpy_v2(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) + cublasCaxpy(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) } unsafe fn copy( handle: cublasHandle_t, @@ -357,7 +357,7 @@ impl Level1 for Complex32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasCcopy_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasCcopy(handle, n, x.cast(), incx, y.cast(), incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -366,7 +366,7 @@ impl Level1 for Complex32 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasScnrm2_v2(handle, n, x.cast(), incx, result) + cublasScnrm2(handle, n, x.cast(), incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -378,7 +378,7 @@ impl Level1 for Complex32 { c: *const Self::FloatTy, s: *const Self::FloatTy, ) -> cublasStatus_t { - cublasCsrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s) + cublasCsrot(handle, n, x.cast(), incx, y.cast(), incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -387,7 +387,7 @@ impl Level1 for Complex32 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasCrotg_v2(handle, a.cast(), b.cast(), c, s.cast()) + cublasCrotg(handle, a.cast(), b.cast(), c, s.cast()) } unsafe fn rotm( _handle: cublasHandle_t, @@ -417,7 +417,7 @@ impl Level1 for Complex32 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasCscal_v2(handle, n, alpha.cast(), x.cast(), incx) + cublasCscal(handle, n, alpha.cast(), x.cast(), incx) } unsafe fn swap( handle: cublasHandle_t, @@ -427,7 +427,7 @@ impl Level1 for Complex32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasCswap_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasCswap(handle, n, x.cast(), incx, y.cast(), incy) } } @@ -439,7 +439,7 @@ impl Level1 for Complex64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIzamax_v2(handle, n, x.cast(), incx, result) + cublasIzamax(handle, n, x.cast(), incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -448,7 +448,7 @@ impl Level1 for Complex64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIzamin_v2(handle, n, x.cast(), incx, result) + cublasIzamin(handle, n, x.cast(), incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -459,7 +459,7 @@ impl Level1 for Complex64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasZaxpy_v2(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) + cublasZaxpy(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) } unsafe fn copy( handle: cublasHandle_t, @@ -469,7 +469,7 @@ impl Level1 for Complex64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasZcopy_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasZcopy(handle, n, x.cast(), incx, y.cast(), incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -478,7 +478,7 @@ impl Level1 for Complex64 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasDznrm2_v2(handle, n, x.cast(), incx, result) + cublasDznrm2(handle, n, x.cast(), incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -490,7 +490,7 @@ impl Level1 for Complex64 { c: *const Self::FloatTy, s: *const Self::FloatTy, ) -> cublasStatus_t { - cublasZdrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s) + cublasZdrot(handle, n, x.cast(), incx, y.cast(), incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -499,7 +499,7 @@ impl Level1 for Complex64 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasZrotg_v2(handle, a.cast(), b.cast(), c, s.cast()) + cublasZrotg(handle, a.cast(), b.cast(), c, s.cast()) } unsafe fn rotm( _handle: cublasHandle_t, @@ -529,7 +529,7 @@ impl Level1 for Complex64 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasZscal_v2(handle, n, alpha.cast(), x.cast(), incx) + cublasZscal(handle, n, alpha.cast(), x.cast(), incx) } unsafe fn swap( handle: cublasHandle_t, @@ -539,7 +539,7 @@ impl Level1 for Complex64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasZswap_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasZswap(handle, n, x.cast(), incx, y.cast(), incy) } } @@ -575,7 +575,7 @@ impl ComplexLevel1 for Complex32 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasCdotu_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasCdotu(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } unsafe fn dotc( handle: cublasHandle_t, @@ -586,7 +586,7 @@ impl ComplexLevel1 for Complex32 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasCdotc_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasCdotc(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } } @@ -600,7 +600,7 @@ impl ComplexLevel1 for Complex64 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasZdotu_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasZdotu(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } unsafe fn dotc( handle: cublasHandle_t, @@ -611,7 +611,7 @@ impl ComplexLevel1 for Complex64 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasZdotc_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasZdotc(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } } @@ -638,7 +638,7 @@ impl FloatLevel1 for f32 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasSdot_v2(handle, n, x, incx, y, incy, result) + cublasSdot(handle, n, x, incx, y, incy, result) } } @@ -652,6 +652,6 @@ impl FloatLevel1 for f64 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasDdot_v2(handle, n, x, incx, y, incy, result) + cublasDdot(handle, n, x, incx, y, incy, result) } } diff --git a/crates/blastoff/src/raw/level3.rs b/crates/blastoff/src/raw/level3.rs index 5e6d8e17..3e770a29 100644 --- a/crates/blastoff/src/raw/level3.rs +++ b/crates/blastoff/src/raw/level3.rs @@ -85,7 +85,7 @@ impl GemmOps for f32 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasSgemm_v2( + cublasSgemm( handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ) } @@ -108,7 +108,7 @@ impl GemmOps for f64 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasDgemm_v2( + cublasDgemm( handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ) } @@ -131,7 +131,7 @@ impl GemmOps for Complex32 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasCgemm_v2( + cublasCgemm( handle, transa, transb, @@ -167,7 +167,7 @@ impl GemmOps for Complex64 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasCgemm_v2( + cublasCgemm( handle, transa, transb, diff --git a/crates/cust/src/context/legacy.rs b/crates/cust/src/context/legacy.rs index 838d6d77..53f2b501 100644 --- a/crates/cust/src/context/legacy.rs +++ b/crates/cust/src/context/legacy.rs @@ -262,7 +262,7 @@ impl Context { // lifetime guarantees so we create-and-push, then pop, then the programmer has to // push again. let mut ctx: CUcontext = ptr::null_mut(); - driver_sys::cuCtxCreate_v2(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw()) + driver_sys::cuCtxCreate(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw()) .to_result()?; Ok(Context { inner: ctx }) } @@ -354,7 +354,7 @@ impl Context { unsafe { let inner = mem::replace(&mut ctx.inner, ptr::null_mut()); - match driver_sys::cuCtxDestroy_v2(inner).to_result() { + match driver_sys::cuCtxDestroy(inner).to_result() { Ok(()) => { mem::forget(ctx); Ok(()) @@ -372,7 +372,7 @@ impl Drop for Context { unsafe { let inner = mem::replace(&mut self.inner, ptr::null_mut()); - driver_sys::cuCtxDestroy_v2(inner); + driver_sys::cuCtxDestroy(inner); } } } @@ -456,7 +456,7 @@ impl ContextStack { pub fn pop() -> CudaResult { unsafe { let mut ctx: CUcontext = ptr::null_mut(); - driver_sys::cuCtxPopCurrent_v2(&mut ctx as *mut CUcontext).to_result()?; + driver_sys::cuCtxPopCurrent(&mut ctx as *mut CUcontext).to_result()?; Ok(UnownedContext { inner: ctx }) } } @@ -481,7 +481,7 @@ impl ContextStack { /// ``` pub fn push(ctx: &C) -> CudaResult<()> { unsafe { - driver_sys::cuCtxPushCurrent_v2(ctx.get_inner()).to_result()?; + driver_sys::cuCtxPushCurrent(ctx.get_inner()).to_result()?; Ok(()) } } diff --git a/crates/cust/src/context/mod.rs b/crates/cust/src/context/mod.rs index 6b2551bd..eb67e28a 100644 --- a/crates/cust/src/context/mod.rs +++ b/crates/cust/src/context/mod.rs @@ -215,13 +215,13 @@ impl Context { /// Nothing else should be using the primary context for this device, otherwise, /// spurious errors or segfaults will occur. pub unsafe fn reset(device: &Device) -> CudaResult<()> { - driver_sys::cuDevicePrimaryCtxReset_v2(device.as_raw()).to_result() + driver_sys::cuDevicePrimaryCtxReset(device.as_raw()).to_result() } /// Sets the flags for the device context, these flags will apply to any user of the primary /// context associated with this device. pub fn set_flags(&self, flags: ContextFlags) -> CudaResult<()> { - unsafe { driver_sys::cuDevicePrimaryCtxSetFlags_v2(self.device, flags.bits()).to_result() } + unsafe { driver_sys::cuDevicePrimaryCtxSetFlags(self.device, flags.bits()).to_result() } } /// Returns the raw handle to this context. @@ -291,7 +291,7 @@ impl Context { unsafe { let inner = mem::replace(&mut ctx.inner, ptr::null_mut()); - match driver_sys::cuDevicePrimaryCtxRelease_v2(ctx.device).to_result() { + match driver_sys::cuDevicePrimaryCtxRelease(ctx.device).to_result() { Ok(()) => { mem::forget(ctx); Ok(()) @@ -316,7 +316,7 @@ impl Drop for Context { unsafe { self.inner = ptr::null_mut(); - driver_sys::cuDevicePrimaryCtxRelease_v2(self.device); + driver_sys::cuDevicePrimaryCtxRelease(self.device); } } } diff --git a/crates/cust/src/device.rs b/crates/cust/src/device.rs index fb345c86..36f0cc76 100644 --- a/crates/cust/src/device.rs +++ b/crates/cust/src/device.rs @@ -295,7 +295,7 @@ impl Device { pub fn total_memory(self) -> CudaResult { unsafe { let mut memory = 0; - driver_sys::cuDeviceTotalMem_v2(&mut memory as *mut usize, self.device).to_result()?; + driver_sys::cuDeviceTotalMem(&mut memory as *mut usize, self.device).to_result()?; Ok(memory) } } diff --git a/crates/cust/src/event.rs b/crates/cust/src/event.rs index 55ed8195..18c36059 100644 --- a/crates/cust/src/event.rs +++ b/crates/cust/src/event.rs @@ -18,7 +18,7 @@ use std::ptr; use std::time::Duration; use cust_raw::driver_sys::{ - cuEventCreate, cuEventDestroy_v2, cuEventElapsedTime, cuEventQuery, cuEventRecord, + cuEventCreate, cuEventDestroy, cuEventElapsedTime, cuEventQuery, cuEventRecord, cuEventSynchronize, CUevent, }; @@ -334,7 +334,7 @@ impl Event { unsafe { let inner = mem::replace(&mut event.0, ptr::null_mut()); - match cuEventDestroy_v2(inner).to_result() { + match cuEventDestroy(inner).to_result() { Ok(()) => { mem::forget(event); Ok(()) @@ -347,7 +347,7 @@ impl Event { impl Drop for Event { fn drop(&mut self) { - unsafe { cuEventDestroy_v2(self.0) }; + unsafe { cuEventDestroy(self.0) }; } } diff --git a/crates/cust/src/graph.rs b/crates/cust/src/graph.rs index 914f42cf..b24e0963 100644 --- a/crates/cust/src/graph.rs +++ b/crates/cust/src/graph.rs @@ -395,7 +395,7 @@ impl Graph { let deps_ptr = deps.as_ptr().cast(); let mut node = MaybeUninit::::uninit(); let params = invocation.to_raw(); - driver_sys::cuGraphAddKernelNode_v2( + driver_sys::cuGraphAddKernelNode( node.as_mut_ptr().cast(), self.raw, deps_ptr, @@ -476,7 +476,7 @@ impl Graph { ); unsafe { let mut params = MaybeUninit::uninit(); - driver_sys::cuGraphKernelNodeGetParams_v2(node.to_raw(), params.as_mut_ptr()); + driver_sys::cuGraphKernelNodeGetParams(node.to_raw(), params.as_mut_ptr()); Ok(KernelInvocation::from_raw(params.assume_init())) } } diff --git a/crates/cust/src/link.rs b/crates/cust/src/link.rs index 26bf7202..d57aaf97 100644 --- a/crates/cust/src/link.rs +++ b/crates/cust/src/link.rs @@ -27,7 +27,7 @@ impl Linker { unsafe { let mut raw = MaybeUninit::uninit(); - driver_sys::cuLinkCreate_v2(0, null_mut(), null_mut(), raw.as_mut_ptr()).to_result()?; + driver_sys::cuLinkCreate(0, null_mut(), null_mut(), raw.as_mut_ptr()).to_result()?; Ok(Self { raw: raw.assume_init(), }) @@ -48,7 +48,7 @@ impl Linker { let ptx = ptx.as_ref(); unsafe { - driver_sys::cuLinkAddData_v2( + driver_sys::cuLinkAddData( self.raw, driver_sys::CUjitInputType::CU_JIT_INPUT_PTX, // cuda_sys wants *mut but from the API docs we know we retain ownership so @@ -73,7 +73,7 @@ impl Linker { let cubin = cubin.as_ref(); unsafe { - driver_sys::cuLinkAddData_v2( + driver_sys::cuLinkAddData( self.raw, driver_sys::CUjitInputType::CU_JIT_INPUT_CUBIN, // cuda_sys wants *mut but from the API docs we know we retain ownership so @@ -98,7 +98,7 @@ impl Linker { let fatbin = fatbin.as_ref(); unsafe { - driver_sys::cuLinkAddData_v2( + driver_sys::cuLinkAddData( self.raw, driver_sys::CUjitInputType::CU_JIT_INPUT_FATBINARY, // cuda_sys wants *mut but from the API docs we know we retain ownership so diff --git a/crates/cust/src/memory/array.rs b/crates/cust/src/memory/array.rs index 7d543e0c..01008633 100644 --- a/crates/cust/src/memory/array.rs +++ b/crates/cust/src/memory/array.rs @@ -14,9 +14,9 @@ use std::ptr::null; use std::ptr::null_mut; use cust_raw::driver_sys; -use cust_raw::driver_sys::cuMemcpy2D_v2; -use cust_raw::driver_sys::cuMemcpyAtoH_v2; -use cust_raw::driver_sys::cuMemcpyHtoA_v2; +use cust_raw::driver_sys::cuMemcpy2D; +use cust_raw::driver_sys::cuMemcpyAtoH; +use cust_raw::driver_sys::cuMemcpyHtoA; use cust_raw::driver_sys::CUDA_MEMCPY2D; use cust_raw::driver_sys::{CUarray, CUarray_format, CUarray_format_enum}; @@ -479,7 +479,7 @@ impl ArrayObject { } let mut handle = MaybeUninit::uninit(); - unsafe { driver_sys::cuArray3DCreate_v2(handle.as_mut_ptr(), &descriptor.desc) } + unsafe { driver_sys::cuArray3DCreate(handle.as_mut_ptr(), &descriptor.desc) } .to_result()?; Ok(Self { handle: unsafe { handle.assume_init() }, @@ -731,7 +731,7 @@ impl ArrayObject { pub fn descriptor(&self) -> CudaResult { // Use "zeroed" incase CUDA_ARRAY3D_DESCRIPTOR has uninitialized padding let mut raw_descriptor = MaybeUninit::zeroed(); - unsafe { driver_sys::cuArray3DGetDescriptor_v2(raw_descriptor.as_mut_ptr(), self.handle) } + unsafe { driver_sys::cuArray3DGetDescriptor(raw_descriptor.as_mut_ptr(), self.handle) } .to_result()?; Ok(ArrayDescriptor::from_raw(unsafe { @@ -764,8 +764,7 @@ impl ArrayObject { assert_eq!(self_size, other_size, "Array and value sizes don't match"); unsafe { if desc.height() == 0 && desc.depth() == 0 { - cuMemcpyHtoA_v2(self.handle, 0, val.as_ptr() as *const c_void, self_size) - .to_result() + cuMemcpyHtoA(self.handle, 0, val.as_ptr() as *const c_void, self_size).to_result() } else if desc.depth() == 0 { let desc = CUDA_MEMCPY2D { Height: desc.height(), @@ -787,7 +786,7 @@ impl ArrayObject { srcXInBytes: 0, srcY: 0, }; - cuMemcpy2D_v2(&desc as *const _).to_result() + cuMemcpy2D(&desc as *const _).to_result() } else { panic!(); } @@ -810,8 +809,7 @@ impl ArrayObject { assert_eq!(self_size, other_size, "Array and value sizes don't match"); unsafe { if desc.height() == 0 && desc.depth() == 0 { - cuMemcpyAtoH_v2(val.as_mut_ptr() as *mut c_void, self.handle, 0, self_size) - .to_result() + cuMemcpyAtoH(val.as_mut_ptr() as *mut c_void, self.handle, 0, self_size).to_result() } else if desc.depth() == 0 { let width = desc.width() * desc.num_channels() as usize * desc.format().mem_size(); let desc = CUDA_MEMCPY2D { @@ -832,7 +830,7 @@ impl ArrayObject { srcXInBytes: 0, srcY: 0, }; - cuMemcpy2D_v2(&desc as *const _).to_result()?; + cuMemcpy2D(&desc as *const _).to_result()?; Ok(()) } else { panic!(); diff --git a/crates/cust/src/memory/device/device_box.rs b/crates/cust/src/memory/device/device_box.rs index acb0d040..b5e765df 100644 --- a/crates/cust/src/memory/device/device_box.rs +++ b/crates/cust/src/memory/device/device_box.rs @@ -164,7 +164,7 @@ impl DeviceBox { unsafe { let new_box = DeviceBox::uninitialized()?; if mem::size_of::() != 0 { - driver_sys::cuMemsetD8_v2(new_box.as_device_ptr().as_raw(), 0, mem::size_of::()) + driver_sys::cuMemsetD8(new_box.as_device_ptr().as_raw(), 0, mem::size_of::()) .to_result()?; } Ok(new_box) @@ -430,12 +430,8 @@ impl CopyDestination for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyHtoD_v2( - self.ptr.as_raw(), - val as *const T as *const c_void, - size, - ) - .to_result()? + driver_sys::cuMemcpyHtoD(self.ptr.as_raw(), val as *const T as *const c_void, size) + .to_result()? } } Ok(()) @@ -445,7 +441,7 @@ impl CopyDestination for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoH_v2(val as *const T as *mut c_void, self.ptr.as_raw(), size) + driver_sys::cuMemcpyDtoH(val as *const T as *mut c_void, self.ptr.as_raw(), size) .to_result()? } } @@ -457,8 +453,7 @@ impl CopyDestination> for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(self.ptr.as_raw(), val.ptr.as_raw(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(self.ptr.as_raw(), val.ptr.as_raw(), size).to_result()? } } Ok(()) @@ -468,8 +463,7 @@ impl CopyDestination> for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(val.ptr.as_raw(), self.ptr.as_raw(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(val.ptr.as_raw(), self.ptr.as_raw(), size).to_result()? } } Ok(()) @@ -479,7 +473,7 @@ impl AsyncCopyDestination for DeviceBox { unsafe fn async_copy_from(&mut self, val: &T, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyHtoDAsync_v2( + driver_sys::cuMemcpyHtoDAsync( self.ptr.as_raw(), val as *const _ as *const c_void, size, @@ -493,7 +487,7 @@ impl AsyncCopyDestination for DeviceBox { unsafe fn async_copy_to(&self, val: &mut T, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyDtoHAsync_v2( + driver_sys::cuMemcpyDtoHAsync( val as *mut _ as *mut c_void, self.ptr.as_raw(), size, @@ -508,7 +502,7 @@ impl AsyncCopyDestination> for DeviceBox { unsafe fn async_copy_from(&mut self, val: &DeviceBox, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( self.ptr.as_raw(), val.ptr.as_raw(), size, @@ -522,7 +516,7 @@ impl AsyncCopyDestination> for DeviceBox { unsafe fn async_copy_to(&self, val: &mut DeviceBox, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( val.ptr.as_raw(), self.ptr.as_raw(), size, diff --git a/crates/cust/src/memory/device/device_buffer.rs b/crates/cust/src/memory/device/device_buffer.rs index 6fc5dde6..873a194c 100644 --- a/crates/cust/src/memory/device/device_buffer.rs +++ b/crates/cust/src/memory/device/device_buffer.rs @@ -232,12 +232,8 @@ impl DeviceBuffer { unsafe { let new_buf = DeviceBuffer::uninitialized(size)?; if size_of::() != 0 { - driver_sys::cuMemsetD8_v2( - new_buf.as_device_ptr().as_raw(), - 0, - size_of::() * size, - ) - .to_result()?; + driver_sys::cuMemsetD8(new_buf.as_device_ptr().as_raw(), 0, size_of::() * size) + .to_result()?; } Ok(new_buf) } diff --git a/crates/cust/src/memory/device/device_slice.rs b/crates/cust/src/memory/device/device_slice.rs index 702b9d04..893b8c4d 100644 --- a/crates/cust/src/memory/device/device_slice.rs +++ b/crates/cust/src/memory/device/device_slice.rs @@ -250,7 +250,7 @@ impl DeviceSlice { // SAFETY: We know T can hold any value because it is `Pod`, and // sub-byte alignment isn't a thing so we know the alignment is right. unsafe { - driver_sys::cuMemsetD8_v2(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() + driver_sys::cuMemsetD8(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() } } @@ -300,7 +300,7 @@ impl DeviceSlice { 0, "Buffer pointer is not aligned to at least 2 bytes!" ); - unsafe { driver_sys::cuMemsetD16_v2(self.as_raw_ptr(), value, data_len / 2).to_result() } + unsafe { driver_sys::cuMemsetD16(self.as_raw_ptr(), value, data_len / 2).to_result() } } /// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously. @@ -358,7 +358,7 @@ impl DeviceSlice { 0, "Buffer pointer is not aligned to at least 4 bytes!" ); - unsafe { driver_sys::cuMemsetD32_v2(self.as_raw_ptr(), value, data_len / 4).to_result() } + unsafe { driver_sys::cuMemsetD32(self.as_raw_ptr(), value, data_len / 4).to_result() } } /// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously. @@ -651,7 +651,7 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyHtoD_v2(self.as_raw_ptr(), val.as_ptr() as *const c_void, size) + driver_sys::cuMemcpyHtoD(self.as_raw_ptr(), val.as_ptr() as *const c_void, size) .to_result()? } } @@ -667,12 +667,8 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoH_v2( - val.as_mut_ptr() as *mut c_void, - self.as_raw_ptr(), - size, - ) - .to_result()? + driver_sys::cuMemcpyDtoH(val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size) + .to_result()? } } Ok(()) @@ -687,8 +683,7 @@ impl CopyDestination> for DeviceSlice { let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(self.as_raw_ptr(), val.as_raw_ptr(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(self.as_raw_ptr(), val.as_raw_ptr(), size).to_result()? } } Ok(()) @@ -702,8 +697,7 @@ impl CopyDestination> for DeviceSlice { let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(val.as_raw_ptr(), self.as_raw_ptr(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(val.as_raw_ptr(), self.as_raw_ptr(), size).to_result()? } } Ok(()) @@ -729,7 +723,7 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyHtoDAsync_v2( + driver_sys::cuMemcpyHtoDAsync( self.as_raw_ptr(), val.as_ptr() as *const c_void, size, @@ -748,7 +742,7 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyDtoHAsync_v2( + driver_sys::cuMemcpyDtoHAsync( val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size, @@ -767,7 +761,7 @@ impl AsyncCopyDestination> for DeviceSlice { ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( self.as_raw_ptr(), val.as_raw_ptr(), size, @@ -785,7 +779,7 @@ impl AsyncCopyDestination> for DeviceSlice { ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( val.as_raw_ptr(), self.as_raw_ptr(), size, diff --git a/crates/cust/src/memory/malloc.rs b/crates/cust/src/memory/malloc.rs index 78f1f356..6255778c 100644 --- a/crates/cust/src/memory/malloc.rs +++ b/crates/cust/src/memory/malloc.rs @@ -48,7 +48,7 @@ pub unsafe fn cuda_malloc(count: usize) -> CudaResult( let mut ptr = 0; let mut pitch = 0; - driver_sys::cuMemAllocPitch_v2(&mut ptr, &mut pitch, width_bytes, height, element_size) + driver_sys::cuMemAllocPitch(&mut ptr, &mut pitch, width_bytes, height, element_size) .to_result()?; Ok((DevicePointer::from_raw(ptr), pitch)) } @@ -236,7 +236,7 @@ pub unsafe fn cuda_free(ptr: DevicePointer) -> CudaResult<()> return Err(CudaError::InvalidMemoryAllocation); } - driver_sys::cuMemFree_v2(ptr.as_raw()).to_result()?; + driver_sys::cuMemFree(ptr.as_raw()).to_result()?; Ok(()) } @@ -269,7 +269,7 @@ pub unsafe fn cuda_free_unified(mut p: UnifiedPointer) -> Cuda return Err(CudaError::InvalidMemoryAllocation); } - driver_sys::cuMemFree_v2(ptr as u64).to_result()?; + driver_sys::cuMemFree(ptr as u64).to_result()?; Ok(()) } @@ -311,7 +311,7 @@ pub unsafe fn cuda_malloc_locked(count: usize) -> CudaResult<*mut T> { } let mut ptr: *mut c_void = ptr::null_mut(); - driver_sys::cuMemAllocHost_v2(&mut ptr as *mut *mut c_void, size).to_result()?; + driver_sys::cuMemAllocHost(&mut ptr as *mut *mut c_void, size).to_result()?; let ptr = ptr as *mut T; Ok(ptr) } diff --git a/crates/cust/src/memory/mod.rs b/crates/cust/src/memory/mod.rs index d9fd4838..aa349145 100644 --- a/crates/cust/src/memory/mod.rs +++ b/crates/cust/src/memory/mod.rs @@ -205,25 +205,25 @@ mod private { impl Sealed for DeviceBox {} } -/// Simple wrapper over cuMemcpyHtoD_v2 +/// Simple wrapper over cuMemcpyHtoD #[allow(clippy::missing_safety_doc)] pub unsafe fn memcpy_htod( d_ptr: driver_sys::CUdeviceptr, src_ptr: *const c_void, size: usize, ) -> CudaResult<()> { - driver_sys::cuMemcpyHtoD_v2(d_ptr, src_ptr, size).to_result()?; + driver_sys::cuMemcpyHtoD(d_ptr, src_ptr, size).to_result()?; Ok(()) } -/// Simple wrapper over cuMemcpyDtoH_v2 +/// Simple wrapper over cuMemcpyDtoH #[allow(clippy::missing_safety_doc)] pub unsafe fn memcpy_dtoh( d_ptr: *mut c_void, src_ptr: driver_sys::CUdeviceptr, size: usize, ) -> CudaResult<()> { - driver_sys::cuMemcpyDtoH_v2(d_ptr, src_ptr, size).to_result()?; + driver_sys::cuMemcpyDtoH(d_ptr, src_ptr, size).to_result()?; Ok(()) } @@ -309,7 +309,7 @@ pub unsafe fn memcpy_2d_htod( Height: height, }; - driver_sys::cuMemcpy2D_v2(&pcopy).to_result()?; + driver_sys::cuMemcpy2D(&pcopy).to_result()?; Ok(()) } @@ -395,7 +395,7 @@ pub unsafe fn memcpy_2d_dtoh( Height: height, }; - driver_sys::cuMemcpy2D_v2(&pcopy).to_result()?; + driver_sys::cuMemcpy2D(&pcopy).to_result()?; Ok(()) } @@ -409,7 +409,7 @@ pub fn mem_get_info() -> CudaResult<(usize, usize)> { let mut mem_free = 0; let mut mem_total = 0; unsafe { - driver_sys::cuMemGetInfo_v2(&mut mem_free, &mut mem_total).to_result()?; + driver_sys::cuMemGetInfo(&mut mem_free, &mut mem_total).to_result()?; } Ok((mem_free, mem_total)) } diff --git a/crates/cust/src/module.rs b/crates/cust/src/module.rs index 815cbd5c..062a912b 100644 --- a/crates/cust/src/module.rs +++ b/crates/cust/src/module.rs @@ -340,7 +340,7 @@ impl Module { /// ``` #[deprecated( since = "0.3.0", - note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing + note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing an empty slice of options (usually) " )] @@ -390,7 +390,7 @@ impl Module { let mut ptr: DevicePointer = DevicePointer::null(); let mut size: usize = 0; - driver_sys::cuModuleGetGlobal_v2( + driver_sys::cuModuleGetGlobal( &mut ptr as *mut DevicePointer as *mut driver_sys::CUdeviceptr, &mut size as *mut usize, self.inner, @@ -513,12 +513,8 @@ impl CopyDestination for Symbol<'_, T> { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyHtoD_v2( - self.ptr.as_raw(), - val as *const T as *const c_void, - size, - ) - .to_result()? + driver_sys::cuMemcpyHtoD(self.ptr.as_raw(), val as *const T as *const c_void, size) + .to_result()? } } Ok(()) @@ -528,7 +524,7 @@ impl CopyDestination for Symbol<'_, T> { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoH_v2(val as *const T as *mut c_void, self.ptr.as_raw(), size) + driver_sys::cuMemcpyDtoH(val as *const T as *mut c_void, self.ptr.as_raw(), size) .to_result()? } } diff --git a/crates/cust/src/stream.rs b/crates/cust/src/stream.rs index dc67119d..41404da6 100644 --- a/crates/cust/src/stream.rs +++ b/crates/cust/src/stream.rs @@ -325,7 +325,7 @@ impl Stream { unsafe { let inner = mem::replace(&mut stream.inner, ptr::null_mut()); - match driver_sys::cuStreamDestroy_v2(inner).to_result() { + match driver_sys::cuStreamDestroy(inner).to_result() { Ok(()) => { mem::forget(stream); Ok(()) @@ -344,7 +344,7 @@ impl Drop for Stream { unsafe { let inner = mem::replace(&mut self.inner, ptr::null_mut()); - driver_sys::cuStreamDestroy_v2(inner); + driver_sys::cuStreamDestroy(inner); } } } diff --git a/crates/cust_raw/Cargo.toml b/crates/cust_raw/Cargo.toml index 94c91911..046a713f 100644 --- a/crates/cust_raw/Cargo.toml +++ b/crates/cust_raw/Cargo.toml @@ -11,6 +11,8 @@ build = "build/main.rs" [build-dependencies] bindgen = "0.71.1" +bimap = "0.6.3" +cc = "1.2.17" [package.metadata.docs.rs] features = [ diff --git a/crates/cust_raw/build/callbacks.rs b/crates/cust_raw/build/callbacks.rs new file mode 100644 index 00000000..a15e4366 --- /dev/null +++ b/crates/cust_raw/build/callbacks.rs @@ -0,0 +1,117 @@ +use std::cell; +use std::fs; +use std::path; +use std::sync; + +use bimap; +use bindgen::callbacks::{ItemInfo, ItemKind, MacroParsingBehavior, ParseCallbacks}; + +/// Struct to handle renaming of functions through macro expansion. +#[derive(Debug)] +pub(crate) struct FunctionRenames { + func_prefix: &'static str, + out_dir: path::PathBuf, + includes: path::PathBuf, + include_dirs: Vec, + macro_names: cell::RefCell>, + func_remaps: sync::OnceLock>, +} + +impl FunctionRenames { + pub fn new, I: Into>( + func_prefix: &'static str, + out_dir: P, + includes: I, + include_dirs: Vec, + ) -> Self { + Self { + func_prefix, + out_dir: out_dir.as_ref().to_path_buf(), + includes: includes.into(), + include_dirs, + macro_names: cell::RefCell::new(Vec::new()), + func_remaps: sync::OnceLock::new(), + } + } + + fn record_macro(&self, name: &str) { + self.macro_names.borrow_mut().push(name.to_string()); + } + + fn expand(&self) -> &bimap::BiHashMap { + self.func_remaps.get_or_init(|| { + let expand_me = self.out_dir.join("expand_macros.c"); + let includes = fs::read_to_string(&self.includes) + .expect("Failed to read includes for function renames"); + + let mut template = format!( + r#"{includes} +#define RENAMED2(from, to) RUST_RENAMED##from##_TO_##to +#define RENAMED(from, to) RENAMED2(from, to) +"# + ); + + for name in self.macro_names.borrow().iter() { + template.push_str(&format!("RENAMED(_{name}, {name})\n")); + } + + { + let mut temp = fs::File::create(&expand_me).unwrap(); + std::io::Write::write_all(&mut temp, template.as_bytes()).unwrap(); + } + + let mut build = cc::Build::new(); + build + .file(&expand_me) + .includes(&self.include_dirs) + .cargo_warnings(false); + + let expanded = match build.try_expand() { + Ok(expanded) => expanded, + Err(e) => panic!("Failed to expand macros: {}", e), + }; + let expanded = str::from_utf8(&expanded).unwrap(); + + let mut remaps = bimap::BiHashMap::new(); + for line in expanded.lines().rev() { + let rename_prefix = "RUST_RENAMED_"; + + if let Some((original, expanded)) = line + .strip_prefix(rename_prefix) + .and_then(|s| s.split_once("_TO_")) + .filter(|(l, r)| l != r && !r.is_empty()) + { + remaps.insert(original.to_string(), expanded.to_string()); + } + } + + fs::remove_file(&expand_me).expect("Failed to remove temporary file"); + remaps + }) + } +} + +impl ParseCallbacks for FunctionRenames { + fn will_parse_macro(&self, name: &str) -> MacroParsingBehavior { + if name.starts_with(self.func_prefix) { + self.record_macro(name); + } + MacroParsingBehavior::Default + } + + fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option { + let remaps = self.expand(); + match item_info.kind { + ItemKind::Function => remaps.get_by_right(item_info.name).cloned(), + _ => None, + } + } + + fn generated_link_name_override(&self, item_info: ItemInfo<'_>) -> Option { + let remaps = self.expand(); + match item_info.kind { + ItemKind::Function => remaps.get_by_left(item_info.name).cloned(), + _ => None, + } + } +} diff --git a/crates/cust_raw/build/main.rs b/crates/cust_raw/build/main.rs index 7137b40d..61eddfda 100644 --- a/crates/cust_raw/build/main.rs +++ b/crates/cust_raw/build/main.rs @@ -2,6 +2,7 @@ use std::env; use std::fs; use std::path; +pub mod callbacks; pub mod cuda_sdk; fn main() { @@ -79,8 +80,15 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { return; } let bindgen_path = path::PathBuf::from(format!("{}/driver_sys.rs", outdir.display())); + let header = "build/driver_wrapper.h"; let bindings = bindgen::Builder::default() - .header("build/driver_wrapper.h") + .header(header) + .parse_callbacks(Box::new(callbacks::FunctionRenames::new( + "cu", + outdir, + header, + sdk.cuda_include_paths().to_owned(), + ))) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .clang_args( sdk.cuda_include_paths() @@ -115,8 +123,15 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { return; } let bindgen_path = path::PathBuf::from(format!("{}/runtime_sys.rs", outdir.display())); + let header = "build/runtime_wrapper.h"; let bindings = bindgen::Builder::default() - .header("build/runtime_wrapper.h") + .header(header) + .parse_callbacks(Box::new(callbacks::FunctionRenames::new( + "cuda", + outdir, + header, + sdk.cuda_include_paths().to_owned(), + ))) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .clang_args( sdk.cuda_include_paths() @@ -148,16 +163,23 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { #[rustfmt::skip] let params = &[ (cfg!(feature = "cublas"), "cublas", "^cublas.*", "^CUBLAS.*"), - (cfg!(feature = "cublaslt"), "cublaslt", "^cublasLt.*", "^CUBLASLT.*"), - (cfg!(feature = "cublasxt"), "cublasxt", "^cublasXt.*", "^CUBLASXT.*"), + (cfg!(feature = "cublaslt"), "cublasLt", "^cublasLt.*", "^CUBLASLT.*"), + (cfg!(feature = "cublasxt"), "cublasXt", "^cublasXt.*", "^CUBLASXT.*"), ]; for (should_generate, pkg, tf, var) in params { if !should_generate { continue; } let bindgen_path = path::PathBuf::from(format!("{}/{pkg}_sys.rs", outdir.display())); + let header = format!("build/{pkg}_wrapper.h"); let bindings = bindgen::Builder::default() - .header(format!("build/{pkg}_wrapper.h")) + .header(&header) + .parse_callbacks(Box::new(callbacks::FunctionRenames::new( + pkg, + outdir, + header, + sdk.cuda_include_paths().to_owned(), + ))) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .clang_args( sdk.cuda_include_paths() diff --git a/crates/optix-sys/build/main.rs b/crates/optix-sys/build/main.rs index 078618f6..9301f6e6 100644 --- a/crates/optix-sys/build/main.rs +++ b/crates/optix-sys/build/main.rs @@ -15,6 +15,9 @@ fn main() { .expect("Cannot find transitive metadata 'cuda_include' from cust_raw package."); println!("cargo::rerun-if-changed=build"); + for e in sdk.related_optix_envs() { + println!("cargo::rerun-if-env-changed={}", e); + } // Emit metadata for the build script. println!("cargo::metadata=root={}", sdk.optix_root().display()); println!("cargo::metadata=version={}", sdk.optix_version()); diff --git a/crates/optix-sys/build/optix_sdk.rs b/crates/optix-sys/build/optix_sdk.rs index bc0cf736..46b7329f 100644 --- a/crates/optix-sys/build/optix_sdk.rs +++ b/crates/optix-sys/build/optix_sdk.rs @@ -3,6 +3,8 @@ use std::error; use std::fs; use std::path; +const OPTIX_ROOT_ENVS: &[&str] = &["OPTIX_ROOT", "OPTIX_ROOT_DIR"]; + /// Represents the OptiX SDK installation. #[derive(Debug, Clone)] pub struct OptiXSdk { @@ -60,14 +62,19 @@ impl OptiXSdk { self.optix_version % 100 } + pub fn related_optix_envs(&self) -> Vec { + OPTIX_ROOT_ENVS.iter().map(|s| s.to_string()).collect() + } + fn find_optix_root() -> Option { // the optix SDK installer sets OPTIX_ROOT_DIR whenever it installs. // We also check OPTIX_ROOT first in case someone wants to override it without overriding // the SDK-set variable. - env::var("OPTIX_ROOT") - .ok() - .or_else(|| env::var("OPTIX_ROOT_DIR").ok()) + OPTIX_ROOT_ENVS + .iter() + .filter_map(|env| env::var(env).ok()) .map(path::PathBuf::from) + .next() } /// Parses the content of the `optix.h` header file to extract the OptiX version. diff --git a/crates/optix/examples/ex02_pipeline/build.rs b/crates/optix/examples/ex02_pipeline/build.rs index 4e82edf0..4eb0f5b1 100644 --- a/crates/optix/examples/ex02_pipeline/build.rs +++ b/crates/optix/examples/ex02_pipeline/build.rs @@ -7,7 +7,7 @@ fn main() { println!("cargo::rerun-if-changed=build.rs"); let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let optix_include_paths = env::var_os("DEP_OPTIX_OPTIX_INCLUDE") + let optix_include_paths = env::var_os("DEP_OPTIX_INCLUDE_DIR") .map(|s| env::split_paths(s.as_os_str()).collect::>()) .expect("Cannot find transitive metadata 'optix_include' from optix-sys package."); diff --git a/crates/optix/examples/ex03_window/build.rs b/crates/optix/examples/ex03_window/build.rs index 06122b04..63d1bced 100644 --- a/crates/optix/examples/ex03_window/build.rs +++ b/crates/optix/examples/ex03_window/build.rs @@ -5,7 +5,7 @@ fn main() { println!("cargo::rerun-if-changed=build.rs"); let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let optix_include_paths = env::var_os("DEP_OPTIX_OPTIX_INCLUDE") + let optix_include_paths = env::var_os("DEP_OPTIX_INCLUDE_DIR") .map(|s| env::split_paths(s.as_os_str()).collect::>()) .expect("Cannot find transitive metadata 'optix_include' from optix-sys package."); diff --git a/examples/cuda/cpu/add/Cargo.toml b/examples/cuda/cpu/add/Cargo.toml index ae99bb78..523e0a75 100644 --- a/examples/cuda/cpu/add/Cargo.toml +++ b/examples/cuda/cpu/add/Cargo.toml @@ -14,8 +14,6 @@ log = "=0.4.17" regex-syntax = "=0.6.28" regex = "=1.11.1" thread_local = "=1.1.4" -jobserver = "=0.1.25" -cc = "=1.0.78" rayon = "=1.10" rayon-core = "=1.12.1" byteorder = "=1.4.0"