Skip to content

Commit d7d0b15

Browse files
refactor(cust_raw): consolidate CUDA, cuDNN, OptiX bindgen and remove find_cuda_helper
1. Consolidation of bindgen related "*-sys" packages - Remove the common dependency of `find_cuda_helper`. Use the cargo metadata mechanism instead. - Merged all CUDA bindgen-generated code into the cust_raw crate for simplicity and maintainability. - Add CUDA Runtime API bindgen support. 2. cuDNN and OptiX Integration - Split cudnn into cudnn (high-level API) and cudnn-sys (low-level bindgens) for better abstraction. - Split optix into optix (high-level API) and optix-sys (low-level bindgens) for better abstraction. 3. CUDA 12+ Support - Updated cust to support CUDA versions >= 12. - Added compatibility for CUDA 12.3+ graph API changes: - Renamed cuGraphKernelNodeGetParams → cuGraphKernelNodeGetParams_v2. - Enabled conditional node support for CUDA >= 12.3. 4. Temporarily disable cuDNN in CI - Windows CI pipelines have no cuDNN support yet. Co-authored-by: Adam Basfop Cavendish <[email protected]> Co-authored-by: Jorge Ortega <[email protected]>
1 parent 7eb199c commit d7d0b15

File tree

178 files changed

+2992
-20991
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

178 files changed

+2992
-20991
lines changed

.github/workflows/rust.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
run: cargo fmt --all -- --check
103103

104104
- name: Build
105-
run: cargo build --workspace --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
105+
run: cargo build --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*"
106106

107107
# Don't currently test because many tests rely on the system having a CUDA GPU
108108
# - name: Test
@@ -112,9 +112,9 @@ jobs:
112112
if: contains(matrix.os, 'ubuntu')
113113
env:
114114
RUSTFLAGS: -Dwarnings
115-
run: cargo clippy --workspace --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
115+
run: cargo clippy --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*"
116116

117117
- name: Check documentation
118118
env:
119119
RUSTDOCFLAGS: -Dwarnings
120-
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*"
120+
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"

Cargo.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ members = [
1111
"examples/optix/*",
1212
"examples/cuda/cpu/*",
1313
"examples/cuda/gpu/*",
14-
1514
]
1615

1716
exclude = [
18-
"crates/optix/examples/common"
17+
"crates/optix/examples/common",
1918
]
2019

2120
[profile.dev.package.rustc_codegen_nvvm]

crates/blastoff/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA"
77

88
[dependencies]
99
bitflags = "2.8"
10-
cublas_sys = { version = "0.1", path = "../cublas_sys" }
1110
cust = { version = "0.3", path = "../cust", features = ["impl_num_complex"] }
11+
cust_raw = { path = "../cust_raw", features = ["cublas"] }
1212
num-complex = "0.4.6"
1313
half = { version = "2.4.1", optional = true }
1414

crates/blastoff/src/context.rs

+31-26
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
use crate::{error::*, sys};
2-
use cust::stream::Stream;
31
use std::ffi::CString;
42
use std::mem::{self, MaybeUninit};
53
use std::os::raw::c_char;
64
use std::ptr;
75

8-
type Result<T, E = Error> = std::result::Result<T, E>;
6+
use cust::stream::Stream;
7+
use cust_raw::cublas_sys;
8+
use cust_raw::driver_sys;
9+
10+
use super::error::DropResult;
11+
use super::error::ToResult as _;
12+
13+
type Result<T, E = super::error::Error> = std::result::Result<T, E>;
914

