From f4ca159fc566c681357fbf39d7af7d5385a6fc91 Mon Sep 17 00:00:00 2001 From: frjnn Date: Sun, 13 Mar 2022 21:33:11 +0100 Subject: [PATCH 1/3] Fix: Missing references --- crates/cudnn/src/pooling/mod.rs | 2 +- crates/cudnn/src/softmax/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/cudnn/src/pooling/mod.rs b/crates/cudnn/src/pooling/mod.rs index 44f3cdcd..6fb9972e 100644 --- a/crates/cudnn/src/pooling/mod.rs +++ b/crates/cudnn/src/pooling/mod.rs @@ -1,11 +1,11 @@ mod pooling_descriptor; mod pooling_mode; -use cust::memory::GpuBuffer; pub use pooling_descriptor::*; pub use pooling_mode::*; use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; +use cust::memory::GpuBuffer; impl CudnnContext { /// This function computes the pooling of the input tensor and produces a smaller tensor in diff --git a/crates/cudnn/src/softmax/mod.rs b/crates/cudnn/src/softmax/mod.rs index 12c0987a..ce2aff9c 100644 --- a/crates/cudnn/src/softmax/mod.rs +++ b/crates/cudnn/src/softmax/mod.rs @@ -38,7 +38,7 @@ impl CudnnContext { mode: SoftmaxMode, alpha: CompT, x_desc: &TensorDescriptor, - x: impl GpuBuffer, + x: &impl GpuBuffer, beta: CompT, y_desc: &TensorDescriptor, y: &mut impl GpuBuffer, @@ -103,7 +103,7 @@ impl CudnnContext { mode: SoftmaxMode, alpha: CompT, y_desc: &TensorDescriptor, - y: impl GpuBuffer, + y: &impl GpuBuffer, dy_desc: &TensorDescriptor, dy: &impl GpuBuffer, beta: CompT, From 001fd1435e5a077926e352449cd761335b19a906 Mon Sep 17 00:00:00 2001 From: frjnn Date: Mon, 14 Mar 2022 16:12:50 +0100 Subject: [PATCH 2/3] Feat: Add activation forward and backward --- .../src/activation/activation_descriptor.rs | 71 ++++++ .../cudnn/src/activation/activation_mode.rs | 31 +++ crates/cudnn/src/activation/mod.rs | 211 ++++++++++++++++++ .../src/attention/attention_descriptor.rs | 1 + .../src/convolution/convolution_descriptor.rs | 2 +- .../src/convolution/filter_descriptor.rs | 2 +- .../cudnn/src/dropout/dropout_descriptor.rs | 1 + crates/cudnn/src/lib.rs | 6 +- crates/cudnn/src/{op_tensor => op}/mod.rs | 0 .../{op_tensor => op}/op_tensor_descriptor.rs | 2 + .../src/{op_tensor => op}/op_tensor_op.rs | 0 crates/cudnn/src/pooling/mod.rs | 56 ++--- crates/cudnn/src/rnn/rnn_data_descriptor.rs | 2 + crates/cudnn/src/rnn/rnn_descriptor.rs | 2 +- crates/cudnn/src/tensor/tensor_descriptor.rs | 2 +- 15 files changed, 355 insertions(+), 34 deletions(-) create mode 100644 crates/cudnn/src/activation/activation_descriptor.rs create mode 100644 crates/cudnn/src/activation/activation_mode.rs create mode 100644 crates/cudnn/src/activation/mod.rs rename crates/cudnn/src/{op_tensor => op}/mod.rs (100%) rename crates/cudnn/src/{op_tensor => op}/op_tensor_descriptor.rs (98%) rename crates/cudnn/src/{op_tensor => op}/op_tensor_op.rs (100%) diff --git a/crates/cudnn/src/activation/activation_descriptor.rs b/crates/cudnn/src/activation/activation_descriptor.rs new file mode 100644 index 00000000..eb4ed2d8 --- /dev/null +++ b/crates/cudnn/src/activation/activation_descriptor.rs @@ -0,0 +1,71 @@ +use crate::{sys, ActivationMode, CudnnContext, CudnnError, IntoResult, NanPropagation}; +use std::mem::MaybeUninit; + +/// The descriptor of a neuron activation operation. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ActivationDescriptor { + pub(crate) raw: sys::cudnnActivationDescriptor_t, +} + +impl ActivationDescriptor { + /// Creates a new neuron activation descriptor. + /// + /// # Arguments + /// + /// * `mode` - activation function to compute. + /// + /// * `nan_opt` - NaN propagation policy for the operation. + /// + /// * `coefficient` - optional coefficient for the given function. It specifies the clipping + /// threshold for `ActivationMode::ClippedRelu`. + /// + /// # Examples + /// + /// ``` + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use cudnn::{ActivationDescriptor, ActivationMode, CudnnContext, NanPropagation}; + /// + /// let ctx = CudnnContext::new()?; + /// + /// let mode = ActivationMode::Swish; + /// let nan_opt = NanPropagation::PropagateNaN; + /// let coefficient = None; + /// + /// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?; + /// # Ok(()) + /// # } + /// ``` + pub fn new( + mode: ActivationMode, + nan_opt: NanPropagation, + coefficient: impl Into>, + ) -> Result { + let mut raw = MaybeUninit::uninit(); + + unsafe { + sys::cudnnCreateActivationDescriptor(raw.as_mut_ptr()).into_result()?; + + let mut raw = raw.assume_init(); + + let coefficient = coefficient.into().unwrap_or_else(|| match mode { + ActivationMode::ClippedRelu => std::f64::MAX, + _ => 1.0, + }); + + sys::cudnnSetActivationDescriptor(raw, mode.into(), nan_opt.into(), coefficient) + .into_result()?; + + Ok(Self { raw }) + } + } +} + +impl Drop for ActivationDescriptor { + fn drop(&mut self) { + unsafe { + sys::cudnnDestroyActivationDescriptor(self.raw); + } + } +} diff --git a/crates/cudnn/src/activation/activation_mode.rs b/crates/cudnn/src/activation/activation_mode.rs new file mode 100644 index 00000000..82437d92 --- /dev/null +++ b/crates/cudnn/src/activation/activation_mode.rs @@ -0,0 +1,31 @@ +use crate::sys; + +/// Specifies a neuron activation function. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ActivationMode { + /// Selects the sigmoid function. + Sigmoid, + /// Selects the rectified linear function. + Relu, + /// Selects the hyperbolic tangent function. + Tanh, + /// Selects the clipped rectified linear function. + ClippedRelu, + /// Selects the exponential linear function. + Elu, + /// Selects the swish function. + Swish, +} + +impl From for sys::cudnnActivationMode_t { + fn from(mode: ActivationMode) -> Self { + match mode { + ActivationMode::Sigmoid => Self::CUDNN_ACTIVATION_SIGMOID, + ActivationMode::Relu => Self::CUDNN_ACTIVATION_RELU, + ActivationMode::Tanh => Self::CUDNN_ACTIVATION_TANH, + ActivationMode::ClippedRelu => Self::CUDNN_ACTIVATION_CLIPPED_RELU, + ActivationMode::Elu => Self::CUDNN_ACTIVATION_ELU, + ActivationMode::Swish => Self::CUDNN_ACTIVATION_SWISH, + } + } +} diff --git a/crates/cudnn/src/activation/mod.rs b/crates/cudnn/src/activation/mod.rs new file mode 100644 index 00000000..71d310cb --- /dev/null +++ b/crates/cudnn/src/activation/mod.rs @@ -0,0 +1,211 @@ +mod activation_descriptor; +mod activation_mode; + +pub use activation_descriptor::*; +pub use activation_mode::*; + +use crate::{private, sys, CudnnContext, CudnnError, DataType, IntoResult, TensorDescriptor}; +use cust::memory::GpuBuffer; +use std::mem::MaybeUninit; + +impl CudnnContext { + /// Applies a specific neuron activation functions element wise of the provided tensor. + /// + /// # Arguments + /// + /// * `activation_desc` - activation descriptor. + /// + /// * `alpha` - scaling factor for the result. + /// + /// * `x_desc` - tensor descriptor for the input. + /// + /// * `x` - data for the input tensor. + /// + /// * `beta` - scaling factor for the destination tensor. + /// + /// * `y_desc` - tensor descriptor for the output. + /// + /// * `y` - data for the output. + /// + /// # Errors + /// + /// Returns errors if the shapes of the `y` and `x` tensors do not match or an unsupported + /// configuration of arguments is detected. + /// + /// # Examples + /// + /// ``` + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use cudnn::{ActivationDescriptor, ActivationMode, CudnnContext, NanPropagation, TensorDescriptor}; + /// use cust::memory::DeviceBuffer; + /// + /// let ctx = CudnnContext::new()?; + /// + /// let mode = ActivationMode::Swish; + /// let nan_opt = NanPropagation::PropagateNaN; + /// let coefficient = None; + /// + /// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?; + /// + /// let alpha: f32 = 1.0; + /// let x_desc = TensorDescriptor::::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?; + /// let x = DeviceBuffer::::from_slice(&[10, 10, 10, 10, 10])?; + /// + /// let beta: f32 = 0.0; + /// let y_desc = TensorDescriptor::::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?; + /// let mut y = DeviceBuffer::::from_slice(&[0, 0, 0, 0, 0])?; + /// + /// ctx.activation_forward(&desc, alpha, &x_desc, &x, beta, &y_desc, &mut y)?; + /// + /// let y_host = y.as_host_vec()?; + /// + /// assert!(y_host.iter().all(|el| *el == 10)); + /// # Ok(()) + /// # } + /// ``` + pub fn activation_forward( + &self, + activation_desc: &ActivationDescriptor, + alpha: CompT, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + beta: CompT, + y_desc: &TensorDescriptor, + y: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + CompT: SupportedActFwd, + T: DataType, + { + let alpha_ptr = &alpha as *const CompT as *const _; + let x_ptr = x.as_device_ptr().as_ptr() as *const _; + + let beta_ptr = &beta as *const CompT as *const _; + let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnActivationForward( + self.raw, + activation_desc.raw, + alpha_ptr, + x_desc.raw, + x_ptr, + beta_ptr, + y_desc.raw, + y_ptr, + ) + .into_result() + } + } + + /// Computes the gradient of a neuron activation function. + /// + /// # Arguments + /// + /// * `activation_descriptor` - descriptor of a neuron activation operation. + /// + /// * `alpha` - scaling factor for the result. + /// + /// * `y_desc` - tensor descriptor for the output map. + /// + /// * `y` - data for the output map. + /// + /// * `dy_desc` - tensor descriptor for the differential of the output map. + /// + /// * `dy` - data foe the differential of the output map. + /// + /// * `x_desc` - tensor descriptor for the activation input. + /// + /// * `x` - data for the activation input. + /// + /// * `beta` - scaling factor for the destination tensor. + /// + /// * `dx_desc` - tensor descriptor for the input differential. + /// + /// * `dx` - data for the input differential. + /// + /// # Errors + /// + /// Returns errors if the shapes of the `dx` and `x` tensors do not match, the strides of the + /// tensors and their differential do not match, or an unsupported configuration of arguments + /// is detected. + pub fn activation_backward( + &self, + activation_desc: &ActivationDescriptor, + alpha: CompT, + y_desc: &TensorDescriptor, + y: &impl GpuBuffer, + dy_desc: &TensorDescriptor, + dy: &impl GpuBuffer, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + beta: CompT, + dx_desc: &TensorDescriptor, + dx: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + CompT: SupportedActBwd, + T: DataType, + { + let alpha_ptr = &alpha as *const CompT as *const _; + + let y_ptr = y.as_device_ptr().as_ptr() as *const _; + let dy_ptr = dy.as_device_ptr().as_ptr() as *const _; + let x_ptr = x.as_device_ptr().as_ptr() as *const _; + + let beta_ptr = &beta as *const CompT as *const _; + + let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnActivationBackward( + self.raw, + activation_desc.raw, + alpha_ptr, + y_desc.raw, + y_ptr, + dy_desc.raw, + dy_ptr, + x_desc.raw, + x_ptr, + beta_ptr, + dx_desc.raw, + dx_ptr, + ) + .into_result() + } + } +} + +/// Supported data type configurations for the activation forward operation. +pub trait SupportedActFwd: DataType + private::Sealed +where + T: DataType, +{ +} + +impl SupportedActFwd for f32 {} +impl SupportedActFwd for f32 {} +impl SupportedActFwd for f32 {} +impl SupportedActFwd for f32 {} +impl SupportedActFwd for f32 {} +impl SupportedActFwd for f32 {} + +impl SupportedActFwd for f64 {} +impl SupportedActFwd for f64 {} +impl SupportedActFwd for f64 {} +impl SupportedActFwd for f64 {} +impl SupportedActFwd for f64 {} +impl SupportedActFwd for f64 {} + +/// Supported type configurations for the activation backward operation. +pub trait SupportedActBwd: DataType + private::Sealed +where + T: DataType, +{ +} + +impl SupportedActBwd for f32 {} +impl SupportedActBwd for f64 {} diff --git a/crates/cudnn/src/attention/attention_descriptor.rs b/crates/cudnn/src/attention/attention_descriptor.rs index cc4db6a5..30c01710 100644 --- a/crates/cudnn/src/attention/attention_descriptor.rs +++ b/crates/cudnn/src/attention/attention_descriptor.rs @@ -21,6 +21,7 @@ bitflags::bitflags! { } /// A multi-head attention descriptor. +#[derive(Debug, PartialEq, Eq, Hash)] pub struct AttentionDescriptor where T: SeqDataType, diff --git a/crates/cudnn/src/convolution/convolution_descriptor.rs b/crates/cudnn/src/convolution/convolution_descriptor.rs index a35723dd..1451e1b4 100644 --- a/crates/cudnn/src/convolution/convolution_descriptor.rs +++ b/crates/cudnn/src/convolution/convolution_descriptor.rs @@ -9,7 +9,7 @@ use std::{marker::PhantomData, mem::MaybeUninit}; /// **Do note** that N can be either 2 or 3, respectively for a 2-d or a 3-d convolution, and that /// the same convolution descriptor can be reused in the backward path provided it corresponds to /// the same layer. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ConvDescriptor { pub(crate) raw: sys::cudnnConvolutionDescriptor_t, comp_type: PhantomData, diff --git a/crates/cudnn/src/convolution/filter_descriptor.rs b/crates/cudnn/src/convolution/filter_descriptor.rs index 81f9e43f..43583236 100644 --- a/crates/cudnn/src/convolution/filter_descriptor.rs +++ b/crates/cudnn/src/convolution/filter_descriptor.rs @@ -5,7 +5,7 @@ use std::{ }; /// A generic description of an n-dimensional filter dataset. -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct FilterDescriptor where T: DataType, diff --git a/crates/cudnn/src/dropout/dropout_descriptor.rs b/crates/cudnn/src/dropout/dropout_descriptor.rs index 7e0c4a30..e10be80e 100644 --- a/crates/cudnn/src/dropout/dropout_descriptor.rs +++ b/crates/cudnn/src/dropout/dropout_descriptor.rs @@ -2,6 +2,7 @@ use crate::{error::CudnnError, sys, IntoResult}; use cust::memory::GpuBuffer; /// The descriptor of a dropout operation. +#[derive(Debug, PartialEq, Eq, Hash)] pub struct DropoutDescriptor where T: GpuBuffer, diff --git a/crates/cudnn/src/lib.rs b/crates/cudnn/src/lib.rs index bc682fc6..15a9774d 100644 --- a/crates/cudnn/src/lib.rs +++ b/crates/cudnn/src/lib.rs @@ -1,6 +1,7 @@ #![allow(warnings, clippy::all)] mod sys; +mod activation; mod attention; mod backend; mod context; @@ -11,13 +12,14 @@ mod dropout; mod error; mod math_type; mod nan_propagation; -mod op_tensor; +mod op; mod pooling; mod rnn; mod softmax; mod tensor; mod w_grad_mode; +pub use activation::*; pub use attention::*; pub use context::*; pub use convolution::*; @@ -27,7 +29,7 @@ pub use dropout::*; pub use error::*; pub use math_type::*; pub use nan_propagation::*; -pub use op_tensor::*; +pub use op::*; pub use pooling::*; pub use rnn::*; pub use softmax::*; diff --git a/crates/cudnn/src/op_tensor/mod.rs b/crates/cudnn/src/op/mod.rs similarity index 100% rename from crates/cudnn/src/op_tensor/mod.rs rename to crates/cudnn/src/op/mod.rs diff --git a/crates/cudnn/src/op_tensor/op_tensor_descriptor.rs b/crates/cudnn/src/op/op_tensor_descriptor.rs similarity index 98% rename from crates/cudnn/src/op_tensor/op_tensor_descriptor.rs rename to crates/cudnn/src/op/op_tensor_descriptor.rs index 126a7624..e388e0ac 100644 --- a/crates/cudnn/src/op_tensor/op_tensor_descriptor.rs +++ b/crates/cudnn/src/op/op_tensor_descriptor.rs @@ -30,6 +30,7 @@ unsafe fn init_raw_op_descriptor( /// As specified in the cuDNN [docs](https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters), /// admissible types for scaling parameters are `f32` and `f64` for `f32` and `f64` tensors /// respectively. +#[derive(Debug, PartialEq, Eq, Hash)] pub struct UnaryOpTensorDescriptor { pub(crate) raw: sys::cudnnOpTensorDescriptor_t, comp_type: PhantomData, @@ -91,6 +92,7 @@ impl Drop for UnaryOpTensorDescriptor { /// As specified in the cuDNN [docs](https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters), /// admissible types for scaling parameters are `f32` and `f64` for `f32` and `f64` tensors /// respectively. +#[derive(Debug, PartialEq, Eq, Hash)] pub struct BinaryOpTensorDescriptor { pub(crate) raw: sys::cudnnOpTensorDescriptor_t, comp_type: PhantomData, diff --git a/crates/cudnn/src/op_tensor/op_tensor_op.rs b/crates/cudnn/src/op/op_tensor_op.rs similarity index 100% rename from crates/cudnn/src/op_tensor/op_tensor_op.rs rename to crates/cudnn/src/op/op_tensor_op.rs diff --git a/crates/cudnn/src/pooling/mod.rs b/crates/cudnn/src/pooling/mod.rs index 6fb9972e..4b2c9309 100644 --- a/crates/cudnn/src/pooling/mod.rs +++ b/crates/cudnn/src/pooling/mod.rs @@ -31,24 +31,24 @@ impl CudnnContext { /// /// Returns errors if the batch size or channels dimensions of the two tensor differ or an /// invalid combination of arguments is detected. - pub fn pooling_forward( + pub fn pooling_forward( &self, pooling_desc: &PoolingDescriptor, - alpha: T, - x_desc: &TensorDescriptor, - x: &impl GpuBuffer, - beta: T, - y_desc: &TensorDescriptor, - y: &mut impl GpuBuffer, + alpha: CompT, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + beta: CompT, + y_desc: &TensorDescriptor, + y: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where - T: SupportedPoolFwd, - U: DataType, + CompT: SupportedPoolFwd, + T: DataType, { - let alpha_ptr = &alpha as *const T as *const _; + let alpha_ptr = &alpha as *const CompT as *const _; let x_ptr = x.as_device_ptr().as_ptr() as *const _; - let beta_ptr = &beta as *const T as *const _; + let beta_ptr = &beta as *const CompT as *const _; let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _; unsafe { @@ -82,9 +82,9 @@ impl CudnnContext { /// /// * `dy` - data foe the differential of the output map. /// - /// * `x_desc` - tensor descriptor for the dropout input. + /// * `x_desc` - tensor descriptor for the pooling input. /// - /// * `x` - data for the dropout input. + /// * `x` - data for the pooling input. /// /// * `beta` - scaling factor for the destination tensor. /// @@ -97,31 +97,31 @@ impl CudnnContext { /// Returns errors if the dimensions or the strides of `y` and `dy` tensors differ or if the /// dimensions or the strides of `x` and `dx` tensors differ or if an unsupported combination /// of arguments is detected. - pub fn pooling_backward( + pub fn pooling_backward( &self, pooling_desc: &PoolingDescriptor, - alpha: T, - y_desc: &TensorDescriptor, - y: &impl GpuBuffer, - dy_desc: &TensorDescriptor, - dy: &impl GpuBuffer, - x_desc: &TensorDescriptor, - x: &impl GpuBuffer, - beta: T, - dx_desc: &TensorDescriptor, - dx: &mut impl GpuBuffer, + alpha: CompT, + y_desc: &TensorDescriptor, + y: &impl GpuBuffer, + dy_desc: &TensorDescriptor, + dy: &impl GpuBuffer, + x_desc: &TensorDescriptor, + x: &impl GpuBuffer, + beta: CompT, + dx_desc: &TensorDescriptor, + dx: &mut impl GpuBuffer, ) -> Result<(), CudnnError> where - T: SupportedPoolBwd, - U: DataType, + CompT: SupportedPoolBwd, + T: DataType, { - let alpha_ptr = &alpha as *const T as *const _; + let alpha_ptr = &alpha as *const CompT as *const _; let y_ptr = y.as_device_ptr().as_ptr() as *const _; let dy_ptr = dy.as_device_ptr().as_ptr() as *const _; let x_ptr = x.as_device_ptr().as_ptr() as *const _; - let beta_ptr = &beta as *const T as *const _; + let beta_ptr = &beta as *const CompT as *const _; let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _; diff --git a/crates/cudnn/src/rnn/rnn_data_descriptor.rs b/crates/cudnn/src/rnn/rnn_data_descriptor.rs index 7fb50e0c..af6f255c 100644 --- a/crates/cudnn/src/rnn/rnn_data_descriptor.rs +++ b/crates/cudnn/src/rnn/rnn_data_descriptor.rs @@ -10,6 +10,8 @@ pub trait RnnDataType: DataType + private::Sealed {} impl RnnDataType for f32 {} impl RnnDataType for f64 {} +/// Descriptor of a recurrent neural network data container. +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RnnDataDescriptor where T: RnnDataType, diff --git a/crates/cudnn/src/rnn/rnn_descriptor.rs b/crates/cudnn/src/rnn/rnn_descriptor.rs index 20088b21..3f906519 100644 --- a/crates/cudnn/src/rnn/rnn_descriptor.rs +++ b/crates/cudnn/src/rnn/rnn_descriptor.rs @@ -21,7 +21,7 @@ bitflags::bitflags! { /// /// This descriptor is generic over the data type of the parameters and the inputs, and the one of /// the computation. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RnnDescriptor where T: DataType, diff --git a/crates/cudnn/src/tensor/tensor_descriptor.rs b/crates/cudnn/src/tensor/tensor_descriptor.rs index eaf4451c..aadb2e54 100644 --- a/crates/cudnn/src/tensor/tensor_descriptor.rs +++ b/crates/cudnn/src/tensor/tensor_descriptor.rs @@ -5,7 +5,7 @@ use std::{ }; /// A generic description of an n-dimensional dataset. -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct TensorDescriptor where T: DataType, From 6bbf231f41877897fa076dfd8eea2b4373dbf87c Mon Sep 17 00:00:00 2001 From: Francesco Iannelli <54247008+frjnn@users.noreply.github.com> Date: Mon, 14 Mar 2022 16:17:59 +0100 Subject: [PATCH 3/3] Chore: Adds cudnn to crate lineup (#60) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index eddbc6c6..68b16043 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ The current line-up of libraries is the following: - `cuda_std` for GPU-side functions and utilities, such as thread index queries, memory allocation, warp intrinsics, etc. - *Not* a low level library, provides many utility functions to make it easier to write cleaner and more reliable GPU kernels. - Closely tied to `rustc_codegen_nvvm` which exposes GPU features through it internally. +- [`cudnn`](https://github.com/Rust-GPU/Rust-CUDA/tree/master/crates/cudnn) for a collection of GPU-accelerated primitives for deep neural networks. - `cust` for CPU-side CUDA features such as launching GPU kernels, GPU memory allocation, device queries, etc. - High level with features such as RAII and Rust Results that make it easier and cleaner to manage the interface to the GPU. - A high level wrapper for the CUDA Driver API, the lower level version of the more common CUDA Runtime API used from C++.