Skip to content

Commit 321110e

Browse files
authored
Feat: Add Softmax forward and backward passes (#54)
* Feat: Add Softmax * Chore: Remove Ok() return from attn fwd
1 parent 5d2bf59 commit 321110e

File tree

5 files changed

+189
-3
lines changed

5 files changed

+189
-3
lines changed

crates/cudnn/src/attention/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ impl CudnnContext {
180180
reserve_space_size,
181181
reserve_space_ptr,
182182
)
183-
.into_result()?;
184-
185-
Ok(())
183+
.into_result()
186184
}
187185
}
188186
}

crates/cudnn/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod math_type;
1313
mod nan_propagation;
1414
mod op_tensor;
1515
mod rnn;
16+
mod softmax;
1617
mod tensor;
1718
mod w_grad_mode;
1819

@@ -27,6 +28,7 @@ pub use math_type::*;
2728
pub use nan_propagation::*;
2829
pub use op_tensor::*;
2930
pub use rnn::*;
31+
pub use softmax::*;
3032
pub use tensor::*;
3133
pub use w_grad_mode::*;
3234

crates/cudnn/src/softmax/mod.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
mod softmax_algo;
2+
mod softmax_mode;
3+
4+
pub use softmax_algo::*;
5+
pub use softmax_mode::*;
6+
7+
use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, SupportedOp, TensorDescriptor};
8+
use cust::memory::GpuBuffer;
9+
10+
impl CudnnContext {
11+
/// Computes the softmax function.
12+
///
13+
/// # Arguments
14+
///
15+
/// * `algo` - softmax algorithm to compute.
16+
///
17+
/// * `mode` - specifies the softmax mode.
18+
///
19+
/// * `alpha` - scaling factor for the result. Must be stored in host memory.
20+
///
21+
/// * `x_desc` - tensor descriptor for the operand.
22+
///
23+
/// * `x` - operand data in device memory.
24+
///
25+
/// * `beta` - scaling factor for the destination tensor.
26+
///
27+
/// * `y_desc` - tensor descriptor for the result.
28+
///
29+
/// * `y` - output data in device memory.
30+
///
31+
/// # Errors
32+
///
33+
/// Returns errors if the configuration in input is not supported, the tensor shapes differ or
34+
/// the data types of the input and destination tensor are not the same.
35+
pub fn softmax_forward<T, CompT>(
36+
&self,
37+
algo: SoftmaxAlgo,
38+
mode: SoftmaxMode,
39+
alpha: CompT,
40+
x_desc: &TensorDescriptor<T>,
41+
x: impl GpuBuffer<T>,
42+
beta: CompT,
43+
y_desc: &TensorDescriptor<T>,
44+
y: &mut impl GpuBuffer<T>,
45+
) -> Result<(), CudnnError>
46+
where
47+
T: DataType,
48+
CompT: SupportedOp<T, T, T>,
49+
{
50+
let alpha_ptr = &alpha as *const CompT as *const _;
51+
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
52+
53+
let beta_ptr = &beta as *const CompT as *const _;
54+
let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _;
55+
56+
unsafe {
57+
sys::cudnnSoftmaxForward(
58+
self.raw,
59+
algo.into(),
60+
mode.into(),
61+
alpha_ptr,
62+
x_desc.raw,
63+
x_ptr,
64+
beta_ptr,
65+
y_desc.raw,
66+
y_ptr,
67+
)
68+
.into_result()
69+
}
70+
}
71+
72+
/// Computes the gradient of the softmax function
73+
///
74+
/// # Arguments
75+
///
76+
/// * `algo` - softmax algorithm to compute the gradient of.
77+
///
78+
/// * `mode` - specifies the softmax mode to compute the gradient of.
79+
///
80+
/// * `alpha` - scaling factor for the result. Must be stored in host memory.
81+
///
82+
/// * `y_desc` - tensor descriptor for the operand.
83+
///
84+
/// * `y` - operand data in device memory.
85+
///
86+
/// * `dy_desc` - tensor descriptor for the result.
87+
///
88+
/// * `dy` - output data in device memory.
89+
///
90+
/// * `beta` - scaling factor for the differential tensor.
91+
///
92+
/// * `dx_desc` - differential tensor descriptor.
93+
///
94+
/// * `dx` - differential data in device memory.
95+
///
96+
/// # Errors
97+
///
98+
/// Returns errors if the configuration in input is not supported, the tensor shapes differ or
99+
/// the data types of the input and differential tensor are not the same.
100+
pub fn softmax_backward<T, CompT>(
101+
&self,
102+
algo: SoftmaxAlgo,
103+
mode: SoftmaxMode,
104+
alpha: CompT,
105+
y_desc: &TensorDescriptor<T>,
106+
y: impl GpuBuffer<T>,
107+
dy_desc: &TensorDescriptor<T>,
108+
dy: &impl GpuBuffer<T>,
109+
beta: CompT,
110+
dx_desc: &TensorDescriptor<T>,
111+
dx: &mut impl GpuBuffer<T>,
112+
) -> Result<(), CudnnError>
113+
where
114+
T: DataType,
115+
CompT: SupportedOp<T, T, T>,
116+
{
117+
let alpha_ptr = &alpha as *const CompT as *const _;
118+
let y_ptr = y.as_device_ptr().as_ptr() as *const _;
119+
120+
let beta_ptr = &beta as *const CompT as *const _;
121+
let dy_ptr = dy.as_device_ptr().as_ptr() as *const _;
122+
123+
let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _;
124+
125+
unsafe {
126+
sys::cudnnSoftmaxBackward(
127+
self.raw,
128+
algo.into(),
129+
mode.into(),
130+
alpha_ptr,
131+
y_desc.raw,
132+
y_ptr,
133+
dy_desc.raw,
134+
dy_ptr,
135+
beta_ptr,
136+
dx_desc.raw,
137+
dx_ptr,
138+
)
139+
.into_result()
140+
}
141+
}
142+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use crate::sys;
2+
3+
/// Specifies the implementation of the softmax function.
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5+
pub enum SoftmaxAlgo {
6+
/// This implementation applies the straightforward softmax operation.
7+
Fast,
8+
/// This implementation scales each point of the softmax input domain by its maximum value
9+
/// to avoid potential floating point overflows in the softmax evaluation.
10+
Accurate,
11+
/// This entry performs the log softmax operation, avoiding overflows by scaling each point in
12+
/// the input domain as in the accurate version.
13+
Log,
14+
}
15+
16+
impl From<SoftmaxAlgo> for sys::cudnnSoftmaxAlgorithm_t {
17+
fn from(algo: SoftmaxAlgo) -> Self {
18+
match algo {
19+
SoftmaxAlgo::Fast => Self::CUDNN_SOFTMAX_FAST,
20+
SoftmaxAlgo::Accurate => Self::CUDNN_SOFTMAX_ACCURATE,
21+
SoftmaxAlgo::Log => Self::CUDNN_SOFTMAX_ACCURATE,
22+
}
23+
}
24+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use crate::{sys, SoftmaxAlgo};
2+
3+
/// Specifies how the softmax input must be processed.
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5+
pub enum SoftmaxMode {
6+
/// The softmax operation is computed per image (N) across the dimensions C,H,W.
7+
Instance,
8+
/// The softmax operation is computed per spatial location (H,W) per image (N) across
9+
/// dimension C.
10+
Channel,
11+
}
12+
13+
impl From<SoftmaxMode> for sys::cudnnSoftmaxMode_t {
14+
fn from(mode: SoftmaxMode) -> Self {
15+
match mode {
16+
SoftmaxMode::Channel => Self::CUDNN_SOFTMAX_MODE_CHANNEL,
17+
SoftmaxMode::Instance => Self::CUDNN_SOFTMAX_MODE_INSTANCE,
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)