1015
bitflags::bitflags! {
1116
/// Configures precision levels for the math in cuBLAS.
12-
#[derive(Default)]
17+
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
1318
pub struct MathMode: u32 {
1419
/// Highest performance mode which uses compute and intermediate storage precisions
1520
/// with at least the same number of mantissa and exponent bits as requested. Will
@@ -68,7 +73,7 @@ bitflags::bitflags! {
6873
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
6974
#[derive(Debug)]
7075
pub struct CublasContext {
71-
pub(crate) raw: sys::v2::cublasHandle_t,
76+
pub(crate) raw: cublas_sys::cublasHandle_t,
7277
}
7378

7479
impl CublasContext {
@@ -87,10 +92,10 @@ impl CublasContext {
8792
pub fn new() -> Result<Self> {
8893
let mut raw = MaybeUninit::uninit();
8994
unsafe {
90-
sys::v2::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
91-
sys::v2::cublasSetPointerMode_v2(
95+
cublas_sys::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
96+
cublas_sys::cublasSetPointerMode_v2(
9297
raw.assume_init(),
93-
sys::v2::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
98+
cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
9499
)
95100
.to_result()?;
96101
Ok(Self {
@@ -107,7 +112,7 @@ impl CublasContext {
107112

108113
unsafe {
109114
let inner = mem::replace(&mut ctx.raw, ptr::null_mut());
110-
match sys::v2::cublasDestroy_v2(inner).to_result() {
115+
match cublas_sys::cublasDestroy_v2(inner).to_result() {
111116
Ok(()) => {
112117
mem::forget(ctx);
113118
Ok(())
@@ -122,7 +127,7 @@ impl CublasContext {
122127
let mut raw = MaybeUninit::<u32>::uninit();
123128
unsafe {
124129
// getVersion can't fail
125-
sys::v2::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
130+
cublas_sys::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
126131
.to_result()
127132
.unwrap();
128133

@@ -140,17 +145,17 @@ impl CublasContext {
140145
) -> Result<T> {
141146
unsafe {
142147
// cudaStream_t is the same as CUstream
143-
sys::v2::cublasSetStream_v2(
148+
cublas_sys::cublasSetStream_v2(
144149
self.raw,
145-
mem::transmute::<*mut cust::sys::CUstream_st, *mut cublas_sys::v2::CUstream_st>(
150+
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
146151
stream.as_inner(),
147152
),
148153
)
149154
.to_result()?;
150155
let res = func(self)?;
151156
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
152157
// execute a raw sys function with the context's handle.
153-
sys::v2::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
158+
cublas_sys::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
154159
Ok(res)
155160
}
156161
}
@@ -180,12 +185,12 @@ impl CublasContext {
180185
/// ```
181186
pub fn set_atomics_mode(&self, allowed: bool) -> Result<()> {
182187
unsafe {
183-
Ok(sys::v2::cublasSetAtomicsMode(
188+
Ok(cublas_sys::cublasSetAtomicsMode(
184189
self.raw,
185190
if allowed {
186-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
191+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
187192
} else {
188-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
193+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
189194
},
190195
)
191196
.to_result()?)
@@ -210,10 +215,10 @@ impl CublasContext {
210215
pub fn get_atomics_mode(&self) -> Result<bool> {
211216
let mut mode = MaybeUninit::uninit();
212217
unsafe {
213-
sys::v2::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
218+
cublas_sys::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
214219
Ok(match mode.assume_init() {
215-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
216-
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
220+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
221+
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
217222
})
218223
}
219224
}
@@ -233,9 +238,9 @@ impl CublasContext {
233238
/// ```
234239
pub fn set_math_mode(&self, math_mode: MathMode) -> Result<()> {
235240
unsafe {
236-
Ok(sys::v2::cublasSetMathMode(
241+
Ok(cublas_sys::cublasSetMathMode(
237242
self.raw,
238-
mem::transmute::<u32, cublas_sys::v2::cublasMath_t>(math_mode.bits()),
243+
mem::transmute::<u32, cublas_sys::cublasMath_t>(math_mode.bits()),
239244
)
240245
.to_result()?)
241246
}
@@ -258,7 +263,7 @@ impl CublasContext {
258263
pub fn get_math_mode(&self) -> Result<MathMode> {
259264
let mut mode = MaybeUninit::uninit();
260265
unsafe {
261-
sys::v2::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
266+
cublas_sys::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
262267
Ok(MathMode::from_bits(mode.assume_init() as u32)
263268
.expect("Invalid MathMode from cuBLAS"))
264269
}
@@ -298,7 +303,7 @@ impl CublasContext {
298303
let path = log_file_name.map(|p| CString::new(p).expect("nul in log_file_name"));
299304
let path_ptr = path.map_or(ptr::null(), |s| s.as_ptr());
300305

301-
sys::v2::cublasLoggerConfigure(
306+
cublas_sys::cublasLoggerConfigure(
302307
enable as i32,
303308
log_to_stdout as i32,
304309
log_to_stderr as i32,
@@ -315,7 +320,7 @@ impl CublasContext {
315320
///
316321
/// The callback must not panic and unwind.
317322
pub unsafe fn set_logger_callback(callback: Option<unsafe extern "C" fn(*const c_char)>) {
318-
sys::v2::cublasSetLoggerCallback(callback)
323+
cublas_sys::cublasSetLoggerCallback(callback)
319324
.to_result()
320325
.unwrap();
321326
}
@@ -324,7 +329,7 @@ impl CublasContext {
324329
pub fn get_logger_callback() -> Option<unsafe extern "C" fn(*const c_char)> {
325330
let mut cb = MaybeUninit::uninit();
326331
unsafe {
327-
sys::v2::cublasGetLoggerCallback(cb.as_mut_ptr())
332+
cublas_sys::cublasGetLoggerCallback(cb.as_mut_ptr())
328333
.to_result()
329334
.unwrap();
330335
cb.assume_init()
@@ -335,7 +340,7 @@ impl CublasContext {
335340
impl Drop for CublasContext {
336341
fn drop(&mut self) {
337342
unsafe {
338-
sys::v2::cublasDestroy_v2(self.raw);
343+
cublas_sys::cublasDestroy_v2(self.raw);
339344
}
340345
}
341346
}

crates/blastoff/src/error.rs

+27-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
use crate::sys;
2-
use cust::error::CudaError;
31
use std::{ffi::CStr, fmt::Display};
42

3+
use cust::error::CudaError;
4+
use cust_raw::cublas_sys;
5+
56
/// Result that contains the un-dropped value on error.
67
pub type DropResult<T> = std::result::Result<(), (CublasError, T)>;
78

@@ -24,7 +25,7 @@ impl std::error::Error for CublasError {}
2425
impl Display for CublasError {
2526
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2627
unsafe {
27-
let ptr = sys::v2::cublasGetStatusString(self.into_raw());
28+
let ptr = cublas_sys::cublasGetStatusString(self.into_raw());
2829
let cow = CStr::from_ptr(ptr).to_string_lossy();
2930
f.write_str(cow.as_ref())
3031
}
@@ -35,39 +36,41 @@ pub trait ToResult {
3536
fn to_result(self) -> Result<(), CublasError>;
3637
}
3738

38-
impl ToResult for sys::v2::cublasStatus_t {
39+
impl ToResult for cublas_sys::cublasStatus_t {
3940
fn to_result(self) -> Result<(), CublasError> {
41+
use cust_raw::cublas_sys::cublasStatus_t::*;
4042
use CublasError::*;
4143

4244
Err(match self {
43-
sys::v2::cublasStatus_t::CUBLAS_STATUS_SUCCESS => return Ok(()),
44-
sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => NotInitialized,
45-
sys::v2::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => AllocFailed,
46-
sys::v2::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE => InvalidValue,
47-
sys::v2::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => ArchMismatch,
48-
sys::v2::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR => MappingError,
49-
sys::v2::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => ExecutionFailed,
50-
sys::v2::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR => InternalError,
51-
sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED => NotSupported,
52-
sys::v2::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR => LicenseError,
45+
CUBLAS_STATUS_SUCCESS => return Ok(()),
46+
CUBLAS_STATUS_NOT_INITIALIZED => NotInitialized,
47+
CUBLAS_STATUS_ALLOC_FAILED => AllocFailed,
48+
CUBLAS_STATUS_INVALID_VALUE => InvalidValue,
49+
CUBLAS_STATUS_ARCH_MISMATCH => ArchMismatch,
50+
CUBLAS_STATUS_MAPPING_ERROR => MappingError,
51+
CUBLAS_STATUS_EXECUTION_FAILED => ExecutionFailed,
52+
CUBLAS_STATUS_INTERNAL_ERROR => InternalError,
53+
CUBLAS_STATUS_NOT_SUPPORTED => NotSupported,
54+
CUBLAS_STATUS_LICENSE_ERROR => LicenseError,
5355
})
5456
}
5557
}
5658

5759
impl CublasError {
58-
pub fn into_raw(self) -> sys::v2::cublasStatus_t {
60+
pub fn into_raw(self) -> cublas_sys::cublasStatus_t {
61+
use cust_raw::cublas_sys::cublasStatus_t::*;
5962
use CublasError::*;
6063

6164
match self {
62-
NotInitialized => sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED,
63-
AllocFailed => sys::v2::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED,
64-
InvalidValue => sys::v2::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE,
65-
ArchMismatch => sys::v2::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH,
66-
MappingError => sys::v2::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR,
67-
ExecutionFailed => sys::v2::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED,
68-
InternalError => sys::v2::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR,
69-
NotSupported => sys::v2::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED,
70-
LicenseError => sys::v2::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR,
65+
NotInitialized => CUBLAS_STATUS_NOT_INITIALIZED,
66+
AllocFailed => CUBLAS_STATUS_ALLOC_FAILED,
67+
InvalidValue => CUBLAS_STATUS_INVALID_VALUE,
68+
ArchMismatch => CUBLAS_STATUS_ARCH_MISMATCH,
69+
MappingError => CUBLAS_STATUS_MAPPING_ERROR,
70+
ExecutionFailed => CUBLAS_STATUS_EXECUTION_FAILED,
71+
InternalError => CUBLAS_STATUS_INTERNAL_ERROR,
72+
NotSupported => CUBLAS_STATUS_NOT_SUPPORTED,
73+
LicenseError => CUBLAS_STATUS_LICENSE_ERROR,
7174
}
7275
}
7376
}

crates/blastoff/src/lib.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#![allow(clippy::too_many_arguments)]
1111
#![cfg_attr(docsrs, feature(doc_cfg))]
1212

13-
pub use cublas_sys as sys;
13+
pub use cust_raw::cublas_sys;
1414
use num_complex::{Complex32, Complex64};
1515

1616
pub use context::*;
@@ -39,34 +39,34 @@ pub trait BlasDatatype: private::Sealed + cust::memory::DeviceCopy {
3939
/// The corresponding float type. For complex numbers this means their backing
4040
/// precision, and for floats it is just themselves.
4141
type FloatTy: Float;
42-
fn to_raw(&self) -> sys::v2::cudaDataType;
42+
fn to_raw(&self) -> cublas_sys::cudaDataType;
4343
}
4444

4545
impl BlasDatatype for f32 {
4646
type FloatTy = f32;
47-
fn to_raw(&self) -> sys::v2::cudaDataType {
48-
sys::v2::cudaDataType::CUDA_R_32F
47+
fn to_raw(&self) -> cublas_sys::cudaDataType {
48+
cublas_sys::cudaDataType::CUDA_R_32F
4949
}
5050
}
5151

5252
impl BlasDatatype for f64 {
5353
type FloatTy = f64;
54-
fn to_raw(&self) -> sys::v2::cudaDataType {
55-
sys::v2::cudaDataType::CUDA_R_64F
54+
fn to_raw(&self) -> cublas_sys::cudaDataType {
55+
cublas_sys::cudaDataType::CUDA_R_64F
5656
}
5757
}
5858

5959
impl BlasDatatype for Complex32 {
6060
type FloatTy = f32;
61-
fn to_raw(&self) -> sys::v2::cudaDataType {
62-
sys::v2::cudaDataType::CUDA_C_32F
61+
fn to_raw(&self) -> cublas_sys::cudaDataType {
62+
cublas_sys::cudaDataType::CUDA_C_32F
6363
}
6464
}
6565

6666
impl BlasDatatype for Complex64 {
6767
type FloatTy = f64;
68-
fn to_raw(&self) -> sys::v2::cudaDataType {
69-
sys::v2::cudaDataType::CUDA_C_64F
68+
fn to_raw(&self) -> cublas_sys::cudaDataType {
69+
cublas_sys::cudaDataType::CUDA_C_64F
7070
}
7171
}
7272

@@ -106,11 +106,11 @@ pub enum MatrixOp {
106106

107107
impl MatrixOp {
108108
/// Returns the corresponding `cublasOperation_t` for this operation.
109-
pub fn to_raw(self) -> sys::v2::cublasOperation_t {
109+
pub fn to_raw(self) -> cublas_sys::cublasOperation_t {
110110
match self {
111-
MatrixOp::None => sys::v2::cublasOperation_t::CUBLAS_OP_N,
112-
MatrixOp::Transpose => sys::v2::cublasOperation_t::CUBLAS_OP_T,
113-
MatrixOp::ConjugateTranspose => sys::v2::cublasOperation_t::CUBLAS_OP_C,
111+
MatrixOp::None => cublas_sys::cublasOperation_t::CUBLAS_OP_N,
112+
MatrixOp::Transpose => cublas_sys::cublasOperation_t::CUBLAS_OP_T,
113+
MatrixOp::ConjugateTranspose => cublas_sys::cublasOperation_t::CUBLAS_OP_C,
114114
}
115115
}
116116
}

crates/blastoff/src/raw/level1.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use crate::{sys::v2::*, BlasDatatype};
2-
use num_complex::{Complex32, Complex64};
31
use std::os::raw::c_int;
42

3+
use cust_raw::cublas_sys::*;
4+
use num_complex::{Complex32, Complex64};
5+
6+
use crate::BlasDatatype;
7+
58
pub trait Level1: BlasDatatype {
69
unsafe fn amax(
710
handle: cublasHandle_t,

crates/blastoff/src/raw/level3.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use crate::{sys::v2::*, GemmDatatype};
2-
use num_complex::{Complex32, Complex64};
31
use std::os::raw::c_int;
42

3+
use cust_raw::cublas_sys::*;
4+
use num_complex::{Complex32, Complex64};
5+
6+
use crate::GemmDatatype;
7+
58
pub trait GemmOps: GemmDatatype {
69
unsafe fn gemm(
710
handle: cublasHandle_t,

0 commit comments

Comments
 (0)