Skip to content

Feat: Add pooling forward and backward passes #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/cudnn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod error;
mod math_type;
mod nan_propagation;
mod op_tensor;
mod pooling;
mod rnn;
mod softmax;
mod tensor;
Expand All @@ -27,6 +28,7 @@ pub use error::*;
pub use math_type::*;
pub use nan_propagation::*;
pub use op_tensor::*;
pub use pooling::*;
pub use rnn::*;
pub use softmax::*;
pub use tensor::*;
Expand Down
177 changes: 177 additions & 0 deletions crates/cudnn/src/pooling/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
mod pooling_descriptor;
mod pooling_mode;

use cust::memory::GpuBuffer;
pub use pooling_descriptor::*;
pub use pooling_mode::*;

use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};

impl CudnnContext {
/// This function computes the pooling of the input tensor and produces a smaller tensor in
/// output.
///
/// # Arguments
///
/// * `pooling_desc` - descriptor of the pooling operation.
///
/// * `alpha` - scaling factor for the result.
///
/// * `x_desc` - descriptor for the input tensor.
///
/// * `x` - data for the input tensor.
///
/// * `beta` - scaling factor for the destination tensor.
///
/// * `y_desc` - descriptor for the destination tensor.
///
/// * `y` - data for the destination tensor.
///
/// # Errors
///
/// Returns errors if the batch size or channels dimensions of the two tensor differ or an
/// invalid combination of arguments is detected.
pub fn pooling_forward<T, U>(
&self,
pooling_desc: &PoolingDescriptor,
alpha: T,
x_desc: &TensorDescriptor<U>,
x: &impl GpuBuffer<U>,
beta: T,
y_desc: &TensorDescriptor<U>,
y: &mut impl GpuBuffer<U>,
) -> Result<(), CudnnError>
where
T: SupportedPoolFwd<U>,
U: DataType,
{
let alpha_ptr = &alpha as *const T as *const _;
let x_ptr = x.as_device_ptr().as_ptr() as *const _;

let beta_ptr = &beta as *const T as *const _;
let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _;

unsafe {
sys::cudnnPoolingForward(
self.raw,
pooling_desc.raw,
alpha_ptr,
x_desc.raw,
x_ptr,
beta_ptr,
y_desc.raw,
y_ptr,
)
.into_result()
}
}

/// Computes the gradient of a pooling operation.
///
/// # Arguments
///
/// * `pooling_desc` - descriptor of the pooling operation.
///
/// * `alpha` - scaling factor for the result.
///
/// * `y_desc` - tensor descriptor for the output map.
///
/// * `y` - data for the output map.
///
/// * `dy_desc` - tensor descriptor for the differential of the output map.
///
/// * `dy` - data foe the differential of the output map.
///
/// * `x_desc` - tensor descriptor for the dropout input.
///
/// * `x` - data for the dropout input.
///
/// * `beta` - scaling factor for the destination tensor.
///
/// * `dx_desc` - tensor descriptor for the input differential.
///
/// * `dx` - data for the input differential.
///
/// # Errors
///
/// Returns errors if the dimensions or the strides of `y` and `dy` tensors differ or if the
/// dimensions or the strides of `x` and `dx` tensors differ or if an unsupported combination
/// of arguments is detected.
pub fn pooling_backward<T, U>(
&self,
pooling_desc: &PoolingDescriptor,
alpha: T,
y_desc: &TensorDescriptor<U>,
y: &impl GpuBuffer<U>,
dy_desc: &TensorDescriptor<U>,
dy: &impl GpuBuffer<U>,
x_desc: &TensorDescriptor<U>,
x: &impl GpuBuffer<U>,
beta: T,
dx_desc: &TensorDescriptor<U>,
dx: &mut impl GpuBuffer<U>,
) -> Result<(), CudnnError>
where
T: SupportedPoolBwd<U>,
U: DataType,
{
let alpha_ptr = &alpha as *const T as *const _;

let y_ptr = y.as_device_ptr().as_ptr() as *const _;
let dy_ptr = dy.as_device_ptr().as_ptr() as *const _;
let x_ptr = x.as_device_ptr().as_ptr() as *const _;

let beta_ptr = &beta as *const T as *const _;

let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _;

unsafe {
sys::cudnnPoolingBackward(
self.raw,
pooling_desc.raw,
alpha_ptr,
y_desc.raw,
y_ptr,
dy_desc.raw,
dy_ptr,
x_desc.raw,
x_ptr,
beta_ptr,
dx_desc.raw,
dx_ptr,
)
.into_result()
}
}
}

/// Supported data type configurations for the pooling forward operation.
pub trait SupportedPoolFwd<T>: DataType + private::Sealed
where
T: DataType,
{
}

impl SupportedPoolFwd<i8> for f32 {}
impl SupportedPoolFwd<u8> for f32 {}
impl SupportedPoolFwd<i32> for f32 {}
impl SupportedPoolFwd<i64> for f32 {}
impl SupportedPoolFwd<f32> for f32 {}
impl SupportedPoolFwd<f64> for f32 {}

impl SupportedPoolFwd<i8> for f64 {}
impl SupportedPoolFwd<u8> for f64 {}
impl SupportedPoolFwd<i32> for f64 {}
impl SupportedPoolFwd<i64> for f64 {}
impl SupportedPoolFwd<f32> for f64 {}
impl SupportedPoolFwd<f64> for f64 {}

