Skip to content

Commit 54978ca

Browse files
authored
Feat: Add multi-head attention backward passes and convolution fused bias-act (#62)
* Feat: Add AttnWeight enum and ScalingDataType trait * Feat: Add multi-head attention backward passes * Feat: Add convolution fused bias activation
1 parent 820b5cf commit 54978ca

File tree

11 files changed

+680
-149
lines changed

11 files changed

+680
-149
lines changed

crates/cudnn/src/activation/activation_mode.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ pub enum ActivationMode {
1515
Elu,
1616
/// Selects the swish function.
1717
Swish,
18+
/// Selects no activation.
19+
///
20+
/// **Do note** that this is only valid for an activation descriptor passed to
21+
/// [`convolution_bias_act_forward()`](CudnnContext::convolution_bias_act_fwd).
22+
Identity,
1823
}
1924

2025
impl From<ActivationMode> for sys::cudnnActivationMode_t {
@@ -26,6 +31,7 @@ impl From<ActivationMode> for sys::cudnnActivationMode_t {
2631
ActivationMode::ClippedRelu => Self::CUDNN_ACTIVATION_CLIPPED_RELU,
2732
ActivationMode::Elu => Self::CUDNN_ACTIVATION_ELU,
2833
ActivationMode::Swish => Self::CUDNN_ACTIVATION_SWISH,
34+
ActivationMode::Identity => Self::CUDNN_ACTIVATION_IDENTITY,
2935
}
3036
}
3137
}

crates/cudnn/src/activation/mod.rs

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ mod activation_mode;
44
pub use activation_descriptor::*;
55
pub use activation_mode::*;
66

7-
use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor};
7+
use crate::{
8+
private, sys, CudnnContext, CudnnError, DataType, IntoResult, ScalingDataType, TensorDescriptor,
9+
};
810
use cust::memory::GpuBuffer;
911
use std::mem::MaybeUninit;
1012

@@ -49,11 +51,11 @@ impl CudnnContext {
4951
///
5052
/// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?;
5153
///
52-
/// let alpha: f32 = 1.0;
54+
/// let alpha = 1.0;
5355
/// let x_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
5456
/// let x = DeviceBuffer::<i8>::from_slice(&[10, 10, 10, 10, 10])?;
5557
///
56-
/// let beta: f32 = 0.0;
58+
/// let beta = 0.0;
5759
/// let y_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
5860
/// let mut y = DeviceBuffer::<i8>::from_slice(&[0, 0, 0, 0, 0])?;
5961
///
@@ -76,7 +78,7 @@ impl CudnnContext {
7678
y: &mut impl GpuBuffer<T>,
7779
) -> Result<(), CudnnError>
7880
where
79-
CompT: SupportedActFwd<T>,
81+
CompT: ScalingDataType<T>,
8082
T: DataType,
8183
{
8284
let alpha_ptr = &alpha as *const CompT as *const _;
@@ -179,27 +181,6 @@ impl CudnnContext {
179181
}
180182
}
181183

182-
/// Supported data type configurations for the activation forward operation.
183-
pub trait SupportedActFwd<T>: DataType + private::Sealed
184-
where
185-
T: DataType,
186-
{
187-
}
188-
189-
impl SupportedActFwd<i8> for f32 {}
190-
impl SupportedActFwd<u8> for f32 {}
191-
impl SupportedActFwd<i32> for f32 {}
192-
impl SupportedActFwd<i64> for f32 {}
193-
impl SupportedActFwd<f32> for f32 {}
194-
impl SupportedActFwd<f64> for f32 {}
195-
196-
impl SupportedActFwd<i8> for f64 {}
197-
impl SupportedActFwd<u8> for f64 {}
198-
impl SupportedActFwd<i32> for f64 {}
199-
impl SupportedActFwd<i64> for f64 {}
200-
impl SupportedActFwd<f32> for f64 {}
201-
impl SupportedActFwd<f64> for f64 {}
202-
203184
/// Supported type configurations for the activation backward operation.
204185
pub trait SupportedActBwd<T>: DataType + private::Sealed
205186
where
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use crate::sys;
2+
3+
/// Specifies a group of weights or biases for the multi-head attention layer.
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5+
pub enum AttnWeight {
6+
/// Selects the input projection weights for queries.
7+
QWeights,
8+
/// Selects the input projection weights for keys.
9+
KWeights,
10+
/// Selects the input projection weights for values.
11+
VWeights,
12+
/// Selects the output projection weights.
13+
OWeights,
14+
/// Selects the input projection biases for queries.
15+
QBiases,
16+
/// Selects the input projection biases for keys.
17+
KBiases,
18+
/// Selects the input projection biases for values.
19+
VBiases,
20+
/// Selects the output projection biases.
21+
OBiases,
22+
}
23+
24+
impl From<AttnWeight> for sys::cudnnMultiHeadAttnWeightKind_t {
25+
fn from(kind: AttnWeight) -> Self {
26+
match kind {
27+
AttnWeight::QWeights => Self::CUDNN_MH_ATTN_Q_WEIGHTS,
28+
AttnWeight::KWeights => Self::CUDNN_MH_ATTN_K_WEIGHTS,
29+
AttnWeight::VWeights => Self::CUDNN_MH_ATTN_V_WEIGHTS,
30+
AttnWeight::OWeights => Self::CUDNN_MH_ATTN_O_WEIGHTS,
31+
AttnWeight::QBiases => Self::CUDNN_MH_ATTN_Q_BIASES,
32+
AttnWeight::KBiases => Self::CUDNN_MH_ATTN_K_BIASES,
33+
AttnWeight::VBiases => Self::CUDNN_MH_ATTN_V_BIASES,
34+
AttnWeight::OBiases => Self::CUDNN_MH_ATTN_O_BIASES,
35+
}
36+
}
37+
}

0 commit comments

Comments
 (0)