Skip to content

Feat: Add activation forward and backward #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The current line-up of libraries is the following:
- `cuda_std` for GPU-side functions and utilities, such as thread index queries, memory allocation, warp intrinsics, etc.
- *Not* a low level library, provides many utility functions to make it easier to write cleaner and more reliable GPU kernels.
- Closely tied to `rustc_codegen_nvvm` which exposes GPU features through it internally.
- [`cudnn`](https://github.com/Rust-GPU/Rust-CUDA/tree/master/crates/cudnn) for a collection of GPU-accelerated primitives for deep neural networks.
- `cust` for CPU-side CUDA features such as launching GPU kernels, GPU memory allocation, device queries, etc.
- High level with features such as RAII and Rust Results that make it easier and cleaner to manage the interface to the GPU.
- A high level wrapper for the CUDA Driver API, the lower level version of the more common CUDA Runtime API used from C++.
Expand Down
71 changes: 71 additions & 0 deletions crates/cudnn/src/activation/activation_descriptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use crate::{sys, ActivationMode, CudnnContext, CudnnError, IntoResult, NanPropagation};
use std::mem::MaybeUninit;

/// The descriptor of a neuron activation operation.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ActivationDescriptor {
pub(crate) raw: sys::cudnnActivationDescriptor_t,
}

impl ActivationDescriptor {
/// Creates a new neuron activation descriptor.
///
/// # Arguments
///
/// * `mode` - activation function to compute.
///
/// * `nan_opt` - NaN propagation policy for the operation.
///
/// * `coefficient` - optional coefficient for the given function. It specifies the clipping
/// threshold for `ActivationMode::ClippedRelu`.
///
/// # Examples
///
/// ```
/// # use std::error::Error;
/// #
/// # fn main() -> Result<(), Box<dyn Error>> {
/// use cudnn::{ActivationDescriptor, ActivationMode, CudnnContext, NanPropagation};
///
/// let ctx = CudnnContext::new()?;
///
/// let mode = ActivationMode::Swish;
/// let nan_opt = NanPropagation::PropagateNaN;
/// let coefficient = None;
///
/// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?;
/// # Ok(())
/// # }
/// ```
pub fn new(
mode: ActivationMode,
nan_opt: NanPropagation,
coefficient: impl Into<Option<f64>>,
) -> Result<Self, CudnnError> {
let mut raw = MaybeUninit::uninit();

unsafe {
sys::cudnnCreateActivationDescriptor(raw.as_mut_ptr()).into_result()?;

let mut raw = raw.assume_init();

let coefficient = coefficient.into().unwrap_or_else(|| match mode {
ActivationMode::ClippedRelu => std::f64::MAX,
_ => 1.0,
});

sys::cudnnSetActivationDescriptor(raw, mode.into(), nan_opt.into(), coefficient)
.into_result()?;

Ok(Self { raw })
}
}
}

impl Drop for ActivationDescriptor {
fn drop(&mut self) {
unsafe {
sys::cudnnDestroyActivationDescriptor(self.raw);
}
}
}
31 changes: 31 additions & 0 deletions crates/cudnn/src/activation/activation_mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::sys;

/// Specifies a neuron activation function.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ActivationMode {
/// Selects the sigmoid function.
Sigmoid,
/// Selects the rectified linear function.
Relu,
/// Selects the hyperbolic tangent function.
Tanh,
/// Selects the clipped rectified linear function.
ClippedRelu,
/// Selects the exponential linear function.
Elu,
/// Selects the swish function.
Swish,
}

impl From<ActivationMode> for sys::cudnnActivationMode_t {
fn from(mode: ActivationMode) -> Self {
match mode {
ActivationMode::Sigmoid => Self::CUDNN_ACTIVATION_SIGMOID,
ActivationMode::Relu => Self::CUDNN_ACTIVATION_RELU,
ActivationMode::Tanh => Self::CUDNN_ACTIVATION_TANH,
ActivationMode::ClippedRelu => Self::CUDNN_ACTIVATION_CLIPPED_RELU,
ActivationMode::Elu => Self::CUDNN_ACTIVATION_ELU,
ActivationMode::Swish => Self::CUDNN_ACTIVATION_SWISH,
}
}
}
211 changes: 211 additions & 0 deletions crates/cudnn/src/activation/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
mod activation_descriptor;
mod activation_mode;

