Skip to content

Commit 6c43b45

Browse files
refactor(cust_raw): consolidate CUDA, OptiX bindgen and remove find_cuda_helper
1. Consolidation of bindgen related "*-sys" packages - Remove the common dependency of `find_cuda_helper`. Use the cargo metadata mechanism instead. - Merged all CUDA bindgen-generated code into the cust_raw crate for simplicity and maintainability. 2. OptiX Integration - Split optix into optix (high-level API) and optix-sys (low-level bindgens) for better abstraction. 3. CUDA 12+ Support - Updated cust to support CUDA versions >= 12. - Added compatibility for CUDA 12.3+ graph API changes: - Renamed cuGraphKernelNodeGetParams → cuGraphKernelNodeGetParams_v2. - Kernel node parameters now require an explicit kernel and context. - Enabled conditional node support for CUDA >= 12.3. Co-authored-by: Adam Basfop Cavendish <[email protected]> Co-authored-by: Jorge Ortega <[email protected]>
1 parent 7eb199c commit 6c43b45

File tree

177 files changed

+24044
-14980
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+24044
-14980
lines changed

Cargo.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ members = [
1111
"examples/optix/*",
1212
"examples/cuda/cpu/*",
1313
"examples/cuda/gpu/*",
14-
1514
]
1615

1716
exclude = [
18-
"crates/optix/examples/common"
17+
"crates/optix/examples/common",
1918
]
2019

2120
[profile.dev.package.rustc_codegen_nvvm]

crates/blastoff/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA"
77

88
[dependencies]
99
bitflags = "2.8"
10-
cublas_sys = { version = "0.1", path = "../cublas_sys" }
1110
cust = { version = "0.3", path = "../cust", features = ["impl_num_complex"] }
11+
cust_raw = { path = "../cust_raw", features = ["cublas"] }
1212
num-complex = "0.4.6"
1313
half = { version = "2.4.1", optional = true }
1414

crates/blastoff/src/context.rs

+32-26
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
use crate::{error::*, sys};
2-
use cust::stream::Stream;
31
use std::ffi::CString;
42
use std::mem::{self, MaybeUninit};
53
use std::os::raw::c_char;
64
use std::ptr;
75

8-
type Result<T, E = Error> = std::result::Result<T, E>;
6+
use cust::stream::Stream;
7+
use cust_raw::cublas_sys;
8+
use cust_raw::driver_sys;
9+
10+
use super::error::DropResult;
11+
use super::error::ToResult as _;
12+
13+
type Result<T, E = super::error::Error> = std::result::Result<T, E>;
914

