Skip to content

Commit 17b9a60

Browse files
authored
Feat: Add pooling forward and backward (#55)
1 parent 321110e commit 17b9a60

File tree

5 files changed

+302
-3
lines changed

5 files changed

+302
-3
lines changed

crates/cudnn/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod error;
1212
mod math_type;
1313
mod nan_propagation;
1414
mod op_tensor;
15+
mod pooling;
1516
mod rnn;
1617
mod softmax;
1718
mod tensor;
@@ -27,6 +28,7 @@ pub use error::*;
2728
pub use math_type::*;
2829
pub use nan_propagation::*;
2930
pub use op_tensor::*;
31+
pub use pooling::*;
3032
pub use rnn::*;
3133
pub use softmax::*;
3234
pub use tensor::*;

crates/cudnn/src/pooling/mod.rs

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
mod pooling_descriptor;
2+
mod pooling_mode;
3+
4+
use cust::memory::GpuBuffer;
5+
pub use pooling_descriptor::*;
6+
pub use pooling_mode::*;
7+
8+
use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};
9+
10+
impl CudnnContext {
11+
/// This function computes the pooling of the input tensor and produces a smaller tensor in
12+
/// output.
13+
///
14+
/// # Arguments
15+
///
16+
/// * `pooling_desc` - descriptor of the pooling operation.
17+
///
18+
/// * `alpha` - scaling factor for the result.
19+
///
20+
/// * `x_desc` - descriptor for the input tensor.
21+
///
22+
/// * `x` - data for the input tensor.
23+
///
24+
/// * `beta` - scaling factor for the destination tensor.
25+
///
26+
/// * `y_desc` - descriptor for the destination tensor.
27+
///
28+
/// * `y` - data for the destination tensor.
29+
///
30+
/// # Errors
31+
///
32+
/// Returns errors if the batch size or channels dimensions of the two tensor differ or an
33+
/// invalid combination of arguments is detected.
34+
pub fn pooling_forward<T, U>(
35+
&self,
36+
pooling_desc: &PoolingDescriptor,
37+
alpha: T,
38+
x_desc: &TensorDescriptor<U>,
39+
x: &impl GpuBuffer<U>,
40+
beta: T,
41+
y_desc: &TensorDescriptor<U>,
42+
y: &mut impl GpuBuffer<U>,
43+
) -> Result<(), CudnnError>
44+
where
45+
T: SupportedPoolFwd<U>,
46+
U: DataType,
47+
{
48+
let alpha_ptr = &alpha as *const T as *const _;
49+
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
50+
51+
let beta_ptr = &beta as *const T as *const _;
52+
let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _;
53+
54+
unsafe {
55+
sys::cudnnPoolingForward(
56+
self.raw,
57+
pooling_desc.raw,
58+
alpha_ptr,
59+
x_desc.raw,
60+
x_ptr,
61+
beta_ptr,
62+
y_desc.raw,
63+
y_ptr,
64+
)
65+
.into_result()
66+
}
67+
}
68+
69+
/// Computes the gradient of a pooling operation.
70+
///
71+
/// # Arguments
72+
///
73+
/// * `pooling_desc` - descriptor of the pooling operation.
74+
///
75+
/// * `alpha` - scaling factor for the result.
76+
///
77+
/// * `y_desc` - tensor descriptor for the output map.
78+
///
79+
/// * `y` - data for the output map.
80+
///
81+
/// * `dy_desc` - tensor descriptor for the differential of the output map.
82+
///
83+
/// * `dy` - data foe the differential of the output map.
84+
///
85+
/// * `x_desc` - tensor descriptor for the dropout input.
86+
///
87+
/// * `x` - data for the dropout input.
88+
///
89+
/// * `beta` - scaling factor for the destination tensor.
90+
///
91+
/// * `dx_desc` - tensor descriptor for the input differential.
92+
///
93+
/// * `dx` - data for the input differential.
94+
///
95+
/// # Errors
96+
///
97+
/// Returns errors if the dimensions or the strides of `y` and `dy` tensors differ or if the
98+
/// dimensions or the strides of `x` and `dx` tensors differ or if an unsupported combination
99+
/// of arguments is detected.
100+
pub fn pooling_backward<T, U>(
101+
&self,
102+
pooling_desc: &PoolingDescriptor,
103+
alpha: T,
104+
y_desc: &TensorDescriptor<U>,
105+
y: &impl GpuBuffer<U>,
106+
dy_desc: &TensorDescriptor<U>,
107+
dy: &impl GpuBuffer<U>,
108+
x_desc: &TensorDescriptor<U>,
109+
x: &impl GpuBuffer<U>,
110+
beta: T,
111+
dx_desc: &TensorDescriptor<U>,
112+
dx: &mut impl GpuBuffer<U>,
113+
) -> Result<(), CudnnError>
114+
where
115+
T: SupportedPoolBwd<U>,
116+
U: DataType,
117+
{
118+
let alpha_ptr = &alpha as *const T as *const _;
119+
120+
let y_ptr = y.as_device_ptr().as_ptr() as *const _;
121+
let dy_ptr = dy.as_device_ptr().as_ptr() as *const _;
122+
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
123+
124+
let beta_ptr = &beta as *const T as *const _;
125+
126+
let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _;
127+
128+
unsafe {
129+
sys::cudnnPoolingBackward(
130+
self.raw,
131+
pooling_desc.raw,
132+
alpha_ptr,
133+
y_desc.raw,
134+
y_ptr,
135+
dy_desc.raw,
136+
dy_ptr,
137+
x_desc.raw,
138+
x_ptr,
139+
beta_ptr,
140+
dx_desc.raw,
141+
dx_ptr,
142+
)
143+
.into_result()
144+
}
145+
}
146+
}
147+
148+
/// Supported data type configurations for the pooling forward operation.
149+
pub trait SupportedPoolFwd<T>: DataType + private::Sealed
150+
where
151+
T: DataType,
152+
{
153+
}
154+
155+
impl SupportedPoolFwd<i8> for f32 {}
156+
impl SupportedPoolFwd<u8> for f32 {}
157+
impl SupportedPoolFwd<i32> for f32 {}
158+
impl SupportedPoolFwd<i64> for f32 {}
159+
impl SupportedPoolFwd<f32> for f32 {}
160+
impl SupportedPoolFwd<f64> for f32 {}
161+
162+
impl SupportedPoolFwd<i8> for f64 {}
163+
impl SupportedPoolFwd<u8> for f64 {}
164+
impl SupportedPoolFwd<i32> for f64 {}
165+
impl SupportedPoolFwd<i64> for f64 {}
166+
impl SupportedPoolFwd<f32> for f64 {}
167+
impl SupportedPoolFwd<f64> for f64 {}
168+
169+
/// Supported type configurations for the pooling backward operation.
170+
pub trait SupportedPoolBwd<T>: DataType + private::Sealed
171+
where
172+
T: DataType,
173+
{
174+
}
175+
176+
impl SupportedPoolBwd<f32> for f32 {}
177+
impl SupportedPoolBwd<f64> for f64 {}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use crate::{sys, CudnnError, IntoResult, NanPropagation, PoolingMode};
2+
use std::mem::MaybeUninit;
3+
4+
/// The descriptor of a pooling operation.
5+
#[derive(Debug, PartialEq, Eq, Hash)]
6+
pub struct PoolingDescriptor {
7+
pub(crate) raw: sys::cudnnPoolingDescriptor_t,
8+
}
9+
10+
impl PoolingDescriptor {
11+
/// Creates a new pooling descriptor.
12+
///
13+
/// # Arguments
14+
///
15+
/// * `mode` - pooling mode.
16+
///
17+
/// * `nan_opt` - nan propagation policy.
18+
///
19+
/// * `window_shape` - shape of the pooling window.
20+
///
21+
/// * `padding` - padding size for each dimension. Negative padding is allowed.
22+
///
23+
/// * `stride` - stride for each dimension.
24+
///
25+
/// # Errors
26+
///
27+
/// Returns errors if an invalid configuration of arguments is detected.
28+
///
29+
/// # Examples
30+
///
31+
///
32+
/// ```
33+
/// # use std::error::Error;
34+
/// #
35+
/// # fn main() -> Result<(), Box<dyn Error>> {
36+
/// use cudnn::{CudnnContext, NanPropagation, PoolingDescriptor, PoolingMode};
37+
///
38+
/// let ctx = CudnnContext::new()?;
39+
///
40+
/// let mode = PoolingMode::Max;
41+
/// let nan_opt = NanPropagation::PropagateNaN;
42+
/// let window_shape = [2, 2];
43+
/// let padding = [0, 0];
44+
/// let stride = [1, 1];
45+
///
46+
/// let pooling_desc = PoolingDescriptor::new(mode, nan_opt, window_shape, padding, stride)?;
47+
///
48+
/// # Ok(())
49+
/// # }
50+
/// ```
51+
pub fn new<const N: usize>(
52+
mode: PoolingMode,
53+
nan_opt: NanPropagation,
54+
window_shape: [i32; N],
55+
padding: [i32; N],
56+
stride: [i32; N],
57+
) -> Result<Self, CudnnError> {
58+
let mut raw = MaybeUninit::uninit();
59+
60+
unsafe {
61+
sys::cudnnCreatePoolingDescriptor(raw.as_mut_ptr()).into_result()?;
62+
63+
let mut raw = raw.assume_init();
64+
65+
sys::cudnnSetPoolingNdDescriptor(
66+
raw,
67+
mode.into(),
68+
nan_opt.into(),
69+
N as i32,
70+
window_shape.as_ptr(),
71+
padding.as_ptr(),
72+
stride.as_ptr(),
73+
)
74+
.into_result()?;
75+
76+
Ok(Self { raw })
77+
}
78+
}
79+
}
80+
81+
impl Drop for PoolingDescriptor {
82+
fn drop(&mut self) {
83+
unsafe {
84+
sys::cudnnDestroyPoolingDescriptor(self.raw);
85+
}
86+
}
87+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use crate::sys;
2+
3+
/// Specifies the pooling method.
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5+
pub enum PoolingMode {
6+
/// The maximum value inside the pooling window is used.
7+
Max,
8+
/// Values inside the pooling window are averaged. The number of elements used to calculate
9+
/// the average includes spatial locations falling in the padding region.
10+
AvgIncludePadding,
11+
/// Values inside the pooling window are averaged. The number of elements used to calculate
12+
/// the average excludes spatial locations falling in the padding region.
13+
AvgExcludePadding,
14+
/// The maximum value inside the pooling window is used. The algorithm used is deterministic.
15+
MaxDeterministic,
16+
}
17+
18+
impl From<PoolingMode> for sys::cudnnPoolingMode_t {
19+
fn from(mode: PoolingMode) -> Self {
20+
match mode {
21+
PoolingMode::Max => Self::CUDNN_POOLING_MAX,
22+
PoolingMode::AvgExcludePadding => Self::CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING,
23+
PoolingMode::AvgIncludePadding => Self::CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,
24+
PoolingMode::MaxDeterministic => Self::CUDNN_POOLING_MAX_DETERMINISTIC,
25+
}
26+
}
27+
}

crates/cudnn/src/softmax/mod.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ mod softmax_mode;
44
pub use softmax_algo::*;
55
pub use softmax_mode::*;
66

7-
use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, SupportedOp, TensorDescriptor};
7+
use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};
88
use cust::memory::GpuBuffer;
99