pub use activation_descriptor::*;
pub use activation_mode::*;

use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};
use cust::memory::GpuBuffer;
use std::mem::MaybeUninit;

impl CudnnContext {
/// Applies a specific neuron activation functions element wise of the provided tensor.
///
/// # Arguments
///
/// * `activation_desc` - activation descriptor.
///
/// * `alpha` - scaling factor for the result.
///
/// * `x_desc` - tensor descriptor for the input.
///
/// * `x` - data for the input tensor.
///
/// * `beta` - scaling factor for the destination tensor.
///
/// * `y_desc` - tensor descriptor for the output.
///
/// * `y` - data for the output.
///
/// # Errors
///
/// Returns errors if the shapes of the `y` and `x` tensors do not match or an unsupported
/// configuration of arguments is detected.
///
/// # Examples
///
/// ```
/// # use std::error::Error;
/// #
/// # fn main() -> Result<(), Box<dyn Error>> {
/// use cudnn::{ActivationDescriptor, ActivationMode, CudnnContext, NanPropagation, TensorDescriptor};
/// use cust::memory::DeviceBuffer;
///
/// let ctx = CudnnContext::new()?;
///
/// let mode = ActivationMode::Swish;
/// let nan_opt = NanPropagation::PropagateNaN;
/// let coefficient = None;
///
/// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?;
///
/// let alpha: f32 = 1.0;
/// let x_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
/// let x = DeviceBuffer::<i8>::from_slice(&[10, 10, 10, 10, 10])?;
///
/// let beta: f32 = 0.0;
/// let y_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
/// let mut y = DeviceBuffer::<i8>::from_slice(&[0, 0, 0, 0, 0])?;
///
/// ctx.activation_forward(&desc, alpha, &x_desc, &x, beta, &y_desc, &mut y)?;
///
/// let y_host = y.as_host_vec()?;
///
/// assert!(y_host.iter().all(|el| *el == 10));
/// # Ok(())
/// # }
/// ```
pub fn activation_forward<CompT, T>(
&self,
activation_desc: &ActivationDescriptor,
alpha: CompT,
x_desc: &TensorDescriptor<T>,
x: &impl GpuBuffer<T>,
beta: CompT,
y_desc: &TensorDescriptor<T>,
y: &mut impl GpuBuffer<T>,
) -> Result<(), CudnnError>
where
CompT: SupportedActFwd<T>,
T: DataType,
{
let alpha_ptr = &alpha as *const CompT as *const _;
let x_ptr = x.as_device_ptr().as_ptr() as *const _;

let beta_ptr = &beta as *const CompT as *const _;
let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _;

unsafe {
sys::cudnnActivationForward(
self.raw,
activation_desc.raw,
alpha_ptr,
x_desc.raw,
x_ptr,
beta_ptr,
y_desc.raw,
y_ptr,
)
.into_result()
}
}

/// Computes the gradient of a neuron activation function.
///
/// # Arguments
///
/// * `activation_descriptor` - descriptor of a neuron activation operation.
///
/// * `alpha` - scaling factor for the result.
///
/// * `y_desc` - tensor descriptor for the output map.
///
/// * `y` - data for the output map.
///
/// * `dy_desc` - tensor descriptor for the differential of the output map.
///
/// * `dy` - data foe the differential of the output map.
///
/// * `x_desc` - tensor descriptor for the activation input.
///
/// * `x` - data for the activation input.
///
/// * `beta` - scaling factor for the destination tensor.
///
/// * `dx_desc` - tensor descriptor for the input differential.
///
/// * `dx` - data for the input differential.
///
/// # Errors
///
/// Returns errors if the shapes of the `dx` and `x` tensors do not match, the strides of the
/// tensors and their differential do not match, or an unsupported configuration of arguments
/// is detected.
pub fn activation_backward<CompT, T>(
&self,
activation_desc: &ActivationDescriptor,
alpha: CompT,
y_desc: &TensorDescriptor<T>,
y: &impl GpuBuffer<T>,
dy_desc: &TensorDescriptor<T>,
dy: &impl GpuBuffer<T>,
x_desc: &TensorDescriptor<T>,
x: &impl GpuBuffer<T>,
beta: CompT,
dx_desc: &TensorDescriptor<T>,
dx: &mut impl GpuBuffer<T>,
) -> Result<(), CudnnError>
where
CompT: SupportedActBwd<T>,
T: DataType,
{
let alpha_ptr = &alpha as *const CompT as *const _;

let y_ptr = y.as_device_ptr().as_ptr() as *const _;
let dy_ptr = dy.as_device_ptr().as_ptr() as *const _;
let x_ptr = x.as_device_ptr().as_ptr() as *const _;

let beta_ptr = &beta as *const CompT as *const _;

let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _;

unsafe {
sys::cudnnActivationBackward(
self.raw,
activation_desc.raw,
alpha_ptr,
y_desc.raw,
y_ptr,
dy_desc.raw,
dy_ptr,
x_desc.raw,
x_ptr,
beta_ptr,
dx_desc.raw,
dx_ptr,
)
.into_result()
}
}
}