/// Supported type configurations for the pooling backward operation.
pub trait SupportedPoolBwd<T>: DataType + private::Sealed
where
T: DataType,
{
}

impl SupportedPoolBwd<f32> for f32 {}
impl SupportedPoolBwd<f64> for f64 {}
87 changes: 87 additions & 0 deletions crates/cudnn/src/pooling/pooling_descriptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use crate::{sys, CudnnError, IntoResult, NanPropagation, PoolingMode};
use std::mem::MaybeUninit;

/// The descriptor of a pooling operation.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct PoolingDescriptor {
pub(crate) raw: sys::cudnnPoolingDescriptor_t,
}

impl PoolingDescriptor {
/// Creates a new pooling descriptor.
///
/// # Arguments
///
/// * `mode` - pooling mode.
///
/// * `nan_opt` - nan propagation policy.
///
/// * `window_shape` - shape of the pooling window.
///
/// * `padding` - padding size for each dimension. Negative padding is allowed.
///
/// * `stride` - stride for each dimension.
///
/// # Errors
///
/// Returns errors if an invalid configuration of arguments is detected.
///
/// # Examples
///
///
/// ```
/// # use std::error::Error;
/// #
/// # fn main() -> Result<(), Box<dyn Error>> {
/// use cudnn::{CudnnContext, NanPropagation, PoolingDescriptor, PoolingMode};
///
/// let ctx = CudnnContext::new()?;
///
/// let mode = PoolingMode::Max;
/// let nan_opt = NanPropagation::PropagateNaN;
/// let window_shape = [2, 2];
/// let padding = [0, 0];
/// let stride = [1, 1];
///
/// let pooling_desc = PoolingDescriptor::new(mode, nan_opt, window_shape, padding, stride)?;
///
/// # Ok(())
/// # }
/// ```
pub fn new<const N: usize>(
mode: PoolingMode,
nan_opt: NanPropagation,
window_shape: [i32; N],
padding: [i32; N],
stride: [i32; N],
) -> Result<Self, CudnnError> {
let mut raw = MaybeUninit::uninit();

unsafe {
sys::cudnnCreatePoolingDescriptor(raw.as_mut_ptr()).into_result()?;

let mut raw = raw.assume_init();

sys::cudnnSetPoolingNdDescriptor(
raw,
mode.into(),
nan_opt.into(),
N as i32,
window_shape.as_ptr(),
padding.as_ptr(),
stride.as_ptr(),
)
.into_result()?;

Ok(Self { raw })
}
}
}

impl Drop for PoolingDescriptor {
fn drop(&mut self) {
unsafe {
sys::cudnnDestroyPoolingDescriptor(self.raw);
}
}
}
27 changes: 27 additions & 0 deletions crates/cudnn/src/pooling/pooling_mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::sys;

/// Specifies the pooling method.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PoolingMode {
/// The maximum value inside the pooling window is used.
Max,
/// Values inside the pooling window are averaged. The number of elements used to calculate
/// the average includes spatial locations falling in the padding region.
AvgIncludePadding,
/// Values inside the pooling window are averaged. The number of elements used to calculate
/// the average excludes spatial locations falling in the padding region.
AvgExcludePadding,
/// The maximum value inside the pooling window is used. The algorithm used is deterministic.
MaxDeterministic,
}

impl From<PoolingMode> for sys::cudnnPoolingMode_t {
fn from(mode: PoolingMode) -> Self {
match mode {
PoolingMode::Max => Self::CUDNN_POOLING_MAX,
PoolingMode::AvgExcludePadding => Self::CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING,
PoolingMode::AvgIncludePadding => Self::CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,
PoolingMode::MaxDeterministic => Self::CUDNN_POOLING_MAX_DETERMINISTIC,
}
}
}
12 changes: 9 additions & 3 deletions crates/cudnn/src/softmax/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod softmax_mode;
pub use softmax_algo::*;
pub use softmax_mode::*;

use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, SupportedOp, TensorDescriptor};
use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};
use cust::memory::GpuBuffer;

impl CudnnContext {
Expand Down Expand Up @@ -45,7 +45,7 @@ impl CudnnContext {
) -> Result<(), CudnnError>
where
T: DataType,
CompT: SupportedOp<T, T, T>,
CompT: SupportedSoftmax<T>,
{
let alpha_ptr = &alpha as *const CompT as *const _;
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
Expand Down Expand Up @@ -112,7 +112,7 @@ impl CudnnContext {
) -> Result<(), CudnnError>
where
T: DataType,
CompT: SupportedOp<T, T, T>,
CompT: SupportedSoftmax<T>,
{
let alpha_ptr = &alpha as *const CompT as *const _;
let y_ptr = y.as_device_ptr().as_ptr() as *const _;
Expand Down Expand Up @@ -140,3 +140,9 @@ impl CudnnContext {
}
}
}

/// Supported data type configurations for softmax operations.
pub trait SupportedSoftmax<T>: DataType + private::Sealed {}

impl SupportedSoftmax<f32> for f32 {}
impl SupportedSoftmax<f64> for f64 {}