1010
impl CudnnContext {
@@ -45,7 +45,7 @@ impl CudnnContext {
4545
) -> Result<(), CudnnError>
4646
where
4747
T: DataType,
48-
CompT: SupportedOp<T, T, T>,
48+
CompT: SupportedSoftmax<T>,
4949
{
5050
let alpha_ptr = &alpha as *const CompT as *const _;
5151
let x_ptr = x.as_device_ptr().as_ptr() as *const _;
@@ -112,7 +112,7 @@ impl CudnnContext {
112112
) -> Result<(), CudnnError>
113113
where
114114
T: DataType,
115-
CompT: SupportedOp<T, T, T>,
115+
CompT: SupportedSoftmax<T>,
116116
{
117117
let alpha_ptr = &alpha as *const CompT as *const _;
118118
let y_ptr = y.as_device_ptr().as_ptr() as *const _;
@@ -140,3 +140,9 @@ impl CudnnContext {
140140
}
141141
}
142142
}
143+
144+
/// Supported data type configurations for softmax operations.
145+
pub trait SupportedSoftmax<T>: DataType + private::Sealed {}
146+
147+
impl SupportedSoftmax<f32> for f32 {}
148+
impl SupportedSoftmax<f64> for f64 {}

0 commit comments

Comments
 (0)