1015
bitflags::bitflags! {
1116
/// Configures precision levels for the math in cuBLAS.
12-
#[derive(Default)]
17+
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
1318
pub struct MathMode: u32 {
1419
/// Highest performance mode which uses compute and intermediate storage precisions
1520
/// with at least the same number of mantissa and exponent bits as requested. Will
@@ -68,7 +73,7 @@ bitflags::bitflags! {
6873
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
6974
#[derive(Debug)]
7075
pub struct CublasContext {
71-
pub(crate) raw: sys::v2::cublasHandle_t,
76+
pub(crate) raw: cublas_sys::cublasHandle_t,
7277
}
7378

7479
impl CublasContext {
@@ -87,10 +92,10 @@ impl CublasContext {
8792
pub fn new() -> Result<Self> {
8893
let mut raw = MaybeUninit::uninit();
8994
unsafe {
90-
sys::v2::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
91-
sys::v2::cublasSetPointerMode_v2(
95+
cublas_sys::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
96+
cublas_sys::cublasSetPointerMode_v2(
9297
raw.assume_init(),
93-
sys::v2::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
98+
cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
9499
)
95100
.to_result()?;
96101
Ok(Self {
@@ -107,7 +112,7 @@ impl CublasContext {
107112

108113
unsafe {
109114
let inner = mem::replace(&mut ctx.raw, ptr::null_mut());
110-
match sys::v2::cublasDestroy_v2(inner).to_result() {
115+
match cublas_sys::cublasDestroy_v2(inner).to_result() {
111116
Ok(()) => {
112117
mem::forget(ctx);
113118
Ok(())
@@ -122,7 +127,7 @@ impl CublasContext {
122127
let mut raw = MaybeUninit::<u32>::uninit();
123128
unsafe {
124129
// getVersion can't fail
125-
sys::v2::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
130+
cublas_sys::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
126131
.to_result()
127132
.unwrap();
128133

@@ -140,17 +145,17 @@ impl CublasContext {
140145
) -> Result<T> {
141146
unsafe {
142147
// cudaStream_t is the same as CUstream
143-
sys::v2::cublasSetStream_v2(
148+
cublas_sys::cublasSetStream_v2(
144149
self.raw,
145-
mem::transmute::<*mut cust::sys::CUstream_st, *mut cublas_sys::v2::CUstream_st>(
150+
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
146151
stream.as_inner(),
147152
),
148153
)
149154
.to_result()?;
150155
let res = func(self)?;
151156
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
152157
// execute a raw sys function with the context's handle.
153-
sys::v2::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
158+
cublas_sys::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
154159
Ok(res)
155160
}
156161
}
@@ -180,12 +185,12 @@ impl CublasContext {
180185
/// ```
181186
pub fn set_atomics_mode(&self, allowed: bool) -> Result<()> {
182187
unsafe {
183-
Ok(sys::v2::cublasSetAtomicsMode(
188+
Ok(cublas_sys::cublasSetAtomicsMode(
184189
self.raw,
185190
if allowed {
186-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
191+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
187192
} else {
188-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
193+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
189194
},
190195
)
191196
.to_result()?)
@@ -210,10 +215,11 @@ impl CublasContext {
210215
pub fn get_atomics_mode(&self) -> Result<bool> {
211216
let mut mode = MaybeUninit::uninit();
212217
unsafe {
213-
sys::v2::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
218+
cublas_sys::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
214219
Ok(match mode.assume_init() {
215-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
216-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
220+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
221+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
222+
_ => false,
217223
})
218224
}
219225
}
@@ -233,9 +239,9 @@ impl CublasContext {
233239
/// ```
234240
pub fn set_math_mode(&self, math_mode: MathMode) -> Result<()> {
235241
unsafe {
236-
Ok(sys::v2::cublasSetMathMode(
242+
Ok(cublas_sys::cublasSetMathMode(
237243
self.raw,
238-
mem::transmute::<u32, cublas_sys::v2::cublasMath_t>(math_mode.bits()),
244+
mem::transmute::<u32, cublas_sys::cublasMath_t>(math_mode.bits()),
239245
)
240246
.to_result()?)
241247
}
@@ -258,7 +264,7 @@ impl CublasContext {
258264
pub fn get_math_mode(&self) -> Result<MathMode> {
259265
let mut mode = MaybeUninit::uninit();
260266
unsafe {
261-
sys::v2::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
267+
cublas_sys::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
262268
Ok(MathMode::from_bits(mode.assume_init() as u32)
263269
.expect("Invalid MathMode from cuBLAS"))
264270
}
@@ -298,7 +304,7 @@ impl CublasContext {
298304
let path = log_file_name.map(|p| CString::new(p).expect("nul in log_file_name"));
299305
let path_ptr = path.map_or(ptr::null(), |s| s.as_ptr());
300306

301-
sys::v2::cublasLoggerConfigure(
307+
cublas_sys::cublasLoggerConfigure(
302308
enable as i32,
303309
log_to_stdout as i32,
304310
log_to_stderr as i32,
@@ -315,7 +321,7 @@ impl CublasContext {
315321
///
316322
/// The callback must not panic and unwind.
317323
pub unsafe fn set_logger_callback(callback: Option<unsafe extern "C" fn(*const c_char)>) {
318-
sys::v2::cublasSetLoggerCallback(callback)
324+
cublas_sys::cublasSetLoggerCallback(callback)
319325
.to_result()
320326
.unwrap();
321327
}
@@ -324,7 +330,7 @@ impl CublasContext {
324330
pub fn get_logger_callback() -> Option<unsafe extern "C" fn(*const c_char)> {
325331
let mut cb = MaybeUninit::uninit();
326332
unsafe {
327-
sys::v2::cublasGetLoggerCallback(cb.as_mut_ptr())
333+
cublas_sys::cublasGetLoggerCallback(cb.as_mut_ptr())
328334
.to_result()
329335
.unwrap();
330336
cb.assume_init()
@@ -335,7 +341,7 @@ impl CublasContext {
335341
impl Drop for CublasContext {
336342
fn drop(&mut self) {
337343
unsafe {
338-
sys::v2::cublasDestroy_v2(self.raw);
344+
cublas_sys::cublasDestroy_v2(self.raw);
339345
}
340346
}
341347
}

crates/blastoff/src/error.rs

+28-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
use crate::sys;
2-
use cust::error::CudaError;
31
use std::{ffi::CStr, fmt::Display};
42

3+
use cust::error::CudaError;
4+
use cust_raw::cublas_sys;
5+
56
/// Result that contains the un-dropped value on error.
67
pub type DropResult<T> = std::result::Result<(), (CublasError, T)>;
78

@@ -24,7 +25,7 @@ impl std::error::Error for CublasError {}
2425
impl Display for CublasError {
2526
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2627
unsafe {
27-
let ptr = sys::v2::cublasGetStatusString(self.into_raw());
28+
let ptr = cublas_sys::cublasGetStatusString(self.into_raw());
2829
let cow = CStr::from_ptr(ptr).to_string_lossy();
2930
f.write_str(cow.as_ref())
3031
}
@@ -35,39 +36,42 @@ pub trait ToResult {
3536
fn to_result(self) -> Result<(), CublasError>;
3637
}
3738

38-
impl ToResult for sys::v2::cublasStatus_t {
39+
impl ToResult for cublas_sys::cublasStatus_t {
3940
fn to_result(self) -> Result<(), CublasError> {
41+
use cust_raw::cublas_sys::cublasStatus_t::*;
4042
use CublasError::*;
4143

4244
Err(match self {
43-
sys::v2::cublasStatus_t::CUBLAS_STATUS_SUCCESS => return Ok(()),
44-
sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => NotInitialized,
45-
sys::v2::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => AllocFailed,
46-
sys::v2::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE => InvalidValue,
47-
sys::v2::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => ArchMismatch,
48-
sys::v2::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR => MappingError,
49-
sys::v2::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => ExecutionFailed,
50-
sys::v2::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR => InternalError,
51-
sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED => NotSupported,
52-
sys::v2::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR => LicenseError,
45+
CUBLAS_STATUS_SUCCESS => return Ok(()),
46+
CUBLAS_STATUS_NOT_INITIALIZED => NotInitialized,
47+
CUBLAS_STATUS_ALLOC_FAILED => AllocFailed,
48+
CUBLAS_STATUS_INVALID_VALUE => InvalidValue,
49+
CUBLAS_STATUS_ARCH_MISMATCH => ArchMismatch,
50+
CUBLAS_STATUS_MAPPING_ERROR => MappingError,
51+
CUBLAS_STATUS_EXECUTION_FAILED => ExecutionFailed,
52+
CUBLAS_STATUS_INTERNAL_ERROR => InternalError,
53+
CUBLAS_STATUS_NOT_SUPPORTED => NotSupported,
54+
CUBLAS_STATUS_LICENSE_ERROR => LicenseError,
55+
_ => NotSupported,
5356
})
5457
}
5558
}
5659

5760
impl CublasError {
58-
pub fn into_raw(self) -> sys::v2::cublasStatus_t {
61+
pub fn into_raw(self) -> cublas_sys::cublasStatus_t {
62+
use cust_raw::cublas_sys::cublasStatus_t::*;
5963
use CublasError::*;
6064

6165
match self {
62-
NotInitialized => sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED,
63-
AllocFailed => sys::v2::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED,
64-
InvalidValue => sys::v2::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE,
65-
ArchMismatch => sys::v2::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH,
66-
MappingError => sys::v2::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR,
67-
ExecutionFailed => sys::v2::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED,
68-
InternalError => sys::v2::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR,
69-
NotSupported => sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED,
70-
LicenseError => sys::v2::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR,
66+
NotInitialized => CUBLAS_STATUS_NOT_INITIALIZED,
67+
AllocFailed => CUBLAS_STATUS_ALLOC_FAILED,
68+
InvalidValue => CUBLAS_STATUS_INVALID_VALUE,
69+
ArchMismatch => CUBLAS_STATUS_ARCH_MISMATCH,
70+
MappingError => CUBLAS_STATUS_MAPPING_ERROR,
71+
ExecutionFailed => CUBLAS_STATUS_EXECUTION_FAILED,
72+
InternalError => CUBLAS_STATUS_INTERNAL_ERROR,
73+
NotSupported => CUBLAS_STATUS_NOT_SUPPORTED,
74+
LicenseError => CUBLAS_STATUS_LICENSE_ERROR,
7175
}
7276
}
7377
}

crates/blastoff/src/lib.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#![allow(clippy::too_many_arguments)]
1111
#![cfg_attr(docsrs, feature(doc_cfg))]
1212

13-
pub use cublas_sys as sys;
13+
pub use cust_raw::cublas_sys;
1414
use num_complex::{Complex32, Complex64};
1515

1616
pub use context::*;
@@ -39,34 +39,34 @@ pub trait BlasDatatype: private::Sealed + cust::memory::DeviceCopy {
3939
/// The corresponding float type. For complex numbers this means their backing
4040
/// precision, and for floats it is just themselves.
4141
type FloatTy: Float;
42-
fn to_raw(&self) -> sys::v2::cudaDataType;
42+
fn to_raw(&self) -> cublas_sys::cudaDataType;
4343
}
4444

4545
impl BlasDatatype for f32 {
4646
type FloatTy = f32;
47-
fn to_raw(&self) -> sys::v2::cudaDataType {
48-
sys::v2::cudaDataType::CUDA_R_32F
47+
fn to_raw(&self) -> cublas_sys::cudaDataType {
48+
cublas_sys::cudaDataType::CUDA_R_32F
4949
}
5050
}
5151

5252
impl BlasDatatype for f64 {
5353
type FloatTy = f64;
54-
fn to_raw(&self) -> sys::v2::cudaDataType {
55-
sys::v2::cudaDataType::CUDA_R_64F
54+
fn to_raw(&self) -> cublas_sys::cudaDataType {
55+
cublas_sys::cudaDataType::CUDA_R_64F
5656
}
5757
}
5858

5959
impl BlasDatatype for Complex32 {
6060
type FloatTy = f32;
61-
fn to_raw(&self) -> sys::v2::cudaDataType {
62-
sys::v2::cudaDataType::CUDA_C_32F
61+
fn to_raw(&self) -> cublas_sys::cudaDataType {
62+
cublas_sys::cudaDataType::CUDA_C_32F
6363
}
6464
}
6565

6666
impl BlasDatatype for Complex64 {
6767
type FloatTy = f64;
68-
fn to_raw(&self) -> sys::v2::cudaDataType {
69-
sys::v2::cudaDataType::CUDA_C_64F
68+
fn to_raw(&self) -> cublas_sys::cudaDataType {
69+
cublas_sys::cudaDataType::CUDA_C_64F
7070
}
7171
}
7272

@@ -106,11 +106,11 @@ pub enum MatrixOp {
106106

107107
impl MatrixOp {
108108
/// Returns the corresponding `cublasOperation_t` for this operation.
109-
pub fn to_raw(self) -> sys::v2::cublasOperation_t {
109+
pub fn to_raw(self) -> cublas_sys::cublasOperation_t {
110110
match self {
111-
MatrixOp::None => sys::v2::cublasOperation_t::CUBLAS_OP_N,
112-
MatrixOp::Transpose => sys::v2::cublasOperation_t::CUBLAS_OP_T,
113-
MatrixOp::ConjugateTranspose => sys::v2::cublasOperation_t::CUBLAS_OP_C,
111+
MatrixOp::None => cublas_sys::cublasOperation_t::CUBLAS_OP_N,
112+
MatrixOp::Transpose => cublas_sys::cublasOperation_t::CUBLAS_OP_T,
113+
MatrixOp::ConjugateTranspose => cublas_sys::cublasOperation_t::CUBLAS_OP_C,
114114
}
115115
}
116116
}

crates/blastoff/src/raw/level1.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use crate::{sys::v2::*, BlasDatatype};
2-
use num_complex::{Complex32, Complex64};
31
use std::os::raw::c_int;
42

3+
use cust_raw::cublas_sys::*;
4+
use num_complex::{Complex32, Complex64};
5+
6+
use crate::BlasDatatype;
7+
58
pub trait Level1: BlasDatatype {
69
unsafe fn amax(
710
handle: cublasHandle_t,

crates/blastoff/src/raw/level3.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use crate::{sys::v2::*, GemmDatatype};
2-
use num_complex::{Complex32, Complex64};
31
use std::os::raw::c_int;
42

3+
use cust_raw::cublas_sys::*;
4+
use num_complex::{Complex32, Complex64};
5+
6+
use crate::GemmDatatype;
7+
58
pub trait GemmOps: GemmDatatype {
69
unsafe fn gemm(
710
handle: cublasHandle_t,

crates/cublas_sys/Cargo.toml

-9
This file was deleted.

0 commit comments

Comments
 (0)