/// Supported data type configurations for the activation forward operation.
pub trait SupportedActFwd<T>: DataType + private::Sealed
where
T: DataType,
{
}

impl SupportedActFwd<i8> for f32 {}
impl SupportedActFwd<u8> for f32 {}
impl SupportedActFwd<i32> for f32 {}
impl SupportedActFwd<i64> for f32 {}
impl SupportedActFwd<f32> for f32 {}
impl SupportedActFwd<f64> for f32 {}

impl SupportedActFwd<i8> for f64 {}
impl SupportedActFwd<u8> for f64 {}
impl SupportedActFwd<i32> for f64 {}
impl SupportedActFwd<i64> for f64 {}
impl SupportedActFwd<f32> for f64 {}
impl SupportedActFwd<f64> for f64 {}

/// Supported type configurations for the activation backward operation.
pub trait SupportedActBwd<T>: DataType + private::Sealed
where
T: DataType,
{
}

impl SupportedActBwd<f32> for f32 {}
impl SupportedActBwd<f64> for f64 {}
1 change: 1 addition & 0 deletions crates/cudnn/src/attention/attention_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ bitflags::bitflags! {
}

/// A multi-head attention descriptor.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct AttentionDescriptor<T, U, D1, D2>
where
T: SeqDataType,
Expand Down
2 changes: 1 addition & 1 deletion crates/cudnn/src/convolution/convolution_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{marker::PhantomData, mem::MaybeUninit};
/// **Do note** that N can be either 2 or 3, respectively for a 2-d or a 3-d convolution, and that
/// the same convolution descriptor can be reused in the backward path provided it corresponds to
/// the same layer.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ConvDescriptor<T: DataType> {
pub(crate) raw: sys::cudnnConvolutionDescriptor_t,
comp_type: PhantomData<T>,
Expand Down
2 changes: 1 addition & 1 deletion crates/cudnn/src/convolution/filter_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

/// A generic description of an n-dimensional filter dataset.
#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct FilterDescriptor<T>
where
T: DataType,
Expand Down
1 change: 1 addition & 0 deletions crates/cudnn/src/dropout/dropout_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{error::CudnnError, sys, IntoResult};
use cust::memory::GpuBuffer;

/// The descriptor of a dropout operation.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct DropoutDescriptor<T>
where
T: GpuBuffer<u8>,
Expand Down
6 changes: 4 additions & 2 deletions crates/cudnn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(warnings, clippy::all)]
mod sys;

mod activation;
mod attention;
mod backend;
mod context;
Expand All @@ -11,13 +12,14 @@ mod dropout;
mod error;
mod math_type;
mod nan_propagation;
mod op_tensor;
mod op;
mod pooling;
mod rnn;
mod softmax;
mod tensor;
mod w_grad_mode;

pub use activation::*;
pub use attention::*;
pub use context::*;
pub use convolution::*;
Expand All @@ -27,7 +29,7 @@ pub use dropout::*;
pub use error::*;
pub use math_type::*;
pub use nan_propagation::*;
pub use op_tensor::*;
pub use op::*;
pub use pooling::*;
pub use rnn::*;
pub use softmax::*;
Expand Down
File renamed without changes.
Loading