|
| 1 | +//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs] |
| 2 | +//! |
| 3 | +//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main |
| 4 | +//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain |
| 5 | +//! |
| 6 | +//! [MSRV]: https://img.shields.io/badge/MSRV-1.76.0-blue |
| 7 | +//! [repo]: https://github.com/juntyr/numcodecs-rs |
| 8 | +//! |
| 9 | +//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-asinh |
| 10 | +//! [crates.io]: https://crates.io/crates/numcodecs-asinh |
| 11 | +//! |
| 12 | +//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-asinh |
| 13 | +//! [docs.rs]: https://docs.rs/numcodecs-asinh/ |
| 14 | +//! |
| 15 | +//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue |
| 16 | +//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_asinh |
| 17 | +//! |
| 18 | +//! `asinh(x)` codec implementation for the [`numcodecs`] API. |
| 19 | +
|
| 20 | +use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Zip}; |
| 21 | +use num_traits::{Float, Signed}; |
| 22 | +use numcodecs::{ |
| 23 | + AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, |
| 24 | + Codec, StaticCodec, StaticCodecConfig, |
| 25 | +}; |
| 26 | +use schemars::JsonSchema; |
| 27 | +use serde::{Deserialize, Serialize}; |
| 28 | +use thiserror::Error; |
| 29 | + |
| 30 | +#[derive(Clone, Serialize, Deserialize, JsonSchema)] |
| 31 | +#[serde(deny_unknown_fields)] |
| 32 | +/// Asinh codec, which applies a quasi-logarithmic transformation on encoding. |
| 33 | +/// |
| 34 | +/// For values close to zero that are within the codec's `linear_width`, the |
| 35 | +/// transform is close to linear. For values of larger magnitudes, the |
| 36 | +/// transform is asymptotically logarithmic. Unlike a logarithmic transform, |
| 37 | +/// this codec supports all finite values, including negative values and zero. |
| 38 | +/// |
| 39 | +/// In detail, the codec calculates `c = asinh(x/w) * w` on encoding and |
| 40 | +/// `d = sinh(c/w) * w` on decoding, where `w` is the codec's `linear_width`. |
| 41 | +/// |
| 42 | +/// The codec only supports finite floating point numbers. |
| 43 | +pub struct AsinhCodec { |
| 44 | + /// The width of the close-to-zero input value range where the transform is |
| 45 | + /// nearly linear |
| 46 | + linear_width: f64, |
| 47 | +} |
| 48 | + |
| 49 | +impl Codec for AsinhCodec { |
| 50 | + type Error = AsinhCodecError; |
| 51 | + |
| 52 | + fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> { |
| 53 | + match data { |
| 54 | + #[allow(clippy::cast_possible_truncation)] |
| 55 | + AnyCowArray::F32(data) => Ok(AnyArray::F32(asinh(data, self.linear_width as f32)?)), |
| 56 | + AnyCowArray::F64(data) => Ok(AnyArray::F64(asinh(data, self.linear_width)?)), |
| 57 | + encoded => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())), |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> { |
| 62 | + match encoded { |
| 63 | + #[allow(clippy::cast_possible_truncation)] |
| 64 | + AnyCowArray::F32(encoded) => { |
| 65 | + Ok(AnyArray::F32(sinh(encoded, self.linear_width as f32)?)) |
| 66 | + } |
| 67 | + AnyCowArray::F64(encoded) => Ok(AnyArray::F64(sinh(encoded, self.linear_width)?)), |
| 68 | + encoded => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())), |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + fn decode_into( |
| 73 | + &self, |
| 74 | + encoded: AnyArrayView, |
| 75 | + decoded: AnyArrayViewMut, |
| 76 | + ) -> Result<(), Self::Error> { |
| 77 | + match (encoded, decoded) { |
| 78 | + #[allow(clippy::cast_possible_truncation)] |
| 79 | + (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => { |
| 80 | + sinh_into(encoded, decoded, self.linear_width as f32) |
| 81 | + } |
| 82 | + (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => { |
| 83 | + sinh_into(encoded, decoded, self.linear_width) |
| 84 | + } |
| 85 | + (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => { |
| 86 | + Err(AsinhCodecError::MismatchedDecodeIntoArray { |
| 87 | + source: AnyArrayAssignError::DTypeMismatch { |
| 88 | + src: encoded.dtype(), |
| 89 | + dst: decoded.dtype(), |
| 90 | + }, |
| 91 | + }) |
| 92 | + } |
| 93 | + (encoded, _decoded) => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())), |
| 94 | + } |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +impl StaticCodec for AsinhCodec { |
| 99 | + const CODEC_ID: &'static str = "asinh"; |
| 100 | + |
| 101 | + type Config<'de> = Self; |
| 102 | + |
| 103 | + fn from_config(config: Self::Config<'_>) -> Self { |
| 104 | + config |
| 105 | + } |
| 106 | + |
| 107 | + fn get_config(&self) -> StaticCodecConfig<Self> { |
| 108 | + StaticCodecConfig::from(self) |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +#[derive(Debug, Error)] |
| 113 | +/// Errors that may occur when applying the [`AsinhCodec`]. |
| 114 | +pub enum AsinhCodecError { |
| 115 | + /// [`AsinhCodec`] does not support the dtype |
| 116 | + #[error("Asinh does not support the dtype {0}")] |
| 117 | + UnsupportedDtype(AnyArrayDType), |
| 118 | + /// [`AsinhCodec`] does not support non-finite (infinite or NaN) floating |
| 119 | + /// point data |
| 120 | + #[error("Asinh does not support non-finite (infinite or NaN) floating point data")] |
| 121 | + NonFiniteData, |
| 122 | + /// [`AsinhCodec`] cannot decode into the provided array |
| 123 | + #[error("Asinh cannot decode into the provided array")] |
| 124 | + MismatchedDecodeIntoArray { |
| 125 | + /// The source of the error |
| 126 | + #[from] |
| 127 | + source: AnyArrayAssignError, |
| 128 | + }, |
| 129 | +} |
| 130 | + |
| 131 | +/// Compute `asinh(x/w) * w` over the elements of the input `data` array. |
| 132 | +/// |
| 133 | +/// # Errors |
| 134 | +/// |
| 135 | +/// Errors with |
| 136 | +/// - [`AsinhCodecError::NonFiniteData`] if any data element is non-finite |
| 137 | +/// (infinite or NaN) |
| 138 | +pub fn asinh<T: Float + Signed, S: Data<Elem = T>, D: Dimension>( |
| 139 | + data: ArrayBase<S, D>, |
| 140 | + linear_width: T, |
| 141 | +) -> Result<Array<T, D>, AsinhCodecError> { |
| 142 | + if !Zip::from(&data).all(|x| x.is_finite()) { |
| 143 | + return Err(AsinhCodecError::NonFiniteData); |
| 144 | + } |
| 145 | + |
| 146 | + let mut data = data.into_owned(); |
| 147 | + data.mapv_inplace(|x| (x / linear_width).asinh() * linear_width); |
| 148 | + |
| 149 | + Ok(data) |
| 150 | +} |
| 151 | + |
| 152 | +/// Compute `sinh(x/w) * w` over the elements of the input `data` array. |
| 153 | +/// |
| 154 | +/// # Errors |
| 155 | +/// |
| 156 | +/// Errors with |
| 157 | +/// - [`AsinhCodecError::NonFiniteData`] if any data element is non-finite |
| 158 | +/// (infinite or NaN) |
| 159 | +pub fn sinh<T: Float, S: Data<Elem = T>, D: Dimension>( |
| 160 | + data: ArrayBase<S, D>, |
| 161 | + linear_width: T, |
| 162 | +) -> Result<Array<T, D>, AsinhCodecError> { |
| 163 | + if !Zip::from(&data).all(|x| x.is_finite()) { |
| 164 | + return Err(AsinhCodecError::NonFiniteData); |
| 165 | + } |
| 166 | + |
| 167 | + let mut data = data.into_owned(); |
| 168 | + data.mapv_inplace(|x| (x / linear_width).sinh() * linear_width); |
| 169 | + |
| 170 | + Ok(data) |
| 171 | +} |
| 172 | + |
| 173 | +#[allow(clippy::needless_pass_by_value)] |
| 174 | +/// Compute `sinh(x/w) * w` over the elements of the input `data` array and |
| 175 | +/// write them into the `out`put array. |
| 176 | +/// |
| 177 | +/// # Errors |
| 178 | +/// |
| 179 | +/// Errors with |
| 180 | +/// - [`AsinhCodecError::NonFiniteData`] if any data element is non-finite |
| 181 | +/// (infinite or NaN) |
| 182 | +/// - [`AsinhCodecError::MismatchedDecodeIntoArray`] if the `data` array's shape |
| 183 | +/// does not match the `out`put array's shape |
| 184 | +pub fn sinh_into<T: Float, D: Dimension>( |
| 185 | + data: ArrayView<T, D>, |
| 186 | + mut out: ArrayViewMut<T, D>, |
| 187 | + linear_width: T, |
| 188 | +) -> Result<(), AsinhCodecError> { |
| 189 | + if data.shape() != out.shape() { |
| 190 | + return Err(AsinhCodecError::MismatchedDecodeIntoArray { |
| 191 | + source: AnyArrayAssignError::ShapeMismatch { |
| 192 | + src: data.shape().to_vec(), |
| 193 | + dst: out.shape().to_vec(), |
| 194 | + }, |
| 195 | + }); |
| 196 | + } |
| 197 | + |
| 198 | + if !Zip::from(&data).all(|x| x.is_finite()) { |
| 199 | + return Err(AsinhCodecError::NonFiniteData); |
| 200 | + } |
| 201 | + |
| 202 | + // iteration must occur in synchronised (standard) order |
| 203 | + for (d, o) in data.iter().zip(out.iter_mut()) { |
| 204 | + *o = ((*d) / linear_width).sinh() * linear_width; |
| 205 | + } |
| 206 | + |
| 207 | + Ok(()) |
| 208 | +} |
| 209 | + |
| 210 | +#[cfg(test)] |
| 211 | +mod tests { |
| 212 | + use super::*; |
| 213 | + |
| 214 | + #[test] |
| 215 | + fn roundtrip() -> Result<(), AsinhCodecError> { |
| 216 | + let data = (-1000..1000).map(f64::from).collect::<Vec<_>>(); |
| 217 | + let data = Array::from_vec(data); |
| 218 | + |
| 219 | + let encoded = asinh(data.view(), 1.0)?; |
| 220 | + |
| 221 | + for (r, e) in data.iter().zip(encoded.iter()) { |
| 222 | + assert_eq!((*r).asinh().to_bits(), (*e).to_bits()); |
| 223 | + } |
| 224 | + |
| 225 | + let decoded = sinh(encoded, 1.0)?; |
| 226 | + |
| 227 | + for (r, d) in data.iter().zip(decoded.iter()) { |
| 228 | + assert!(((*r) - (*d)).abs() < 1e-12); |
| 229 | + } |
| 230 | + |
| 231 | + Ok(()) |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn roundtrip_widths() -> Result<(), AsinhCodecError> { |
| 236 | + let data = (-1000..1000).map(f64::from).collect::<Vec<_>>(); |
| 237 | + let data = Array::from_vec(data); |
| 238 | + |
| 239 | + for linear_width in [-100.0, -10.0, -1.0, -0.1, 0.1, 1.0, 10.0, 100.0] { |
| 240 | + let encoded = asinh(data.view(), linear_width)?; |
| 241 | + let decoded = sinh(encoded, linear_width)?; |
| 242 | + |
| 243 | + for (r, d) in data.iter().zip(decoded.iter()) { |
| 244 | + assert!(((*r) - (*d)).abs() < 1e-12); |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + Ok(()) |
| 249 | + } |
| 250 | +} |
0 commit comments