diff --git a/crates/cudnn/src/lib.rs b/crates/cudnn/src/lib.rs index 41bab0a1..bc682fc6 100644 --- a/crates/cudnn/src/lib.rs +++ b/crates/cudnn/src/lib.rs @@ -12,6 +12,7 @@ mod error; mod math_type; mod nan_propagation; mod op_tensor; +mod pooling; mod rnn; mod softmax; mod tensor; @@ -27,6 +28,7 @@ pub use error::*; pub use math_type::*; pub use nan_propagation::*; pub use op_tensor::*; +pub use pooling::*; pub use rnn::*; pub use softmax::*; pub use tensor::*; diff --git a/crates/cudnn/src/pooling/mod.rs b/crates/cudnn/src/pooling/mod.rs new file mode 100644 index 00000000..44f3cdcd --- /dev/null +++ b/crates/cudnn/src/pooling/mod.rs @@ -0,0 +1,177 @@ +mod pooling_descriptor; +mod pooling_mode; + +use cust::memory::GpuBuffer; +pub use pooling_descriptor::*; +pub use pooling_mode::*; + +use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; + +impl CudnnContext { + /// This function computes the pooling of the input tensor and produces a smaller tensor in + /// output. + /// + /// # Arguments + /// + /// * `pooling_desc` - descriptor of the pooling operation. + /// + /// * `alpha` - scaling factor for the result. + /// + /// * `x_desc` - descriptor for the input tensor. + /// + /// * `x` - data for the input tensor. + /// + /// * `beta` - scaling factor for the destination tensor. + /// + /// * `y_desc` - descriptor for the destination tensor. + /// + /// * `y` - data for the destination tensor. + /// + /// # Errors + /// + /// Returns errors if the batch size or channels dimensions of the two tensor differ or an + /// invalid combination of arguments is detected. + pub fn pooling_forward( + &self, + pooling_desc: &PoolingDescriptor, + alpha: T, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + beta: T, + y_desc: &TensorDescriptor, + y: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T: SupportedPoolFwd, + U: DataType, + { + let alpha_ptr = &alpha as *const T as *const _; + let x_ptr = x.as_device_ptr().as_ptr() as *const _; + + let beta_ptr = &beta as *const T as *const _; + let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnPoolingForward( + self.raw, + pooling_desc.raw, + alpha_ptr, + x_desc.raw, + x_ptr, + beta_ptr, + y_desc.raw, + y_ptr, + ) + .into_result() + } + } + + /// Computes the gradient of a pooling operation. + /// + /// # Arguments + /// + /// * `pooling_desc` - descriptor of the pooling operation. + /// + /// * `alpha` - scaling factor for the result. + /// + /// * `y_desc` - tensor descriptor for the output map. + /// + /// * `y` - data for the output map. + /// + /// * `dy_desc` - tensor descriptor for the differential of the output map. + /// + /// * `dy` - data foe the differential of the output map. + /// + /// * `x_desc` - tensor descriptor for the dropout input. + /// + /// * `x` - data for the dropout input. + /// + /// * `beta` - scaling factor for the destination tensor. + /// + /// * `dx_desc` - tensor descriptor for the input differential. + /// + /// * `dx` - data for the input differential. + /// + /// # Errors + /// + /// Returns errors if the dimensions or the strides of `y` and `dy` tensors differ or if the + /// dimensions or the strides of `x` and `dx` tensors differ or if an unsupported combination + /// of arguments is detected. + pub fn pooling_backward( + &self, + pooling_desc: &PoolingDescriptor, + alpha: T, + y_desc: &TensorDescriptor, + y: &impl GpuBuffer, + dy_desc: &TensorDescriptor, + dy: &impl GpuBuffer, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + beta: T, + dx_desc: &TensorDescriptor, + dx: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T: SupportedPoolBwd, + U: DataType, + { + let alpha_ptr = &alpha as *const T as *const _; + + let y_ptr = y.as_device_ptr().as_ptr() as *const _; + let dy_ptr = dy.as_device_ptr().as_ptr() as *const _; + let x_ptr = x.as_device_ptr().as_ptr() as *const _; + + let beta_ptr = &beta as *const T as *const _; + + let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnPoolingBackward( + self.raw, + pooling_desc.raw, + alpha_ptr, + y_desc.raw, + y_ptr, + dy_desc.raw, + dy_ptr, + x_desc.raw, + x_ptr, + beta_ptr, + dx_desc.raw, + dx_ptr, + ) + .into_result() + } + } +} + +/// Supported data type configurations for the pooling forward operation. +pub trait SupportedPoolFwd: DataType + private::Sealed +where + T: DataType, +{ +} + +impl SupportedPoolFwd for f32 {} +impl SupportedPoolFwd for f32 {} +impl SupportedPoolFwd for f32 {} +impl SupportedPoolFwd for f32 {} +impl SupportedPoolFwd for f32 {} +impl SupportedPoolFwd for f32 {} + +impl SupportedPoolFwd for f64 {} +impl SupportedPoolFwd for f64 {} +impl SupportedPoolFwd for f64 {} +impl SupportedPoolFwd for f64 {} +impl SupportedPoolFwd for f64 {} +impl SupportedPoolFwd for f64 {} + +/// Supported type configurations for the pooling backward operation. +pub trait SupportedPoolBwd: DataType + private::Sealed +where + T: DataType, +{ +} + +impl SupportedPoolBwd for f32 {} +impl SupportedPoolBwd for f64 {} diff --git a/crates/cudnn/src/pooling/pooling_descriptor.rs b/crates/cudnn/src/pooling/pooling_descriptor.rs new file mode 100644 index 00000000..4c225017 --- /dev/null +++ b/crates/cudnn/src/pooling/pooling_descriptor.rs @@ -0,0 +1,87 @@ +use crate::{sys, CudnnError, IntoResult, NanPropagation, PoolingMode}; +use std::mem::MaybeUninit; + +/// The descriptor of a pooling operation. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct PoolingDescriptor { + pub(crate) raw: sys::cudnnPoolingDescriptor_t, +} + +impl PoolingDescriptor { + /// Creates a new pooling descriptor. + /// + /// # Arguments + /// + /// * `mode` - pooling mode. + /// + /// * `nan_opt` - nan propagation policy. + /// + /// * `window_shape` - shape of the pooling window. + /// + /// * `padding` - padding size for each dimension. Negative padding is allowed. + /// + /// * `stride` - stride for each dimension. + /// + /// # Errors + /// + /// Returns errors if an invalid configuration of arguments is detected. + /// + /// # Examples + /// + /// + /// ``` + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use cudnn::{CudnnContext, NanPropagation, PoolingDescriptor, PoolingMode}; + /// + /// let ctx = CudnnContext::new()?; + /// + /// let mode = PoolingMode::Max; + /// let nan_opt = NanPropagation::PropagateNaN; + /// let window_shape = [2, 2]; + /// let padding = [0, 0]; + /// let stride = [1, 1]; + /// + /// let pooling_desc = PoolingDescriptor::new(mode, nan_opt, window_shape, padding, stride)?; + /// + /// # Ok(()) + /// # } + /// ``` + pub fn new( + mode: PoolingMode, + nan_opt: NanPropagation, + window_shape: [i32; N], + padding: [i32; N], + stride: [i32; N], + ) -> Result { + let mut raw = MaybeUninit::uninit(); + + unsafe { + sys::cudnnCreatePoolingDescriptor(raw.as_mut_ptr()).into_result()?; + + let mut raw = raw.assume_init(); + + sys::cudnnSetPoolingNdDescriptor( + raw, + mode.into(), + nan_opt.into(), + N as i32, + window_shape.as_ptr(), + padding.as_ptr(), + stride.as_ptr(), + ) + .into_result()?; + + Ok(Self { raw }) + } + } +} + +impl Drop for PoolingDescriptor { + fn drop(&mut self) { + unsafe { + sys::cudnnDestroyPoolingDescriptor(self.raw); + } + } +} diff --git a/crates/cudnn/src/pooling/pooling_mode.rs b/crates/cudnn/src/pooling/pooling_mode.rs new file mode 100644 index 00000000..58d786c2 --- /dev/null +++ b/crates/cudnn/src/pooling/pooling_mode.rs @@ -0,0 +1,27 @@ +use crate::sys; + +/// Specifies the pooling method. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum PoolingMode { + /// The maximum value inside the pooling window is used. + Max, + /// Values inside the pooling window are averaged. The number of elements used to calculate + /// the average includes spatial locations falling in the padding region. + AvgIncludePadding, + /// Values inside the pooling window are averaged. The number of elements used to calculate + /// the average excludes spatial locations falling in the padding region. + AvgExcludePadding, + /// The maximum value inside the pooling window is used. The algorithm used is deterministic. + MaxDeterministic, +} + +impl From for sys::cudnnPoolingMode_t { + fn from(mode: PoolingMode) -> Self { + match mode { + PoolingMode::Max => Self::CUDNN_POOLING_MAX, + PoolingMode::AvgExcludePadding => Self::CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING, + PoolingMode::AvgIncludePadding => Self::CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING, + PoolingMode::MaxDeterministic => Self::CUDNN_POOLING_MAX_DETERMINISTIC, + } + } +} diff --git a/crates/cudnn/src/softmax/mod.rs b/crates/cudnn/src/softmax/mod.rs index dd9a8fd0..12c0987a 100644 --- a/crates/cudnn/src/softmax/mod.rs +++ b/crates/cudnn/src/softmax/mod.rs @@ -4,7 +4,7 @@ mod softmax_mode; pub use softmax_algo::*; pub use softmax_mode::*; -use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, SupportedOp, TensorDescriptor}; +use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; use cust::memory::GpuBuffer; impl CudnnContext { @@ -45,7 +45,7 @@ impl CudnnContext { ) -> Result<(), CudnnError> where T: DataType, - CompT: SupportedOp, + CompT: SupportedSoftmax, { let alpha_ptr = &alpha as *const CompT as *const _; let x_ptr = x.as_device_ptr().as_ptr() as *const _; @@ -112,7 +112,7 @@ impl CudnnContext { ) -> Result<(), CudnnError> where T: DataType, - CompT: SupportedOp, + CompT: SupportedSoftmax, { let alpha_ptr = &alpha as *const CompT as *const _; let y_ptr = y.as_device_ptr().as_ptr() as *const _; @@ -140,3 +140,9 @@ impl CudnnContext { } } } + +/// Supported data type configurations for softmax operations. +pub trait SupportedSoftmax: DataType + private::Sealed {} + +impl SupportedSoftmax for f32 {} +impl SupportedSoftmax for f64 {}