From 89a858f56334957d6671c9fc5e758674e2efe291 Mon Sep 17 00:00:00 2001 From: frjnn Date: Tue, 15 Mar 2022 13:48:23 +0100 Subject: [PATCH 1/3] Feat: Add AttnWeight enum and ScalingDataType trait --- crates/cudnn/src/activation/mod.rs | 31 ++----- .../src/attention/attention_weights_kind.rs | 37 +++++++++ crates/cudnn/src/attention/mod.rs | 2 + crates/cudnn/src/data_type.rs | 16 ++++ crates/cudnn/src/op/mod.rs | 83 ++++++++++++++++--- crates/cudnn/src/pooling/mod.rs | 27 +----- crates/cudnn/src/rnn/rnn_algo.rs | 30 ++----- crates/cudnn/src/rnn/rnn_descriptor.rs | 2 +- 8 files changed, 148 insertions(+), 80 deletions(-) create mode 100644 crates/cudnn/src/attention/attention_weights_kind.rs diff --git a/crates/cudnn/src/activation/mod.rs b/crates/cudnn/src/activation/mod.rs index 71d310cb..c1ff96ed 100644 --- a/crates/cudnn/src/activation/mod.rs +++ b/crates/cudnn/src/activation/mod.rs @@ -4,7 +4,9 @@ mod activation_mode; pub use activation_descriptor::*; pub use activation_mode::*; -use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; +use crate::{ + private, sys, CudnnContext, CudnnError, DataType, IntoResult, ScalingDataType, TensorDescriptor, +}; use cust::memory::GpuBuffer; use std::mem::MaybeUninit; @@ -49,11 +51,11 @@ impl CudnnContext { /// /// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?; /// - /// let alpha: f32 = 1.0; + /// let alpha = 1.0; /// let x_desc = TensorDescriptor::::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?; /// let x = DeviceBuffer::::from_slice(&[10, 10, 10, 10, 10])?; /// - /// let beta: f32 = 0.0; + /// let beta = 0.0; /// let y_desc = TensorDescriptor::::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?; /// let mut y = DeviceBuffer::::from_slice(&[0, 0, 0, 0, 0])?; /// @@ -76,7 +78,7 @@ impl CudnnContext { y: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where - CompT: SupportedActFwd, + CompT: ScalingDataType, T: DataType, { let alpha_ptr = &alpha as *const CompT as *const _; @@ -179,27 +181,6 @@ impl CudnnContext { } } -/// Supported data type configurations for the activation forward operation. -pub trait SupportedActFwd: DataType + private::Sealed -where - T: DataType, -{ -} - -impl SupportedActFwd for f32 {} -impl SupportedActFwd for f32 {} -impl SupportedActFwd for f32 {} -impl SupportedActFwd for f32 {} -impl SupportedActFwd for f32 {} -impl SupportedActFwd for f32 {} - -impl SupportedActFwd for f64 {} -impl SupportedActFwd for f64 {} -impl SupportedActFwd for f64 {} -impl SupportedActFwd for f64 {} -impl SupportedActFwd for f64 {} -impl SupportedActFwd for f64 {} - /// Supported type configurations for the activation backward operation. pub trait SupportedActBwd: DataType + private::Sealed where diff --git a/crates/cudnn/src/attention/attention_weights_kind.rs b/crates/cudnn/src/attention/attention_weights_kind.rs new file mode 100644 index 00000000..11c5ad76 --- /dev/null +++ b/crates/cudnn/src/attention/attention_weights_kind.rs @@ -0,0 +1,37 @@ +use crate::sys; + +/// Specifies a group of weights or biases for the multi-head attention layer. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AttnWeight { + /// Selects the input projection weights for queries. + QWeights, + /// Selects the input projection weights for keys. + KWeights, + /// Selects the input projection weights for values. + VWeights, + /// Selects the output projection weights. + OWeights, + /// Selects the input projection biases for queries. + QBiases, + /// Selects the input projection biases for keys. + KBiases, + /// Selects the input projection biases for values. + VBiases, + /// Selects the output projection biases. + OBiases, +} + +impl From for sys::cudnnMultiHeadAttnWeightKind_t { + fn from(kind: AttnWeight) -> Self { + match kind { + AttnWeight::QWeights => Self::CUDNN_MH_ATTN_Q_WEIGHTS, + AttnWeight::KWeights => Self::CUDNN_MH_ATTN_K_WEIGHTS, + AttnWeight::VWeights => Self::CUDNN_MH_ATTN_V_WEIGHTS, + AttnWeight::OWeights => Self::CUDNN_MH_ATTN_O_WEIGHTS, + AttnWeight::QBiases => Self::CUDNN_MH_ATTN_Q_BIASES, + AttnWeight::KBiases => Self::CUDNN_MH_ATTN_K_BIASES, + AttnWeight::VBiases => Self::CUDNN_MH_ATTN_V_BIASES, + AttnWeight::OBiases => Self::CUDNN_MH_ATTN_O_BIASES, + } + } +} diff --git a/crates/cudnn/src/attention/mod.rs b/crates/cudnn/src/attention/mod.rs index 00317436..bbc6dabe 100644 --- a/crates/cudnn/src/attention/mod.rs +++ b/crates/cudnn/src/attention/mod.rs @@ -1,8 +1,10 @@ mod attention_descriptor; +mod attention_weights_kind; mod seq_data_axis; mod seq_data_descriptor; pub use attention_descriptor::*; +pub use attention_weights_kind::*; pub use seq_data_axis::*; pub use seq_data_descriptor::*; diff --git a/crates/cudnn/src/data_type.rs b/crates/cudnn/src/data_type.rs index c93fd10b..aba15889 100644 --- a/crates/cudnn/src/data_type.rs +++ b/crates/cudnn/src/data_type.rs @@ -30,6 +30,7 @@ pub struct Vec4; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct Vec32; +/// Vectorized data type. Vectorization size can be either 4 or 32 elements. pub trait VecType: private::Sealed where T: DataType, @@ -55,3 +56,18 @@ macro_rules! impl_cudnn_vec_type { impl_cudnn_vec_type!(Vec4, i8, CUDNN_DATA_INT8x4); impl_cudnn_vec_type!(Vec32, i8, CUDNN_DATA_INT8x32); impl_cudnn_vec_type!(Vec4, u8, CUDNN_DATA_UINT8x4); + +/// Admissible data types for scaling parameters. +pub trait ScalingDataType: DataType + private::Sealed +where + T: DataType, +{ +} + +impl ScalingDataType for f32 {} +impl ScalingDataType for f32 {} +impl ScalingDataType for f32 {} +impl ScalingDataType for f32 {} +impl ScalingDataType for f32 {} + +impl ScalingDataType for f64 {} diff --git a/crates/cudnn/src/op/mod.rs b/crates/cudnn/src/op/mod.rs index 8dec4c5c..c01a72e6 100644 --- a/crates/cudnn/src/op/mod.rs +++ b/crates/cudnn/src/op/mod.rs @@ -4,7 +4,9 @@ mod op_tensor_op; pub use op_tensor_descriptor::*; pub use op_tensor_op::*; -use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; +use crate::{ + sys, CudnnContext, CudnnError, DataType, IntoResult, ScalingDataType, TensorDescriptor, +}; use cust::memory::GpuBuffer; impl CudnnContext { @@ -255,6 +257,10 @@ impl CudnnContext { /// to dimension five (5) are supported. This routine does not support tensor formats beyond /// these dimensions. /// + /// # Errors + /// + /// Returns error if an unsupported configurations of arguments is detected. + /// /// # Examples /// /// ``` @@ -282,19 +288,18 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn add_assign( + pub fn add_assign( &self, alpha: CompT, a_desc: &TensorDescriptor, a: &impl GpuBuffer, gamma: CompT, - c_desc: &TensorDescriptor, - c: &mut impl GpuBuffer, + c_desc: &TensorDescriptor, + c: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where - CompT: SupportedOp, + CompT: ScalingDataType, T1: DataType, - T2: DataType, { let a_data = a.as_device_ptr().as_ptr() as *const std::ffi::c_void; let c_data = c.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; @@ -320,7 +325,9 @@ impl CudnnContext { /// /// * `value` - value to set. Must be stored in host memory. /// - /// **Do note** that this routine is only available for `f32` and `f64` tensors. + /// # Errors + /// + /// Returns error if an unsupported configurations of arguments is detected. /// /// # Examples /// @@ -345,19 +352,75 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn set( + pub fn set( &self, desc: &TensorDescriptor, data: &mut impl GpuBuffer, - value: T, + value: CompT, ) -> Result<(), CudnnError> where + CompT: ScalingDataType, T: DataType, { let data = data.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; - let value = &value as *const T as *const std::ffi::c_void; + let value = &value as *const CompT as *const std::ffi::c_void; unsafe { sys::cudnnSetTensor(self.raw, desc.raw, data, value).into_result() } } + + /// This function scales all the element of a tensor by a given value. + /// + /// # Arguments + /// + /// * `desc` - descriptor of the tensor to scale. + /// + /// * `data` - data of the tensor. + /// + /// * `value` - value in the host memory to a single value that all elements of the tensor will + /// be scaled with. + /// + /// # Errors + /// + /// Returns error if an unsupported configurations of arguments is detected. + /// + /// # Examples + /// + /// ``` + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use cudnn::{CudnnContext, ScalarC, TensorDescriptor}; + /// use cust::memory::DeviceBuffer; + /// + /// let ctx = CudnnContext::new()?; + /// + /// let value = 7.0; + /// let desc = TensorDescriptor::::new_format(&[1, 1, 1, 5], ScalarC::Nchw)?; + /// let mut data = DeviceBuffer::::from_slice(&[2, 2, 2, 2, 2])?; + /// + /// ctx.scale(&desc, &mut data, value)?; + /// + /// let data_host = data.as_host_vec()?; + /// + /// assert!(data_host.iter().all(|x| *x == 14)); + /// # Ok(()) + /// # } + /// ``` + pub fn scale( + &self, + desc: &TensorDescriptor, + data: &mut impl GpuBuffer, + value: CompT, + ) -> Result<(), CudnnError> + where + CompT: ScalingDataType, + T: DataType, + { + let data = data.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; + + let value = &value as *const CompT as *const std::ffi::c_void; + + unsafe { sys::cudnnScaleTensor(self.raw, desc.raw, data, value).into_result() } + } } diff --git a/crates/cudnn/src/pooling/mod.rs b/crates/cudnn/src/pooling/mod.rs index 4b2c9309..f2d07c6d 100644 --- a/crates/cudnn/src/pooling/mod.rs +++ b/crates/cudnn/src/pooling/mod.rs @@ -4,7 +4,9 @@ mod pooling_mode; pub use pooling_descriptor::*; pub use pooling_mode::*; -use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; +use crate::{ + private, sys, CudnnContext, CudnnError, DataType, IntoResult, ScalingDataType, TensorDescriptor, +}; use cust::memory::GpuBuffer; impl CudnnContext { @@ -42,7 +44,7 @@ impl CudnnContext { y: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where - CompT: SupportedPoolFwd, + CompT: ScalingDataType, T: DataType, { let alpha_ptr = &alpha as *const CompT as *const _; @@ -145,27 +147,6 @@ impl CudnnContext { } } -/// Supported data type configurations for the pooling forward operation. -pub trait SupportedPoolFwd: DataType + private::Sealed -where - T: DataType, -{ -} - -impl SupportedPoolFwd for f32 {} -impl SupportedPoolFwd for f32 {} -impl SupportedPoolFwd for f32 {} -impl SupportedPoolFwd for f32 {} -impl SupportedPoolFwd for f32 {} -impl SupportedPoolFwd for f32 {} - -impl SupportedPoolFwd for f64 {} -impl SupportedPoolFwd for f64 {} -impl SupportedPoolFwd for f64 {} -impl SupportedPoolFwd for f64 {} -impl SupportedPoolFwd for f64 {} -impl SupportedPoolFwd for f64 {} - /// Supported type configurations for the pooling backward operation. pub trait SupportedPoolBwd: DataType + private::Sealed where diff --git a/crates/cudnn/src/rnn/rnn_algo.rs b/crates/cudnn/src/rnn/rnn_algo.rs index b151beb4..531cfcd9 100644 --- a/crates/cudnn/src/rnn/rnn_algo.rs +++ b/crates/cudnn/src/rnn/rnn_algo.rs @@ -2,35 +2,23 @@ use crate::sys; /// A recurrent neural network algorithm. /// -/// **Do note** that double precision is only supported by `RnnAlgo::AlgoStandard`. +/// **Do note** that double precision is only supported by `RnnAlgo::Standard`. #[non_exhaustive] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum RnnAlgo { - AlgoStandard, - AlgoPersistStatic, - AlgoPersistDynamic, -} - -impl From for RnnAlgo { - fn from(raw: sys::cudnnRNNAlgo_t) -> Self { - match raw { - sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_STANDARD => Self::AlgoStandard, - sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_STATIC => Self::AlgoPersistStatic, - sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_DYNAMIC => Self::AlgoPersistDynamic, - // This whole enumeration is not documented in the cuDNN docs, the 3 fields above - // are just briefly mentioned and the others never appear. I therefore assume they are - // of no use. - _ => unreachable!(), - } - } + Standard, + PersistStatic, + PersistDynamic, + PersistStaticSmallH, } impl From for sys::cudnnRNNAlgo_t { fn from(algo: RnnAlgo) -> Self { match algo { - RnnAlgo::AlgoStandard => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_STANDARD, - RnnAlgo::AlgoPersistStatic => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_STATIC, - RnnAlgo::AlgoPersistDynamic => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_DYNAMIC, + RnnAlgo::Standard => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_STANDARD, + RnnAlgo::PersistStatic => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_STATIC, + RnnAlgo::PersistDynamic => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_DYNAMIC, + RnnAlgo::PersistStaticSmallH => sys::cudnnRNNAlgo_t::CUDNN_RNN_ALGO_PERSIST_STATIC, } } } diff --git a/crates/cudnn/src/rnn/rnn_descriptor.rs b/crates/cudnn/src/rnn/rnn_descriptor.rs index 3f906519..3afd7b27 100644 --- a/crates/cudnn/src/rnn/rnn_descriptor.rs +++ b/crates/cudnn/src/rnn/rnn_descriptor.rs @@ -108,7 +108,7 @@ where /// /// let ctx = CudnnContext::new()?; /// - /// let algo = RnnAlgo::AlgoStandard; + /// let algo = RnnAlgo::Standard; /// let cell_mode = RnnMode::Lstm; /// let bias_mode = RnnBiasMode::SingleRecurrentBias; /// let dir_mode = RnnDirectionMode::Unidirectional; From b04be11cea9844c73b138644a5717c7c0792fef8 Mon Sep 17 00:00:00 2001 From: frjnn Date: Wed, 16 Mar 2022 11:30:24 +0100 Subject: [PATCH 2/3] Feat: Add multi-head attention backward passes --- crates/cudnn/src/attention/mod.rs | 271 +++++++++++++++++++++++++++++- crates/cudnn/src/rnn/mod.rs | 6 +- 2 files changed, 269 insertions(+), 8 deletions(-) diff --git a/crates/cudnn/src/attention/mod.rs b/crates/cudnn/src/attention/mod.rs index bbc6dabe..ce2ce552 100644 --- a/crates/cudnn/src/attention/mod.rs +++ b/crates/cudnn/src/attention/mod.rs @@ -8,7 +8,7 @@ pub use attention_weights_kind::*; pub use seq_data_axis::*; pub use seq_data_descriptor::*; -use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult}; +use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, WGradMode}; use cust::memory::GpuBuffer; use std::mem::MaybeUninit; @@ -105,7 +105,7 @@ impl CudnnContext { /// /// * `out` - out data in device memory. /// - /// * `weights` - weight buffer in device memory. + /// * `weights` - weights buffer in device memory. /// /// * `work_space` - work space buffer in device memory. /// @@ -128,9 +128,9 @@ impl CudnnContext { values: &impl GpuBuffer, o_desc: &SeqDataDescriptor, out: &mut impl GpuBuffer, - weights: &impl GpuBuffer, - work_space: &mut impl GpuBuffer, - reserve_space: Option<&mut impl GpuBuffer>, + weights: &impl GpuBuffer, + work_space: &mut impl GpuBuffer, + reserve_space: Option<&mut impl GpuBuffer>, ) -> Result<(), CudnnError> where T: SeqDataType, @@ -185,4 +185,265 @@ impl CudnnContext { .into_result() } } + + /// Computes exact, first-order derivatives of the multi-head attention block with respect to its + /// inputs: Q, K, V. + /// + /// This function does not output partial derivatives for residual connections because this + /// result is equal to `d_out`. If the multi-head attention model enables residual connections + /// sourced directly from Q, then the `d_out` tensor needs to be added to `d_queries` to obtain + /// the correct result of the latter. + /// + /// This function must be invoked after `multi_head_attn_forward()`. The `lo_win_idx`, + /// `hi_win_idx`, `queries`, `keys`, `values`, `weights`, and `reserve_space` arguments + /// should be the same as in the `multi_head_attn_forward()` call. + /// + /// Furthermore, `device_seq_lengths_dqdo` and `device_seq_lengths_dkdv` device buffers should + /// contain the same start and end attention window indices as `device_seq_lengths_qo` and + /// `device_seq_lengths_qo` as in the forward function invocation. + /// + /// **Do note** that this function does not verify that sequence lengths stored in + /// `device_seq_lengths_dqdo` and `device_seq_lengths_dkdv` contain the same settings as + /// `seq_lengths` in the corresponding sequence data descriptor. + /// + /// # Arguments + /// + /// * `attn_desc` - multi-head attention descriptor. + /// + /// * `lo_win_idx` - integer array specifying the start indices of the attention window for + /// each Q time-step. The start index in K, V sets is inclusive. + /// + /// * `hi_win_idx` - integer array specifying the end indices of the attention window for each + /// Q time-step. The end index is exclusive. + /// + /// * `device_seq_lengths_dqdo` - device buffer containing a copy of the sequence length array + /// from the `dq_desc` or `do_desc` sequence data descriptors. + /// + /// * `device_seq_lengths_dkdv` - device buffer containing a copy of the sequence length array + /// from the `dk_desc` or `dv_desc` sequence data descriptors. + /// + /// * `do_desc` - descriptor for the output differential, i.e. the vectors of partial + /// derivatives of the loss function with respect to the multi-head attention outputs. + /// + /// * `d_out` - output differential. + /// + /// * `dq_desc` - descriptor for the queries differential. + /// + /// * `d_queries` - gradients of the loss function computed with respect to queries vectors. + /// + /// * `queries` - queries data. This must be the same input as in `multi_head_attn_forward()`. + /// + /// * `dk_desc` - descriptor for the keys and keys gradient sequence data. + /// + /// * `d_keys` - gradients of the loss function computed with respect to keys vectors. + /// + /// * `keys` - keys data. This must be the same input as in `multi_head_attn_forward()`. + /// + /// * `dv_desc` - descriptor for values and values gradient sequence data. + /// + /// * `d_values` - gradients of the loss function computed with respect to values vectors. + /// + /// * `values` - values data. This must be the same input as in `multi_head_attn_forward()`. + /// + /// * `weights` - weights buffer in the device memory. + /// + /// * `work_space` - work space buffer in device memory. Used for temporary API storage. + /// + /// * `reserve_space` - reserve space buffer in device memory. + /// + /// # Errors + /// + /// Returns errors if an invalid or incompatible input argument was encountered, an inconsistent + /// internal state was encountered, a requested option or a combination of input arguments is + /// not supported or in case of insufficient amount of shared memory to launch the kernel. + pub fn multi_head_attn_backward_data( + &self, + attn_desc: &AttentionDescriptor, + current_idx: i32, + lo_win_idx: &[i32], + hi_win_idx: &[i32], + device_seq_lengths_dqdo: &impl GpuBuffer, + device_seq_lengths_dkdv: &impl GpuBuffer, + do_desc: &SeqDataDescriptor, + d_out: &impl GpuBuffer, + dq_desc: &SeqDataDescriptor, + d_queries: &mut impl GpuBuffer, + queries: &impl GpuBuffer, + dk_desc: &SeqDataDescriptor, + d_keys: &mut impl GpuBuffer, + keys: &impl GpuBuffer, + dv_desc: &SeqDataDescriptor, + d_values: &mut impl GpuBuffer, + values: &impl GpuBuffer, + weights: &impl GpuBuffer, + work_space: &mut impl GpuBuffer, + reserve_space: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T: SeqDataType, + U: SupportedAttn, + D1: GpuBuffer, + D2: GpuBuffer, + { + let device_seq_lengths_dqdo_ptr = + device_seq_lengths_dqdo.as_device_ptr().as_ptr() as *const _; + let device_seq_lengths_dkdv_ptr = + device_seq_lengths_dqdo.as_device_ptr().as_ptr() as *const _; + + let d_out_ptr = d_out.as_device_ptr().as_ptr() as *const _; + + let d_queries_ptr = d_queries.as_device_ptr().as_mut_ptr() as *mut _; + let queries_ptr = queries.as_device_ptr().as_ptr() as *const _; + + let d_keys_ptr = d_keys.as_device_ptr().as_mut_ptr() as *mut _; + let keys_ptr = keys.as_device_ptr().as_ptr() as *const _; + + let d_values_ptr = d_values.as_device_ptr().as_mut_ptr() as *mut _; + let values_ptr = values.as_device_ptr().as_ptr() as *const _; + + let weights_ptr = weights.as_device_ptr().as_ptr() as *const _; + let work_space_ptr = work_space.as_device_ptr().as_mut_ptr() as *mut _; + let reserve_space_ptr = reserve_space.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnMultiHeadAttnBackwardData( + self.raw, + attn_desc.raw, + lo_win_idx.as_ptr(), + hi_win_idx.as_ptr(), + device_seq_lengths_dqdo_ptr, + device_seq_lengths_dkdv_ptr, + do_desc.raw, + d_out_ptr, + dq_desc.raw, + d_queries_ptr, + queries_ptr, + dk_desc.raw, + d_keys_ptr, + keys_ptr, + dv_desc.raw, + d_values_ptr, + values_ptr, + weights.len(), + weights_ptr, + work_space.len(), + work_space_ptr, + reserve_space.len(), + reserve_space_ptr, + ) + .into_result() + } + } + + /// This function computes exact, first-order derivatives of the multi-head attention block + /// with respect to its trainable parameters: projection weights and projection biases. + /// + /// All gradient results with respect to weights and biases are written to the `d_weights` + /// buffer. The size and the organization of the `d_weights` buffer is the same as the `weights` + /// buffer that holds multi-head attention weights and biases. + /// + /// Gradient of the loss function with respect to weights or biases is typically computed over + /// multiple batches. In such a case, partial results computed for each batch should be + /// summed together. The `grad_mode` argument specifies if the gradients from the current + /// batch should be added to previously computed results or the `d_weights` buffer should + /// be overwritten with the new results. + /// + /// **Do note** that this function should be invoked **after** `multi_head_attn_backward_data()`. + /// Also, the `queries`, `keys`, `values`, `weights`, and `reserve_space` arguments should be + /// the same as in `multi_head_attn_fwd()` and `multi_head_attn_backward_data()` calls. The + /// `d_out` argument should be the same as in `multi_head_attn_backward_data()`. + /// + /// # Arguments + /// + /// * `attn_desc` - multi-head attention descriptor. + /// + /// * `grad_mode` - gradient accumulation mode. + /// + /// * `q_desc` - descriptor for the query and residual sequence data. + /// + /// * `queries` - queries data in the device memory. + /// + /// * `k_desc` - descriptor for the keys sequence data. + /// + /// * `keys` - keys data in device memory. + /// + /// * `v_desc` - descriptor for the values sequence data. + /// + /// * `values` - values data in device memory. + /// + /// * `do_desc` - descriptor for the output differential sequence data. + /// + /// * `d_out` - output differential data in device memory. + /// + /// * `weights` - weights buffer in the device memory. + /// + /// * `d_weights` - weights gradient buffer in the device memory. + /// + /// * `work_space` - work space buffer in device memory. Used for temporary API storage. + /// + /// * `reserve_space` - reserve space buffer in device memory. + /// + /// # Errors + /// + /// Returns errors if an invalid or incompatible input argument was encountered, an inconsistent + /// internal state was encountered, a requested option or a combination of input arguments is + /// not supported or in case of insufficient amount of shared memory to launch the kernel. + pub fn multi_head_attn_backward_weights( + &self, + attn_desc: &AttentionDescriptor, + grad_mode: WGradMode, + q_desc: &SeqDataDescriptor, + queries: &impl GpuBuffer, + k_desc: &SeqDataDescriptor, + keys: &impl GpuBuffer, + v_desc: &SeqDataDescriptor, + values: &impl GpuBuffer, + do_desc: &SeqDataDescriptor, + d_out: &impl GpuBuffer, + weights: &impl GpuBuffer, + d_weights: &mut impl GpuBuffer, + work_space: &mut impl GpuBuffer, + reserve_space: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T: SeqDataType, + U: SupportedAttn, + D1: GpuBuffer, + D2: GpuBuffer, + { + let queries_ptr = queries.as_device_ptr().as_ptr() as *const _; + let keys_ptr = keys.as_device_ptr().as_ptr() as *const _; + let values_ptr = values.as_device_ptr().as_ptr() as *const _; + + let d_out_ptr = d_out.as_device_ptr().as_ptr() as *const _; + + let weights_ptr = weights.as_device_ptr().as_ptr() as *const _; + let d_weights_ptr = d_weights.as_device_ptr().as_mut_ptr() as *mut _; + let work_space_ptr = work_space.as_device_ptr().as_mut_ptr() as *mut _; + let reserve_space_ptr = reserve_space.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnMultiHeadAttnBackwardWeights( + self.raw, + attn_desc.raw, + grad_mode.into(), + q_desc.raw, + queries_ptr, + k_desc.raw, + keys_ptr, + v_desc.raw, + values_ptr, + do_desc.raw, + d_out_ptr, + weights.len(), + weights_ptr, + d_weights_ptr, + work_space.len(), + work_space_ptr, + reserve_space.len(), + reserve_space_ptr, + ) + .into_result() + } + } } diff --git a/crates/cudnn/src/rnn/mod.rs b/crates/cudnn/src/rnn/mod.rs index d0e698e3..04806545 100644 --- a/crates/cudnn/src/rnn/mod.rs +++ b/crates/cudnn/src/rnn/mod.rs @@ -511,7 +511,7 @@ impl CudnnContext { /// /// * `rnn_desc` - RNN descriptor. /// - /// * `add_grad` - weight gradient output mode. Only `WGradMode::Add` is supported. + /// * `grad_mode` - weight gradient output mode. Only `WGradMode::Add` is supported. /// /// * `device_seq_lengths` - a copy of `seq_lengths` from `x_desc` or `y_desc` RNN data /// descriptors. The `device_seq_lengths` array must be stored in GPU memory as it is accessed @@ -546,7 +546,7 @@ impl CudnnContext { pub fn rnn_backward_weights( &self, rnn_desc: &RnnDescriptor, - add_grad: WGradMode, + grad_mode: WGradMode, device_seq_lengths: &impl GpuBuffer, x_desc: &RnnDataDescriptor, x: &impl GpuBuffer, @@ -576,7 +576,7 @@ impl CudnnContext { sys::cudnnRNNBackwardWeights_v8( self.raw, rnn_desc.raw, - add_grad.into(), + grad_mode.into(), device_sequence_lengths_ptr, x_desc.raw, x_ptr, From 91f597896d8c30ff509fa8c7cb80d67d9c852c24 Mon Sep 17 00:00:00 2001 From: frjnn Date: Wed, 16 Mar 2022 14:06:57 +0100 Subject: [PATCH 3/3] Feat: Add convolution fused bias activation --- .../cudnn/src/activation/activation_mode.rs | 6 + crates/cudnn/src/convolution/mod.rs | 318 ++++++++++++++---- 2 files changed, 263 insertions(+), 61 deletions(-) diff --git a/crates/cudnn/src/activation/activation_mode.rs b/crates/cudnn/src/activation/activation_mode.rs index 82437d92..94c23ed3 100644 --- a/crates/cudnn/src/activation/activation_mode.rs +++ b/crates/cudnn/src/activation/activation_mode.rs @@ -15,6 +15,11 @@ pub enum ActivationMode { Elu, /// Selects the swish function. Swish, + /// Selects no activation. + /// + /// **Do note** that this is only valid for an activation descriptor passed to + /// [`convolution_bias_act_forward()`](CudnnContext::convolution_bias_act_fwd). + Identity, } impl From for sys::cudnnActivationMode_t { @@ -26,6 +31,7 @@ impl From for sys::cudnnActivationMode_t { ActivationMode::ClippedRelu => Self::CUDNN_ACTIVATION_CLIPPED_RELU, ActivationMode::Elu => Self::CUDNN_ACTIVATION_ELU, ActivationMode::Swish => Self::CUDNN_ACTIVATION_SWISH, + ActivationMode::Identity => Self::CUDNN_ACTIVATION_IDENTITY, } } } diff --git a/crates/cudnn/src/convolution/mod.rs b/crates/cudnn/src/convolution/mod.rs index 3eec41a5..1757f768 100644 --- a/crates/cudnn/src/convolution/mod.rs +++ b/crates/cudnn/src/convolution/mod.rs @@ -10,7 +10,9 @@ pub use convolution_descriptor::*; pub use convolution_mode::*; pub use filter_descriptor::*; -use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; +use crate::{ + sys, ActivationDescriptor, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor, +}; use cust::memory::GpuBuffer; use std::mem::MaybeUninit; @@ -72,17 +74,17 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn get_convolution_forward_algorithm( + pub fn get_convolution_forward_algorithm( &self, x_desc: &TensorDescriptor, w_desc: &FilterDescriptor, y_desc: &TensorDescriptor, - conv_desc: &ConvDescriptor, + conv_desc: &ConvDescriptor, ) -> Result, CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, { let mut returned_algo_count = MaybeUninit::uninit(); @@ -184,17 +186,17 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn get_convolution_backward_data_algorithm( + pub fn get_convolution_backward_data_algorithm( &self, w_desc: &FilterDescriptor, dy_desc: &TensorDescriptor, dx_desc: &TensorDescriptor, - conv_desc: &ConvDescriptor, + conv_desc: &ConvDescriptor, ) -> Result, CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, { let mut returned_algo_count = MaybeUninit::uninit(); @@ -296,17 +298,17 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn get_convolution_backward_filter_algorithm( + pub fn get_convolution_backward_filter_algorithm( &self, x_desc: &TensorDescriptor, dy_desc: &TensorDescriptor, dw_desc: &FilterDescriptor, - conv_desc: &ConvDescriptor, + conv_desc: &ConvDescriptor, ) -> Result, CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, { let mut returned_algo_count = MaybeUninit::uninit(); @@ -418,25 +420,25 @@ impl CudnnContext { /// &w_desc, /// &y_desc, /// &conv_desc, - /// &algo, + /// algo, /// )?; /// /// let workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); /// # Ok(()) /// # } /// ``` - pub fn get_convolution_forward_workspace_size( + pub fn get_convolution_forward_workspace_size( &self, x_desc: &TensorDescriptor, w_desc: &FilterDescriptor, y_desc: &TensorDescriptor, - conv_desc: &ConvDescriptor, - algo: &ConvFwdAlgo, + conv_desc: &ConvDescriptor, + algo: ConvFwdAlgo, ) -> Result, CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, { let mut size = MaybeUninit::uninit(); @@ -448,7 +450,7 @@ impl CudnnContext { w_desc.raw, conv_desc.raw, y_desc.raw, - (*algo).into(), + algo.into(), size.as_mut_ptr(), ) .into_result()?; @@ -525,25 +527,25 @@ impl CudnnContext { /// &dy_desc, /// &dx_desc, /// &conv_desc, - /// &algo, + /// algo, /// )?; /// /// let workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); /// # Ok(()) /// # } /// ``` - pub fn get_convolution_backward_data_workspace_size( + pub fn get_convolution_backward_data_workspace_size( &self, w_desc: &FilterDescriptor, dy_desc: &TensorDescriptor, dx_desc: &TensorDescriptor, - conv_desc: &ConvDescriptor, - algo: &ConvBwdDataAlgo, + conv_desc: &ConvDescriptor, + algo: ConvBwdDataAlgo, ) -> Result, CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, { let mut size = MaybeUninit::uninit(); @@ -555,7 +557,7 @@ impl CudnnContext { dy_desc.raw, conv_desc.raw, dx_desc.raw, - (*algo).into(), + algo.into(), size.as_mut_ptr(), ) .into_result()?; @@ -632,25 +634,25 @@ impl CudnnContext { /// &dy_desc, /// &dw_desc, /// &conv_desc, - /// &algo, + /// algo, /// )?; /// /// let workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); /// # Ok(()) /// # } /// ``` - pub fn get_convolution_backward_filter_workspace_size( + pub fn get_convolution_backward_filter_workspace_size( &self, x_desc: &TensorDescriptor, dy_desc: &TensorDescriptor, dw_desc: &FilterDescriptor, - conv_desc: &ConvDescriptor, - algo: &ConvBwdFilterAlgo, + conv_desc: &ConvDescriptor, + algo: ConvBwdFilterAlgo, ) -> Result, CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, { let mut size = MaybeUninit::uninit(); @@ -662,7 +664,7 @@ impl CudnnContext { dy_desc.raw, conv_desc.raw, dw_desc.raw, - (*algo).into(), + algo.into(), size.as_mut_ptr(), ) .into_result()?; @@ -753,7 +755,7 @@ impl CudnnContext { /// &w_desc, /// &y_desc, /// &conv_desc, - /// &algo, + /// algo, /// )?; /// /// let mut workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); @@ -768,7 +770,7 @@ impl CudnnContext { /// &w_desc, /// &w, /// &conv_desc, - /// &algo, + /// algo, /// workspace.as_mut(), /// beta, /// &y_desc, @@ -777,24 +779,24 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn convolution_forward( + pub fn convolution_forward( &self, - alpha: CompType, + alpha: CompT, x_desc: &TensorDescriptor, x: &impl GpuBuffer, w_desc: &FilterDescriptor, w: &impl GpuBuffer, - conv_desc: &ConvDescriptor, - algo: &ConvFwdAlgo, + conv_desc: &ConvDescriptor, + algo: ConvFwdAlgo, work_space: Option<&mut W>, - beta: CompType, + beta: CompT, y_desc: &TensorDescriptor, y: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, W: GpuBuffer, { @@ -802,8 +804,8 @@ impl CudnnContext { let w_data = w.as_device_ptr().as_ptr() as *const std::ffi::c_void; let y_data = y.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; - let alpha = &alpha as *const CompType as *const std::ffi::c_void; - let beta = &beta as *const CompType as *const std::ffi::c_void; + let alpha = &alpha as *const CompT as *const std::ffi::c_void; + let beta = &beta as *const CompT as *const std::ffi::c_void; // If the size is 0 then the algorithm can work in-place and cuDNN expects a null // pointer. @@ -825,7 +827,7 @@ impl CudnnContext { w_desc.raw, w_data, conv_desc.raw, - (*algo).into(), + algo.into(), work_space_ptr, work_space_size, beta, @@ -836,6 +838,200 @@ impl CudnnContext { } } + /// This function applies a bias and then an activation to the convolutions or + /// cross-correlation output: + /// + /// y = act ( alpha * conv(x) + beta * z + bias ) + /// + /// Results are returned in y. + /// + /// # Arguments + /// + /// * `alpha` - scaling parameter. + /// + /// * `x_desc` - input map descriptor. + /// + /// * `x` - input map data. + /// + /// * `w_desc` - filter descriptor. + /// + /// * `w` - filter data. + /// + /// * `conv_desc` - convolution descriptor. + /// + /// * `algo` - convolution algorithm that should be used to compute the result. + /// + /// * `work_space` - a buffer to GPU memory to a workspace needed to be able to execute the + /// specified algorithm. Must be left to `None` if the algorithm works in-place. The workspace + /// dimension can be obtained with `get_convolution_forward_workspace_size`. + /// + /// * `beta` - scaling parameter. + /// + /// * `z_desc` - descriptor for the z tensor. + /// + /// * `z` - data for the z tensor. + /// + /// * `bias_desc` - descriptor for the bias tensor. + /// + /// * `bias` - data for the bias tensor. + /// + /// * `activation_desc` - neuron activation function descriptor. + /// + /// * `y_desc` - output map descriptor. + /// + /// * `y` - data for the output map. + /// + /// **Do note** that `y_desc` and `z_desc` should match. + /// + /// # Errors + /// + /// Returns errors if an invalid or unsupported combination of argument is passed. + /// + /// # Examples + /// + /// ``` + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use cudnn::{ + /// ActivationDescriptor, ActivationMode, ConvDescriptor, ConvFwdAlgo, ConvMode, + /// CudnnContext, FilterDescriptor, NanPropagation, ScalarC, TensorDescriptor + /// }; + /// use cust::memory::DeviceBuffer; + /// + /// let ctx = CudnnContext::new()?; + /// + /// let padding = [0, 0]; + /// let stride = [1, 1]; + /// let dilation = [1, 1]; + /// let mode = ConvMode::CrossCorrelation; + /// + /// let conv_desc = ConvDescriptor::::new(padding, stride, dilation, mode)?; + /// + /// # let data = vec![1.0_f32; 150]; + /// # let x = DeviceBuffer::from_slice(&data)?; + /// # let w = DeviceBuffer::from_slice(&data[..24])?; + /// # let z = DeviceBuffer::from_slice(&data[..144])?; + /// # let bias = DeviceBuffer::from_slice(&data[..3])?; + /// # let mut y = DeviceBuffer::from_slice(&data[..144])?; + /// let x_desc = TensorDescriptor::::new_format(&[3, 2, 5, 5,], ScalarC::Nchw)?; + /// let w_desc = FilterDescriptor::::new(&[3, 2, 2, 2], ScalarC::Nchw)?; + /// let y_desc = TensorDescriptor::::new_format(&[3, 3, 4, 4], ScalarC::Nchw)?; + /// + /// let algo = ConvFwdAlgo::ImplicitPrecompGemm; + /// + /// let size = ctx.get_convolution_forward_workspace_size( + /// &x_desc, + /// &w_desc, + /// &y_desc, + /// &conv_desc, + /// algo, + /// )?; + /// + /// let mut workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); + /// + /// let alpha = 1.; + /// let beta = 0.; + /// + /// let z_desc = TensorDescriptor::::new_format(&[3, 3, 4, 4], ScalarC::Nchw)?; + /// let bias_desc = TensorDescriptor::::new_format(&[1, 3, 1, 1], ScalarC::Nchw)?; + /// + /// let mode = ActivationMode::Relu; + /// let nan_opt = NanPropagation::NotPropagateNaN; + /// let coefficient = None; + /// + /// let activation_desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?; + /// + /// ctx.convolution_bias_act_forward( + /// alpha, + /// &x_desc, + /// &x, + /// &w_desc, + /// &w, + /// &conv_desc, + /// algo, + /// workspace.as_mut(), + /// beta, + /// &z_desc, + /// &z, + /// &bias_desc, + /// &bias, + /// &activation_desc, + /// &y_desc, + /// &mut y + /// )?; + /// # Ok(()) + /// # } + /// ``` + pub fn convolution_bias_act_forward( + &self, + alpha: CompT, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + w_desc: &FilterDescriptor, + w: &impl GpuBuffer, + conv_desc: &ConvDescriptor, + algo: ConvFwdAlgo, + work_space: Option<&mut W>, + beta: CompT, + z_desc: &TensorDescriptor, + z: &impl GpuBuffer, + bias_desc: &TensorDescriptor, + bias: &impl GpuBuffer, + activation_desc: &ActivationDescriptor, + y_desc: &TensorDescriptor, + y: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T1: DataType, + T2: DataType, + CompT: SupportedConv, + T3: DataType, + W: GpuBuffer, + { + let x_data = x.as_device_ptr().as_ptr() as *const std::ffi::c_void; + let w_data = w.as_device_ptr().as_ptr() as *const std::ffi::c_void; + let z_data = z.as_device_ptr().as_ptr() as *const std::ffi::c_void; + let bias_data = bias.as_device_ptr().as_ptr() as *const std::ffi::c_void; + let y_data = y.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; + + let alpha = &alpha as *const CompT as *const std::ffi::c_void; + let beta = &beta as *const CompT as *const std::ffi::c_void; + + let (work_space_ptr, work_space_size) = { + work_space.map_or((std::ptr::null_mut(), 0), |work_space| { + ( + work_space.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void, + work_space.len(), + ) + }) + }; + + unsafe { + sys::cudnnConvolutionBiasActivationForward( + self.raw, + alpha, + x_desc.raw, + x_data, + w_desc.raw, + w_data, + conv_desc.raw, + algo.into(), + work_space_ptr, + work_space_size, + beta, + z_desc.raw, + z_data, + bias_desc.raw, + bias_data, + activation_desc.raw, + y_desc.raw, + y_data, + ) + .into_result() + } + } + /// This function computes the convolution data gradient of the tensor `dy`, where `y` is the /// output of the forward convolution in `convolution_forward`. /// @@ -909,7 +1105,7 @@ impl CudnnContext { /// &dy_desc, /// &dx_desc, /// &conv_desc, - /// &algo, + /// algo, /// )?; /// /// let mut workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); @@ -924,7 +1120,7 @@ impl CudnnContext { /// &dy_desc, /// &dy, /// &conv_desc, - /// &algo, + /// algo, /// workspace.as_mut(), /// beta, /// &dx_desc, @@ -933,24 +1129,24 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn convolution_backward_data( + pub fn convolution_backward_data( &self, - alpha: CompType, + alpha: CompT, w_desc: &FilterDescriptor, w: &impl GpuBuffer, dy_desc: &TensorDescriptor, dy: &impl GpuBuffer, - conv_desc: &ConvDescriptor, - algo: &ConvBwdDataAlgo, + conv_desc: &ConvDescriptor, + algo: ConvBwdDataAlgo, work_space: Option<&mut W>, - beta: CompType, + beta: CompT, dx_desc: &TensorDescriptor, dx: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, W: GpuBuffer, { @@ -958,8 +1154,8 @@ impl CudnnContext { let dy_data = dy.as_device_ptr().as_ptr() as *const std::ffi::c_void; let dx_data = dx.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; - let alpha = &alpha as *const CompType as *const std::ffi::c_void; - let beta = &beta as *const CompType as *const std::ffi::c_void; + let alpha = &alpha as *const CompT as *const std::ffi::c_void; + let beta = &beta as *const CompT as *const std::ffi::c_void; let (work_space_ptr, work_space_size) = { work_space.map_or((std::ptr::null_mut(), 0), |work_space| { @@ -979,7 +1175,7 @@ impl CudnnContext { dy_desc.raw, dy_data, conv_desc.raw, - (*algo).into(), + algo.into(), work_space_ptr, work_space_size, beta, @@ -1063,7 +1259,7 @@ impl CudnnContext { /// &dy_desc, /// &dw_desc, /// &conv_desc, - /// &algo, + /// algo, /// )?; /// /// let mut workspace = size.map(|size| unsafe { DeviceBuffer::::uninitialized(size).unwrap() }); @@ -1078,7 +1274,7 @@ impl CudnnContext { /// &dy_desc, /// &dy, /// &conv_desc, - /// &algo, + /// algo, /// workspace.as_mut(), /// beta, /// &dw_desc, @@ -1087,24 +1283,24 @@ impl CudnnContext { /// # Ok(()) /// # } /// ``` - pub fn convolution_backward_filter( + pub fn convolution_backward_filter( &self, - alpha: CompType, + alpha: CompT, x_desc: &TensorDescriptor, x: &impl GpuBuffer, dy_desc: &TensorDescriptor, y: &impl GpuBuffer, - conv_desc: &ConvDescriptor, - algo: &ConvBwdFilterAlgo, + conv_desc: &ConvDescriptor, + algo: ConvBwdFilterAlgo, work_space: Option<&mut W>, - beta: CompType, + beta: CompT, dw_desc: &FilterDescriptor, dw: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where T1: DataType, T2: DataType, - CompType: SupportedConv, + CompT: SupportedConv, T3: DataType, W: GpuBuffer, { @@ -1112,8 +1308,8 @@ impl CudnnContext { let dy_data = y.as_device_ptr().as_ptr() as *const std::ffi::c_void; let dw_data = dw.as_device_ptr().as_mut_ptr() as *mut std::ffi::c_void; - let alpha = &alpha as *const CompType as *const std::ffi::c_void; - let beta = &beta as *const CompType as *const std::ffi::c_void; + let alpha = &alpha as *const CompT as *const std::ffi::c_void; + let beta = &beta as *const CompT as *const std::ffi::c_void; let (work_space_ptr, work_space_size) = { work_space.map_or((std::ptr::null_mut(), 0), |work_space| { @@ -1133,7 +1329,7 @@ impl CudnnContext { dy_desc.raw, dy_data, conv_desc.raw, - (*algo).into(), + algo.into(), work_space_ptr, work_space_size, beta,