Skip to content

Commit d20e023

Browse files
authored
Feat: Add attention forward pass and bump cust to 0.3.2 (#53)
* Feat: Simplify RNN traits * Feat: Add attention forward pass
1 parent 0d6cd23 commit d20e023

File tree

12 files changed

+584
-71
lines changed

12 files changed

+584
-71
lines changed

crates/cudnn/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ version = "0.1.0"
66

77
[dependencies]
88
bitflags = "1.3.2"
9-
cust = {version = "0.3.0", path = "../cust"}
9+
cust = {version = "0.3.2", path = "../cust"}

crates/cudnn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# cudnn-rs
1+
# cudnn
22
Type safe cuDNN wrapper for the Rust programming language.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
use crate::{sys, CudnnError, DataType, DropoutDescriptor, IntoResult, MathType, SeqDataType};
2+
use cust::memory::GpuBuffer;
3+
use std::{marker::PhantomData, mem::MaybeUninit};
4+
5+
bitflags::bitflags! {
6+
/// Miscellaneous switches for configuring auxiliary multi-head attention features.
7+
pub struct AttnModeFlags: u32 {
8+
/// Forward declaration of mapping between Q, K and V vectors when the beam size is greater
9+
/// than one in the Q input. Multiple Q vectors from the same beam bundle map to the **same**
10+
/// K, V vectors. This means that the beam size in the K, V sets are equal to 1.
11+
const CUDNN_ATTN_QUERYMAP_ALL_TO_ONE = 0;
12+
/// Forward declaration of mapping between Q, K and V vectors when the beam size is greater
13+
/// than one in the Q input. Multiple Q vectors from the same beam bundle map to the **different**
14+
/// K, V vectors. This requires beam sized in K, V sets to be the same as the Q input.
15+
const CUDNN_ATTN_QUERYMAP_ONE_TO_ONE = 1;
16+
/// Use no biases in the attention input and output projections.
17+
const CUDNN_ATTN_DISABLE_PROJ_BIASES = 0;
18+
/// Use extra biases in the attention input and output projections.
19+
const CUDNN_ATTN_ENABLE_PROJ_BIASES = 2;
20+
}
21+
}
22+
23+
/// A multi-head attention descriptor.
24+
pub struct AttentionDescriptor<T, U, D1, D2>
25+
where
26+
T: SeqDataType,
27+
U: SupportedAttn<T>,
28+
D1: GpuBuffer<u8>,
29+
D2: GpuBuffer<u8>,
30+
{
31+
pub(crate) raw: sys::cudnnAttnDescriptor_t,
32+
data_type: PhantomData<T>,
33+
math_prec: PhantomData<U>,
34+
attn_dropout_desc: DropoutDescriptor<D1>,
35+
post_dropout_desc: DropoutDescriptor<D2>,
36+
}
37+
38+
impl<T, U, D1, D2> AttentionDescriptor<T, U, D1, D2>
39+
where
40+
T: SeqDataType,
41+
U: SupportedAttn<T>,
42+
D1: GpuBuffer<u8>,
43+
D2: GpuBuffer<u8>,
44+
{
45+
/// Creates a new multi-head attention descriptor.
46+
///
47+
/// # Arguments
48+
///
49+
/// * `mode` - bit flag enabling various attention options that do not require additional
50+
/// numerical values.
51+
///
52+
/// * `n_heads` - number of attention heads.
53+
///
54+
/// * `sm_scaler` - softmax sharpening/smoothing coefficient. Must be positive.
55+
///
56+
/// * `math_type` - nvidia tensor cores setting.
57+
///
58+
/// * `attn_dropout_desc` - descriptor of the dropout operation applied to the softmax output.
59+
///
60+
/// * `post_dropout_desc` - descriptor of the dropout operation applied to the multi-head
61+
/// attention output, just before the point where residual connections are added.
62+
///
63+
/// * `q_size` - q vectors length.
64+
///
65+
/// * `k_size` - k vectors length.
66+
///
67+
/// * `v_size` - v vectors length.
68+
///
69+
/// * `q_proj_size` - q vectors length after input projection.
70+
///
71+
/// * `k_proj_size` - k vectors length after input projection.
72+
///
73+
/// * `v_proj_size` - v vectors length after input projection.
74+
///
75+
/// * `o_proj_size` - h vectors length after output projection.
76+
///
77+
/// * `qo_max_seq_length` - largest sequence length expected in sequence data descriptors
78+
/// related to Q, O, dQ and dO inputs and outputs.
79+
///
80+
/// * `kv_max_seq_length` - largest sequence length expected in sequence data descriptors
81+
/// related to K, V, dK and dV inputs and outputs.
82+
///
83+
/// * `max_batch_size` - largest batch expected in any sequential data descriptor.
84+
///
85+
/// * `max_bream_size` - largest beam expected in any sequential data descriptor.
86+
///
87+
/// # Errors
88+
///
89+
/// Returns errors if an unsupported combination of arguments is detected. Some examples
90+
/// include:
91+
///
92+
/// * post projection Q and K are not equal.
93+
///
94+
/// * math type is not supported.
95+
///
96+
/// * one or more of the following arguments were either negative or zero: `n_heads`,
97+
/// `q_size`, `k_size`, `v_size`, `qo_max_seq_length`, `kv_max_seq_length`, `max_batch_size` and
98+
/// ` max_beam_size`.
99+
///
100+
/// * one or more of the following arguments were negative: `q_proj_size`, `k_proj_size`,
101+
/// `v_proj_size`, `sm_scaler`.
102+
///
103+
pub fn new(
104+
mode: AttnModeFlags,
105+
n_heads: i32,
106+
sm_scaler: f64,
107+
math_type: MathType,
108+
attn_dropout_desc: DropoutDescriptor<D1>,
109+
post_dropout_desc: DropoutDescriptor<D2>,
110+
q_size: i32,
111+
k_size: i32,
112+
v_size: i32,
113+
q_proj_size: impl Into<Option<i32>>,
114+
k_proj_size: impl Into<Option<i32>>,
115+
v_proj_size: impl Into<Option<i32>>,
116+
o_proj_size: impl Into<Option<i32>>,
117+
qo_max_seq_length: i32,
118+
kv_max_seq_lenght: i32,
119+
max_batch_size: i32,
120+
max_beam_size: i32,
121+
) -> Result<Self, CudnnError> {
122+
let mut raw = MaybeUninit::uninit();
123+
124+
unsafe {
125+
sys::cudnnCreateAttnDescriptor(raw.as_mut_ptr()).into_result()?;
126+
127+
let mut raw = raw.assume_init();
128+
129+
sys::cudnnSetAttnDescriptor(
130+
raw,
131+
mode.bits(),
132+
n_heads,
133+
sm_scaler,
134+
T::into_raw(),
135+
U::into_raw(),
136+
math_type.into(),
137+
attn_dropout_desc.raw,
138+
post_dropout_desc.raw,
139+
q_size,
140+
k_size,
141+
v_size,
142+
q_proj_size.into().unwrap_or(0),
143+
k_proj_size.into().unwrap_or(0),
144+
v_proj_size.into().unwrap_or(0),
145+
o_proj_size.into().unwrap_or(0),
146+
qo_max_seq_length,
147+
kv_max_seq_lenght,
148+
max_batch_size,
149+
max_beam_size,
150+
)
151+
.into_result()?;
152+
153+
Ok(Self {
154+
raw,
155+
data_type: PhantomData,
156+
math_prec: PhantomData,
157+
attn_dropout_desc,
158+
post_dropout_desc,
159+
})
160+
}
161+
}
162+
}
163+
164+
impl<T, U, D1, D2> Drop for AttentionDescriptor<T, U, D1, D2>
165+
where
166+
T: SeqDataType,
167+
U: SupportedAttn<T>,
168+
D1: GpuBuffer<u8>,
169+
D2: GpuBuffer<u8>,
170+
{
171+
fn drop(&mut self) {
172+
unsafe {
173+
sys::cudnnDestroyAttnDescriptor(self.raw);
174+
}
175+
}
176+
}
177+
178+
/// Controls the compute math precision in the multi-head attention. The following
179+
/// applies:
180+
///
181+
/// * For input and output in `f32`, the math precision of the layer can only be `f32`.
182+
///
183+
/// * For input and output in `f64` the math precision of the layer can only be `f64`.
184+
pub trait SupportedAttn<T>
185+
where
186+
Self: DataType,
187+
T: SeqDataType,
188+
{
189+
}
190+
191+
impl SupportedAttn<f32> for f32 {}
192+
impl SupportedAttn<f64> for f64 {}

crates/cudnn/src/attention/mod.rs

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
mod attention_descriptor;
2+
mod seq_data_axis;
3+
mod seq_data_descriptor;
4+
5+
pub use attention_descriptor::*;
6+
pub use seq_data_axis::*;
7+
pub use seq_data_descriptor::*;
8+
9+
use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult};
10+
use cust::memory::GpuBuffer;
11+
use std::mem::MaybeUninit;
12+
13+
impl CudnnContext {
14+
/// This function computes weight, work, and reserve space buffer sizes used by the following
15+
/// functions:
16+
///
17+
/// * `multi_head_attn_forward()`
18+
///
19+
/// * `multi_head_attn_backward_data()`
20+
///
21+
/// * `multi_head_attn_backward_weights()`
22+
///
23+
/// # Arguments
24+
///
25+
/// `desc` - multi-head attention descriptor.
26+
///
27+
/// # Errors
28+
///
29+
/// Returns errors if invalid arguments are detected.
30+
pub fn get_attn_buffers_size<T, U, D1, D2>(
31+
&self,
32+
desc: &AttentionDescriptor<T, U, D1, D2>,
33+
) -> Result<(usize, usize, usize), CudnnError>
34+
where
35+
T: SeqDataType,
36+
U: SupportedAttn<T>,
37+
D1: GpuBuffer<u8>,
38+
D2: GpuBuffer<u8>,
39+
{
40+
let mut weight_space_size = MaybeUninit::uninit();
41+
let mut work_space_size = MaybeUninit::uninit();
42+
let mut reserve_space_size = MaybeUninit::uninit();
43+
44+
unsafe {
45+
sys::cudnnGetMultiHeadAttnBuffers(
46+
self.raw,
47+
desc.raw,
48+
weight_space_size.as_mut_ptr(),
49+
work_space_size.as_mut_ptr(),
50+
reserve_space_size.as_mut_ptr(),
51+
)
52+
.into_result()?;
53+
54+
Ok((
55+
weight_space_size.assume_init(),
56+
work_space_size.assume_init(),
57+
reserve_space_size.assume_init(),
58+
))
59+
}
60+
}
61+
62+
/// Computes the forward response of a multi-head attention layer.
63+
///
64+
/// When `reserve_space` is `None` the function operates in the inference mode in which backward
65+
/// functions are not invoked, otherwise, the training mode is assumed.
66+
///
67+
/// # Arguments
68+
///
69+
/// * `attn_desc` - multi-head attention descriptor.
70+
///
71+
/// * `current_idx` - time-step in queries to process. When the such argument is negative,
72+
/// all Q time-steps are processed. When `current_idx` is zero or positive, the forward response
73+
/// is computed for the selected time-step only.
74+
///
75+
/// * `lo_win_idx` - integer array specifying the start indices of the attention window for
76+
/// each Q time-step. The start index in K, V sets is inclusive.
77+
///
78+
/// * `hi_win_idx` - integer array specifying the end indices of the attention window for each
79+
/// Q time-step. The end index is exclusive.
80+
///
81+
/// * `device_seq_lengths_qo` - device array specifying sequence lengths of query, residual,
82+
/// and output sequence data.
83+
///
84+
/// * `device_seq_lengths_kv` - device array specifying sequence lengths of key and value \
85+
/// input data.
86+
///
87+
/// * `q_desc` - descriptor for the query and residual sequence data.
88+
///
89+
/// * `queries` - queries data in the device memory.
90+
///
91+
/// * `residuals` - residual data in device memory. Set this argument to `None` if no residual
92+
/// connections are required.
93+
///
94+
/// * `k_desc` - descriptor for the keys sequence data.
95+
///
96+
/// * `keys` - keys data in device memory.
97+
///
98+
/// * `v_desc` - descriptor for the values sequence data.
99+
///
100+
/// * `values` - values data in device memory.
101+
///
102+
/// * `o_desc` - descriptor for the out sequence data.
103+
///
104+
/// * `out` - out data in device memory.
105+
///
106+
/// * `weights` - weight buffer in device memory.
107+
///
108+
/// * `work_space` - work space buffer in device memory.
109+
///
110+
/// * `reserve_space` - reserve space buffer in device memory. This argument should be `None` in
111+
/// inference mode.
112+
pub fn multi_head_attn_forward<T, U, D1, D2>(
113+
&self,
114+
attn_desc: &AttentionDescriptor<T, U, D1, D2>,
115+
current_idx: i32,
116+
lo_win_idx: &[i32],
117+
hi_win_idx: &[i32],
118+
device_seq_lengths_qo: &impl GpuBuffer<i32>,
119+
device_seq_lengths_kv: &impl GpuBuffer<i32>,
120+
q_desc: &SeqDataDescriptor<T>,
121+
queries: &impl GpuBuffer<T>,
122+
residuals: Option<&impl GpuBuffer<T>>,
123+
k_desc: &SeqDataDescriptor<T>,
124+
keys: &impl GpuBuffer<T>,
125+
v_desc: &SeqDataDescriptor<T>,
126+
values: &impl GpuBuffer<T>,
127+
o_desc: &SeqDataDescriptor<T>,
128+
out: &mut impl GpuBuffer<T>,
129+
weights: &impl GpuBuffer<T>,
130+
work_space: &mut impl GpuBuffer<T>,
131+
reserve_space: Option<&mut impl GpuBuffer<T>>,
132+
) -> Result<(), CudnnError>
133+
where
134+
T: SeqDataType,
135+
U: SupportedAttn<T>,
136+
D1: GpuBuffer<u8>,
137+
D2: GpuBuffer<u8>,
138+
{
139+
let device_seq_lenghts_qo_ptr = device_seq_lengths_qo.as_device_ptr().as_ptr() as *const _;
140+
let device_seq_lengths_kv_ptr = device_seq_lengths_kv.as_device_ptr().as_ptr() as *const _;
141+
142+
let queries_ptr = queries.as_device_ptr().as_ptr() as *const _;
143+
let residuals_ptr = residuals.map_or(std::ptr::null(), |buff| {
144+
buff.as_device_ptr().as_ptr() as *const _
145+
});
146+
let keys_ptr = keys.as_device_ptr().as_ptr() as *const _;
147+
let values_ptr = values.as_device_ptr().as_ptr() as *const _;
148+
let out_ptr = out.as_device_ptr().as_mut_ptr() as *mut _;
149+
150+
let weights_ptr = weights.as_device_ptr().as_ptr() as *const _;
151+
let work_space_ptr = work_space.as_device_ptr().as_mut_ptr() as *mut _;
152+
153+
let (reserve_space_ptr, reserve_space_size) = reserve_space
154+
.map_or((std::ptr::null_mut(), 0), |buff| {
155+
(buff.as_device_ptr().as_mut_ptr() as *mut _, 0)
156+
});
157+
158+
unsafe {
159+
sys::cudnnMultiHeadAttnForward(
160+
self.raw,
161+
attn_desc.raw,
162+
current_idx,
163+
lo_win_idx.as_ptr(),
164+
hi_win_idx.as_ptr(),
165+
device_seq_lenghts_qo_ptr,
166+
device_seq_lengths_kv_ptr,
167+
q_desc.raw,
168+
queries_ptr,
169+
residuals_ptr,
170+
k_desc.raw,
171+
keys_ptr,
172+
v_desc.raw,
173+
values_ptr,
174+
o_desc.raw,
175+
out_ptr,
176+
weights.len(),
177+
weights_ptr,
178+
work_space.len(),
179+
work_space_ptr,
180+
reserve_space_size,
181+
reserve_space_ptr,
182+
)
183+
.into_result()?;
184+
185+
Ok(())
186+
}
187+
}
188+
}

0 commit comments

Comments
 (0)