1
- use crate :: { error:: * , sys} ;
2
- use cust:: stream:: Stream ;
3
1
use std:: ffi:: CString ;
4
2
use std:: mem:: { self , MaybeUninit } ;
5
3
use std:: os:: raw:: c_char;
6
4
use std:: ptr;
7
5
8
- type Result < T , E = Error > = std:: result:: Result < T , E > ;
6
+ use cust:: stream:: Stream ;
7
+ use cust_raw:: cublas_sys;
8
+ use cust_raw:: driver_sys;
9
+
10
+ use super :: error:: DropResult ;
11
+ use super :: error:: ToResult as _;
12
+
13
+ type Result < T , E = super :: error:: Error > = std:: result:: Result < T , E > ;
9
14
10
15
bitflags:: bitflags! {
11
16
/// Configures precision levels for the math in cuBLAS.
12
- #[ derive( Default ) ]
17
+ #[ derive( Debug , Default , Clone , Copy , PartialEq , Eq , Hash ) ]
13
18
pub struct MathMode : u32 {
14
19
/// Highest performance mode which uses compute and intermediate storage precisions
15
20
/// with at least the same number of mantissa and exponent bits as requested. Will
@@ -68,7 +73,7 @@ bitflags::bitflags! {
68
73
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
69
74
#[ derive( Debug ) ]
70
75
pub struct CublasContext {
71
- pub ( crate ) raw : sys :: v2 :: cublasHandle_t ,
76
+ pub ( crate ) raw : cublas_sys :: cublasHandle_t ,
72
77
}
73
78
74
79
impl CublasContext {
@@ -87,10 +92,10 @@ impl CublasContext {
87
92
pub fn new ( ) -> Result < Self > {
88
93
let mut raw = MaybeUninit :: uninit ( ) ;
89
94
unsafe {
90
- sys :: v2 :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
91
- sys :: v2 :: cublasSetPointerMode_v2 (
95
+ cublas_sys :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96
+ cublas_sys :: cublasSetPointerMode_v2 (
92
97
raw. assume_init ( ) ,
93
- sys :: v2 :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
98
+ cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
94
99
)
95
100
. to_result ( ) ?;
96
101
Ok ( Self {
@@ -107,7 +112,7 @@ impl CublasContext {
107
112
108
113
unsafe {
109
114
let inner = mem:: replace ( & mut ctx. raw , ptr:: null_mut ( ) ) ;
110
- match sys :: v2 :: cublasDestroy_v2 ( inner) . to_result ( ) {
115
+ match cublas_sys :: cublasDestroy_v2 ( inner) . to_result ( ) {
111
116
Ok ( ( ) ) => {
112
117
mem:: forget ( ctx) ;
113
118
Ok ( ( ) )
@@ -122,7 +127,7 @@ impl CublasContext {
122
127
let mut raw = MaybeUninit :: < u32 > :: uninit ( ) ;
123
128
unsafe {
124
129
// getVersion can't fail
125
- sys :: v2 :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
130
+ cublas_sys :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
126
131
. to_result ( )
127
132
. unwrap ( ) ;
128
133
@@ -140,17 +145,17 @@ impl CublasContext {
140
145
) -> Result < T > {
141
146
unsafe {
142
147
// cudaStream_t is the same as CUstream
143
- sys :: v2 :: cublasSetStream_v2 (
148
+ cublas_sys :: cublasSetStream_v2 (
144
149
self . raw ,
145
- mem:: transmute :: < * mut cust :: sys :: CUstream_st , * mut cublas_sys:: v2 :: CUstream_st > (
150
+ mem:: transmute :: < * mut driver_sys :: CUstream_st , * mut cublas_sys:: CUstream_st > (
146
151
stream. as_inner ( ) ,
147
152
) ,
148
153
)
149
154
. to_result ( ) ?;
150
155
let res = func ( self ) ?;
151
156
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
152
157
// execute a raw sys function with the context's handle.
153
- sys :: v2 :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
158
+ cublas_sys :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
154
159
Ok ( res)
155
160
}
156
161
}
@@ -180,12 +185,12 @@ impl CublasContext {
180
185
/// ```
181
186
pub fn set_atomics_mode ( & self , allowed : bool ) -> Result < ( ) > {
182
187
unsafe {
183
- Ok ( sys :: v2 :: cublasSetAtomicsMode (
188
+ Ok ( cublas_sys :: cublasSetAtomicsMode (
184
189
self . raw ,
185
190
if allowed {
186
- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
191
+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
187
192
} else {
188
- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
193
+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
189
194
} ,
190
195
)
191
196
. to_result ( ) ?)
@@ -210,10 +215,10 @@ impl CublasContext {
210
215
pub fn get_atomics_mode ( & self ) -> Result < bool > {
211
216
let mut mode = MaybeUninit :: uninit ( ) ;
212
217
unsafe {
213
- sys :: v2 :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
218
+ cublas_sys :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
214
219
Ok ( match mode. assume_init ( ) {
215
- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
216
- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
220
+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
221
+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
217
222
} )
218
223
}
219
224
}
@@ -233,9 +238,9 @@ impl CublasContext {
233
238
/// ```
234
239
pub fn set_math_mode ( & self , math_mode : MathMode ) -> Result < ( ) > {
235
240
unsafe {
236
- Ok ( sys :: v2 :: cublasSetMathMode (
241
+ Ok ( cublas_sys :: cublasSetMathMode (
237
242
self . raw ,
238
- mem:: transmute :: < u32 , cublas_sys:: v2 :: cublasMath_t > ( math_mode. bits ( ) ) ,
243
+ mem:: transmute :: < u32 , cublas_sys:: cublasMath_t > ( math_mode. bits ( ) ) ,
239
244
)
240
245
. to_result ( ) ?)
241
246
}
@@ -258,7 +263,7 @@ impl CublasContext {
258
263
pub fn get_math_mode ( & self ) -> Result < MathMode > {
259
264
let mut mode = MaybeUninit :: uninit ( ) ;
260
265
unsafe {
261
- sys :: v2 :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
266
+ cublas_sys :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
262
267
Ok ( MathMode :: from_bits ( mode. assume_init ( ) as u32 )
263
268
. expect ( "Invalid MathMode from cuBLAS" ) )
264
269
}
@@ -298,7 +303,7 @@ impl CublasContext {
298
303
let path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ;
299
304
let path_ptr = path. map_or ( ptr:: null ( ) , |s| s. as_ptr ( ) ) ;
300
305
301
- sys :: v2 :: cublasLoggerConfigure (
306
+ cublas_sys :: cublasLoggerConfigure (
302
307
enable as i32 ,
303
308
log_to_stdout as i32 ,
304
309
log_to_stderr as i32 ,
@@ -315,7 +320,7 @@ impl CublasContext {
315
320
///
316
321
/// The callback must not panic and unwind.
317
322
pub unsafe fn set_logger_callback ( callback : Option < unsafe extern "C" fn ( * const c_char ) > ) {
318
- sys :: v2 :: cublasSetLoggerCallback ( callback)
323
+ cublas_sys :: cublasSetLoggerCallback ( callback)
319
324
. to_result ( )
320
325
. unwrap ( ) ;
321
326
}
@@ -324,7 +329,7 @@ impl CublasContext {
324
329
pub fn get_logger_callback ( ) -> Option < unsafe extern "C" fn ( * const c_char ) > {
325
330
let mut cb = MaybeUninit :: uninit ( ) ;
326
331
unsafe {
327
- sys :: v2 :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
332
+ cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
328
333
. to_result ( )
329
334
. unwrap ( ) ;
330
335
cb. assume_init ( )
@@ -335,7 +340,7 @@ impl CublasContext {
335
340
impl Drop for CublasContext {
336
341
fn drop ( & mut self ) {
337
342
unsafe {
338
- sys :: v2 :: cublasDestroy_v2 ( self . raw ) ;
343
+ cublas_sys :: cublasDestroy_v2 ( self . raw ) ;
339
344
}
340
345
}
341
346
}
0 commit comments