Skip to content

Commit 820b5cf

Browse files
authored
Feat: Add activation forward and backward (#61)
* Fix: Missing references * Feat: Add activation forward and backward * Chore: Adds cudnn to crate lineup (#60)
1 parent 17b9a60 commit 820b5cf

17 files changed

+359
-37
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ The current line-up of libraries is the following:
5454
- `cuda_std` for GPU-side functions and utilities, such as thread index queries, memory allocation, warp intrinsics, etc.
5555
- *Not* a low level library, provides many utility functions to make it easier to write cleaner and more reliable GPU kernels.
5656
- Closely tied to `rustc_codegen_nvvm` which exposes GPU features through it internally.
57+
- [`cudnn`](https://github.com/Rust-GPU/Rust-CUDA/tree/master/crates/cudnn) for a collection of GPU-accelerated primitives for deep neural networks.
5758
- `cust` for CPU-side CUDA features such as launching GPU kernels, GPU memory allocation, device queries, etc.
5859
- High level with features such as RAII and Rust Results that make it easier and cleaner to manage the interface to the GPU.
5960
- A high level wrapper for the CUDA Driver API, the lower level version of the more common CUDA Runtime API used from C++.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use crate::{sys, ActivationMode, CudnnContext, CudnnError, IntoResult, NanPropagation};
2+
use std::mem::MaybeUninit;
3+
4+
/// The descriptor of a neuron activation operation.
5+
#[derive(Debug, PartialEq, Eq, Hash)]
6+
pub struct ActivationDescriptor {
7+
pub(crate) raw: sys::cudnnActivationDescriptor_t,
8+
}
9+
10+
impl ActivationDescriptor {
11+
/// Creates a new neuron activation descriptor.
12+
///
13+
/// # Arguments
14+
///
15+
/// * `mode` - activation function to compute.
16+
///
17+
/// * `nan_opt` - NaN propagation policy for the operation.
18+
///
19+
/// * `coefficient` - optional coefficient for the given function. It specifies the clipping
20+
/// threshold for `ActivationMode::ClippedRelu`.
21+
///
22+
/// # Examples
23+
///
24+
/// ```
25+
/// # use std::error::Error;
26+
/// #
27+
/// # fn main() -> Result<(), Box<dyn Error>> {
28+
/// use cudnn::{ActivationDescriptor, ActivationMode, CudnnContext, NanPropagation};
29+
///
30+
/// let ctx = CudnnContext::new()?;
31+
///
32+
/// let mode = ActivationMode::Swish;
33+
/// let nan_opt = NanPropagation::PropagateNaN;
34+
/// let coefficient = None;
35+
///
36+
/// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?;
37+
/// # Ok(())
38+
/// # }
39+
/// ```
40+
pub fn new(
41+
mode: ActivationMode,
42+
nan_opt: NanPropagation,
43+
coefficient: impl Into<Option<f64>>,
44+
) -> Result<Self, CudnnError> {
45+
let mut raw = MaybeUninit::uninit();
46+
47+
unsafe {
48+
sys::cudnnCreateActivationDescriptor(raw.as_mut_ptr()).into_result()?;
49+
50+
let mut raw = raw.assume_init();
51+
52+
let coefficient = coefficient.into().unwrap_or_else(|| match mode {
53+
ActivationMode::ClippedRelu => std::f64::MAX,
54+
_ => 1.0,
55+
});
56+
57+
sys::cudnnSetActivationDescriptor(raw, mode.into(), nan_opt.into(), coefficient)
58+
.into_result()?;
59+
60+
Ok(Self { raw })
61+
}
62+
}
63+
}
64+
65+
impl Drop for ActivationDescriptor {
66+
fn drop(&mut self) {
67+
unsafe {
68+
sys::cudnnDestroyActivationDescriptor(self.raw);
69+
}
70+
}
71+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use crate::sys;
2+
3+
/// Specifies a neuron activation function.
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5+
pub enum ActivationMode {
6+
/// Selects the sigmoid function.
7+
Sigmoid,
8+
/// Selects the rectified linear function.
9+
Relu,
10+
/// Selects the hyperbolic tangent function.
11+
Tanh,
12+
/// Selects the clipped rectified linear function.
13+
ClippedRelu,
14+
/// Selects the exponential linear function.
15+
Elu,
16+
/// Selects the swish function.
17+
Swish,
18+
}
19+
20+
impl From<ActivationMode> for sys::cudnnActivationMode_t {
21+
fn from(mode: ActivationMode) -> Self {
22+
match mode {
23+
ActivationMode::Sigmoid => Self::CUDNN_ACTIVATION_SIGMOID,
24+
ActivationMode::Relu => Self::CUDNN_ACTIVATION_RELU,
25+
ActivationMode::Tanh => Self::CUDNN_ACTIVATION_TANH,
26+
ActivationMode::ClippedRelu => Self::CUDNN_ACTIVATION_CLIPPED_RELU,
27+
ActivationMode::Elu => Self::CUDNN_ACTIVATION_ELU,
28+
ActivationMode::Swish => Self::CUDNN_ACTIVATION_SWISH,
29+
}
30+
}
31+
}

crates/cudnn/src/activation/mod.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
mod activation_descriptor;
2+
mod activation_mode;
3+
4+
pub use activation_descriptor::*;
5+
pub use activation_mode::*;
6+
7+
use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};
8+
use cust::memory::GpuBuffer;
9+
use std::mem::MaybeUninit;
10+
11+
impl CudnnContext {
12+
/// Applies a specific neuron activation functions element wise of the provided tensor.
13+
///
14+
/// # Arguments
15+
///
16+
/// * `activation_desc` - activation descriptor.
17+
///
18+
/// * `alpha` - scaling factor for the result.
19+
///
20+
/// * `x_desc` - tensor descriptor for the input.
21+
///
22+
/// * `x` - data for the input tensor.
23+
///
24+
/// * `beta` - scaling factor for the destination tensor.
25+
///
26+
/// * `y_desc` - tensor descriptor for the output.
27+
///
28+
/// * `y` - data for the output.
29+
///
30+
/// # Errors
31+
///
32+
/// Returns errors if the shapes of the `y` and `x` tensors do not match or an unsupported
33+
/// configuration of arguments is detected.
34+
///
35+
/// # Examples
36+
///
37+
/// ```
38+
/// # use std::error::Error;
39+
/// #
40+
/// # fn main() -> Result<(), Box<dyn Error>> {
41+
/// use cudnn::{ActivationDescriptor, ActivationMode, CudnnContext, NanPropagation, TensorDescriptor};
42+
/// use cust::memory::DeviceBuffer;
43+
///
44+
/// let ctx = CudnnContext::new()?;
45+
///
46+
/// let mode = ActivationMode::Swish;
47+
/// let nan_opt = NanPropagation::PropagateNaN;
48+
/// let coefficient = None;
49+
///
50+
/// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?;
51+
///
52+
/// let alpha: f32 = 1.0;
53+
/// let x_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
54+
/// let x = DeviceBuffer::<i8>::from_slice(&[10, 10, 10, 10, 10])?;
55+
///
56+
/// let beta: f32 = 0.0;
57+
/// let y_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
58+
/// let mut y = DeviceBuffer::<i8>::from_slice(&[0, 0, 0, 0, 0])?;
59+
///
60+
/// ctx.activation_forward(&desc, alpha, &x_desc, &x, beta, &y_desc, &mut y)?;
61+
///
62+
/// let y_host = y.as_host_vec()?;
63+
///
64+
/// assert!(y_host.iter().all(|el| *el == 10));
65+
/// # Ok(())
66+
/// # }
67+
/// ```
68+
pub fn activation_forward<CompT, T>(
69+
&self,
70+
activation_desc: &ActivationDescriptor,
71+
alpha: CompT,
72+
x_desc: &TensorDescriptor<T>,
73+
x: &impl GpuBuffer<T>,
74+
beta: CompT,
75+
y_desc: &TensorDescriptor<T>,
76+
y: &mut impl GpuBuffer<T>,
77+
) -> Result<(), CudnnError>
78+
where
79+
CompT: SupportedActFwd<T>,
80+
T: DataType,
81+
{
82+
let alpha_ptr = &alpha as *const CompT as *const _;
83+
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
84+
85+
let beta_ptr = &beta as *const CompT as *const _;
86+
let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _;
87+
88+
unsafe {
89+
sys::cudnnActivationForward(
90+
self.raw,
91+
activation_desc.raw,
92+
alpha_ptr,
93+
x_desc.raw,
94+
x_ptr,
95+
beta_ptr,
96+
y_desc.raw,
97+
y_ptr,
98+
)
99+
.into_result()
100+
}
101+
}
102+
103+
/// Computes the gradient of a neuron activation function.
104+
///
105+
/// # Arguments
106+
///
107+
/// * `activation_descriptor` - descriptor of a neuron activation operation.
108+
///
109+
/// * `alpha` - scaling factor for the result.
110+
///
111+
/// * `y_desc` - tensor descriptor for the output map.
112+
///
113+
/// * `y` - data for the output map.
114+
///
115+
/// * `dy_desc` - tensor descriptor for the differential of the output map.
116+
///
117+
/// * `dy` - data foe the differential of the output map.
118+
///
119+
/// * `x_desc` - tensor descriptor for the activation input.
120+
///
121+
/// * `x` - data for the activation input.
122+
///
123+
/// * `beta` - scaling factor for the destination tensor.
124+
///
125+
/// * `dx_desc` - tensor descriptor for the input differential.
126+
///
127+
/// * `dx` - data for the input differential.
128+
///
129+
/// # Errors
130+
///
131+
/// Returns errors if the shapes of the `dx` and `x` tensors do not match, the strides of the
132+
/// tensors and their differential do not match, or an unsupported configuration of arguments
133+
/// is detected.
134+
pub fn activation_backward<CompT, T>(
135+
&self,
136+
activation_desc: &ActivationDescriptor,
137+
alpha: CompT,
138+
y_desc: &TensorDescriptor<T>,
139+
y: &impl GpuBuffer<T>,
140+
dy_desc: &TensorDescriptor<T>,
141+
dy: &impl GpuBuffer<T>,
142+
x_desc: &TensorDescriptor<T>,
143+
x: &impl GpuBuffer<T>,
144+
beta: CompT,
145+
dx_desc: &TensorDescriptor<T>,
146+
dx: &mut impl GpuBuffer<T>,
147+
) -> Result<(), CudnnError>
148+
where
149+
CompT: SupportedActBwd<T>,
150+
T: DataType,
151+
{
152+
let alpha_ptr = &alpha as *const CompT as *const _;
153+
154+
let y_ptr = y.as_device_ptr().as_ptr() as *const _;
155+
let dy_ptr = dy.as_device_ptr().as_ptr() as *const _;
156+
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
157+
158+
let beta_ptr = &beta as *const CompT as *const _;
159+
160+
let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _;
161+
162+
unsafe {
163+
sys::cudnnActivationBackward(
164+
self.raw,
165+
activation_desc.raw,
166+
alpha_ptr,
167+
y_desc.raw,
168+
y_ptr,
169+
dy_desc.raw,
170+
dy_ptr,
171+
x_desc.raw,
172+
x_ptr,
173+
beta_ptr,
174+
dx_desc.raw,
175+
dx_ptr,
176+
)
177+
.into_result()
178+
}
179+
}
180+
}
181+
182+
/// Supported data type configurations for the activation forward operation.
183+
pub trait SupportedActFwd<T>: DataType + private::Sealed
184+
where
185+
T: DataType,
186+
{
187+
}
188+
189+
impl SupportedActFwd<i8> for f32 {}
190+
impl SupportedActFwd<u8> for f32 {}
191+
impl SupportedActFwd<i32> for f32 {}
192+
impl SupportedActFwd<i64> for f32 {}
193+
impl SupportedActFwd<f32> for f32 {}
194+
impl SupportedActFwd<f64> for f32 {}
195+
196+
impl SupportedActFwd<i8> for f64 {}
197+
impl SupportedActFwd<u8> for f64 {}
198+
impl SupportedActFwd<i32> for f64 {}
199+
impl SupportedActFwd<i64> for f64 {}
200+
impl SupportedActFwd<f32> for f64 {}
201+
impl SupportedActFwd<f64> for f64 {}
202+
203+
/// Supported type configurations for the activation backward operation.
204+
pub trait SupportedActBwd<T>: DataType + private::Sealed
205+
where
206+
T: DataType,
207+
{
208+
}
209+
210+
impl SupportedActBwd<f32> for f32 {}
211+
impl SupportedActBwd<f64> for f64 {}

crates/cudnn/src/attention/attention_descriptor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ bitflags::bitflags! {
2121
}
2222

2323
/// A multi-head attention descriptor.
24+
#[derive(Debug, PartialEq, Eq, Hash)]
2425
pub struct AttentionDescriptor<T, U, D1, D2>
2526
where
2627
T: SeqDataType,

crates/cudnn/src/convolution/convolution_descriptor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{marker::PhantomData, mem::MaybeUninit};
99
/// **Do note** that N can be either 2 or 3, respectively for a 2-d or a 3-d convolution, and that
1010
/// the same convolution descriptor can be reused in the backward path provided it corresponds to
1111
/// the same layer.
12-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12+
#[derive(Debug, PartialEq, Eq, Hash)]
1313
pub struct ConvDescriptor<T: DataType> {
1414
pub(crate) raw: sys::cudnnConvolutionDescriptor_t,
1515
comp_type: PhantomData<T>,

crates/cudnn/src/convolution/filter_descriptor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{
55
};
66

77
/// A generic description of an n-dimensional filter dataset.
8-
#[derive(Debug, Clone, PartialEq, Hash)]
8+
#[derive(Debug, PartialEq, Eq, Hash)]
99
pub struct FilterDescriptor<T>
1010
where
1111
T: DataType,

crates/cudnn/src/dropout/dropout_descriptor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::{error::CudnnError, sys, IntoResult};
22
use cust::memory::GpuBuffer;
33

44
/// The descriptor of a dropout operation.
5+
#[derive(Debug, PartialEq, Eq, Hash)]
56
pub struct DropoutDescriptor<T>
67
where
78
T: GpuBuffer<u8>,

crates/cudnn/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![allow(warnings, clippy::all)]
22
mod sys;
33

4+
mod activation;
45
mod attention;
56
mod backend;
67
mod context;
@@ -11,13 +12,14 @@ mod dropout;
1112
mod error;
1213
mod math_type;
1314
mod nan_propagation;
14-
mod op_tensor;
15+
mod op;
1516
mod pooling;
1617
mod rnn;
1718
mod softmax;
1819
mod tensor;
1920
mod w_grad_mode;
2021

22+
pub use activation::*;
2123
pub use attention::*;
2224
pub use context::*;
2325
pub use convolution::*;
@@ -27,7 +29,7 @@ pub use dropout::*;
2729
pub use error::*;
2830
pub use math_type::*;
2931
pub use nan_propagation::*;
30-
pub use op_tensor::*;
32+
pub use op::*;
3133
pub use pooling::*;
3234
pub use rnn::*;
3335
pub use softmax::*;
File renamed without changes.

0 commit comments

Comments
 (0)