Skip to content

Commit 8a32ee7

Browse files
authored
feat: Refactor and simplify code, bump to cuDNN 8.3.2, add backend API draft and complete tests (#51)
* Relax compile time bounds on layouts * Separate impl blocks for CudnnContext * Add tests for op tensor * Add doc tests for Dropout and Convolution * Bump to CUDNN v. 8.3.2 * Add backend API proof of concept * Simplify code
1 parent 3a4c7ea commit 8a32ee7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+5187
-3053
lines changed

crates/cudnn/bindgen.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
bindgen "/usr/include/cudnn.h" \
1+
bindgen "${HOME}/local/include/cudnn.h" \
22
--size_t-is-usize \
33
--allowlist-type "cudnn.*" \
44
--allowlist-function "cudnn.*" \
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use crate::{
2+
backend::{ConvCfg, Descriptor, FloatDataType, Operation, Real, Tensor},
3+
sys, CudnnError, DataType, IntoResult,
4+
};
5+
6+
pub struct ConvBwdDataBuilder {
7+
cfg: Option<ConvCfg>,
8+
alpha: Option<Real>,
9+
beta: Option<Real>,
10+
w: Option<Tensor>,
11+
dx: Option<Tensor>,
12+
dy: Option<Tensor>,
13+
}
14+
15+
impl ConvBwdDataBuilder {
16+
pub fn set_cfg(mut self, cfg: ConvCfg) -> Self {
17+
self.cfg = Some(cfg);
18+
self
19+
}
20+
21+
pub fn set_alpha<T>(mut self, alpha: T) -> Self
22+
where
23+
T: FloatDataType,
24+
{
25+
self.alpha = Some(alpha.wrap());
26+
self
27+
}
28+
29+
pub fn set_beta<T>(mut self, beta: T) -> Self
30+
where
31+
T: FloatDataType,
32+
{
33+
self.beta = Some(beta.wrap());
34+
self
35+
}
36+
37+
pub fn set_w(mut self, w: Tensor) -> Self {
38+
self.w = Some(w);
39+
self
40+
}
41+
42+
pub fn set_dx(mut self, dx: Tensor) -> Self {
43+
self.dx = Some(dx);
44+
self
45+
}
46+
47+
pub fn set_dy(mut self, dy: Tensor) -> Self {
48+
self.dy = Some(dy);
49+
self
50+
}
51+
52+
pub fn build(self) -> Result<Operation, CudnnError> {
53+
let cfg = self.cfg.expect("convolution configuration is required.");
54+
55+
let w = self.w.expect("w tensor is required");
56+
let dx = self.dx.expect("dx tensor is required.");
57+
let dy = self.dy.expect("dy tensor is required.");
58+
59+
let alpha = self.alpha.unwrap_or(Real::Float(1.0));
60+
let beta = self.beta.unwrap_or(Real::Float(0.0));
61+
62+
unsafe {
63+
let mut raw = Descriptor::new(
64+
sys::cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
65+
)?;
66+
67+
raw.set_attribute(
68+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC,
69+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
70+
1,
71+
&cfg.raw.inner(),
72+
)
73+
?;
74+
75+
match self.alpha {
76+
Some(Real::Float(ref alpha)) => raw.set_attribute(
77+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
78+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
79+
1,
80+
alpha,
81+
)?,
82+
Some(Real::Double(ref alpha)) => raw.set_attribute(
83+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
84+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
85+
1,
86+
alpha,
87+
)?,
88+
None => (),
89+
}
90+
91+
match self.beta {
92+
Some(Real::Float(ref beta)) => raw.set_attribute(
93+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
94+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
95+
1,
96+
beta,
97+
)?,
98+
Some(Real::Double(ref beta)) => raw.set_attribute(
99+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
100+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
101+
1,
102+
beta,
103+
)?,
104+
None => (),
105+
}
106+
107+
raw.set_attribute(
108+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W,
109+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
110+
1,
111+
&w.raw.inner(),
112+
)?;
113+
114+
raw.set_attribute(
115+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX,
116+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
117+
1,
118+
&dx.raw.inner(),
119+
)?;
120+
121+
raw.set_attribute(
122+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY,
123+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
124+
1,
125+
&dy.raw.inner(),
126+
)?;
127+
128+
raw.finalize()?;
129+
130+
Ok(Operation::ConvBwdData {
131+
raw,
132+
cfg,
133+
alpha,
134+
beta,
135+
w,
136+
dx,
137+
dy,
138+
})
139+
}
140+
}
141+
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use crate::{
2+
backend::{ConvCfg, Descriptor, FloatDataType, Operation, Real, Tensor},
3+
sys, CudnnError, DataType, IntoResult,
4+
};
5+
6+
pub struct ConvBwdFilterBuilder {
7+
cfg: Option<ConvCfg>,
8+
alpha: Option<Real>,
9+
beta: Option<Real>,
10+
dw: Option<Tensor>,
11+
x: Option<Tensor>,
12+
dy: Option<Tensor>,
13+
}
14+
15+
impl ConvBwdFilterBuilder {
16+
pub fn set_cfg(mut self, cfg: ConvCfg) -> Self {
17+
self.cfg = Some(cfg);
18+
self
19+
}
20+
21+
pub fn set_alpha<T>(mut self, alpha: T) -> Self
22+
where
23+
T: FloatDataType,
24+
{
25+
self.alpha = Some(alpha.wrap());
26+
self
27+
}
28+
29+
pub fn set_beta<T>(mut self, beta: T) -> Self
30+
where
31+
T: FloatDataType,
32+
{
33+
self.beta = Some(beta.wrap());
34+
self
35+
}
36+
37+
pub fn set_dw(mut self, dw: Tensor) -> Self {
38+
self.dw = Some(dw);
39+
self
40+
}
41+
42+
pub fn set_dx(mut self, x: Tensor) -> Self {
43+
self.x = Some(x);
44+
self
45+
}
46+
47+
pub fn set_dy(mut self, dy: Tensor) -> Self {
48+
self.dy = Some(dy);
49+
self
50+
}
51+
52+
pub fn build(self) -> Result<Operation, CudnnError> {
53+
let cfg = self.cfg.expect("convolution configuration is required.");
54+
let dw = self.dw.expect("dw tensor is required");
55+
let x = self.x.expect("x tensor is required.");
56+
let dy = self.dy.expect("dy tensor is required.");
57+
58+
let alpha = self.alpha.unwrap_or(Real::Float(1.0));
59+
let beta = self.beta.unwrap_or(Real::Float(0.0));
60+
61+
unsafe {
62+
let mut raw = Descriptor::new(
63+
sys::cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
64+
)?;
65+
66+
raw.set_attribute(
67+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC,
68+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
69+
1,
70+
&cfg.raw.inner(),
71+
)?;
72+
73+
match self.alpha {
74+
Some(Real::Float(ref alpha)) => raw.set_attribute(
75+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
76+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
77+
1,
78+
alpha,
79+
)?,
80+
Some(Real::Double(ref alpha)) => raw.set_attribute(
81+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
82+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
83+
1,
84+
alpha,
85+
)?,
86+
None => (),
87+
}
88+
89+
match self.beta {
90+
Some(Real::Float(ref beta)) => raw.set_attribute(
91+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
92+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
93+
1,
94+
beta,
95+
)?,
96+
Some(Real::Double(ref beta)) => raw.set_attribute(
97+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
98+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
99+
1,
100+
beta,
101+
)?,
102+
None => (),
103+
}
104+
105+
raw.set_attribute(
106+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW,
107+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
108+
1,
109+
&dw.raw.inner(),
110+
)?;
111+
112+
raw.set_attribute(
113+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X,
114+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
115+
1,
116+
&x.raw.inner(),
117+
)?;
118+
119+
raw.set_attribute(
120+
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY,
121+
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
122+
1,
123+
&dy.raw.inner(),
124+
)?;
125+
126+
raw.finalize()?;
127+
128+
Ok(Operation::ConvBwdFilter {
129+
raw,
130+
cfg,
131+
alpha,
132+
beta,
133+
dw,
134+
x,
135+
dy,
136+
})
137+
}
138+
}
139+
}

0 commit comments

Comments
 (0)