-
Notifications
You must be signed in to change notification settings - Fork 28
Entropy #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Entropy #24
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
76986a9
Add entropy trait
LukeMathWalker 1395859
Remove unnecessary imports
LukeMathWalker b741b3d
Implemented entropy
LukeMathWalker 7b4f6d5
Return entropy with reversed sign, as per definition
LukeMathWalker cc221e2
Fixed tests
LukeMathWalker 7473815
Added signature for cross entropy
LukeMathWalker ea3e81f
Fixed typo.
LukeMathWalker 2998395
Implemented cross_entropy
LukeMathWalker bfc3d22
Added tests
LukeMathWalker a69ed91
Refined panic condition
LukeMathWalker 3d7929b
Added test vs SciPy
LukeMathWalker 27dbd00
Added test vs SciPy
LukeMathWalker 873871b
Added KL divergence
LukeMathWalker dc85e9a
Added KL tests
LukeMathWalker ca95788
Renamed to kl_divergence
LukeMathWalker d21a0bb
Update src/entropy.rs
jturner314 c127428
Improved docs on behaviour with not normalised arrays
LukeMathWalker 0106d65
Improved docs on behaviour with not normalised arrays
LukeMathWalker b28f461
Use mapv
LukeMathWalker ddf358b
Styling on closures (avoid dereferencing)
LukeMathWalker afdcf06
Allow different data ownership to interact in kl_divergence
LukeMathWalker 28b4efd
Allow different data ownership to interact in kl_divergence
LukeMathWalker 8c04f9c
Allow different data ownership to interact in cross_entropy
LukeMathWalker 450cfb4
Add a test
LukeMathWalker 5d45bdf
Doc improvement
LukeMathWalker 5c72f55
Check the whole shape
LukeMathWalker 168ffa5
Merge remote-tracking branch 'origin/entropy' into entropy
LukeMathWalker bb38763
Fix docs
LukeMathWalker c470a3a
Broken usage of Zip
LukeMathWalker e4be9b9
Fixed zip, mistery
LukeMathWalker 57537c3
Use Zip for cross_entropy
LukeMathWalker 80198bc
Add failure crate as dependency
93371f8
Errors module
5f6a004
Use failure crate
42c3600
Add ShapeMismatch error
02a63de
Merge branch 'entropy' of github.com:LukeMathWalker/ndarray-stats int…
05d5c66
Return Result
99a391e
Fix test suite
3a3d1f6
Fix docs
ca31af8
Fix docs
e65ef61
Add docs to error
e39025c
Update src/entropy.rs
jturner314 99b999f
Update src/entropy.rs
jturner314 ac4c159
Update src/entropy.rs
jturner314 b429ec7
Better semantic
LukeMathWalker e9679fa
Use Error instead of Fail
LukeMathWalker d2dfe8f
Formatting
LukeMathWalker d8583d2
Merge branch 'master' into entropy
LukeMathWalker b8ed3ed
Fix TOC
LukeMathWalker c961d9f
Module docstring
LukeMathWalker e7ec4b4
Merge branch 'entropy' of github.com:LukeMathWalker/ndarray-stats int…
LukeMathWalker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,396 @@ | ||
//! Information theory (e.g. entropy, KL divergence, etc.). | ||
use crate::errors::ShapeMismatch; | ||
use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; | ||
use num_traits::Float; | ||
|
||
/// Extension trait for `ArrayBase` providing methods | ||
/// to compute information theory quantities | ||
/// (e.g. entropy, Kullback–Leibler divergence, etc.). | ||
pub trait EntropyExt<A, S, D> | ||
where | ||
S: Data<Elem = A>, | ||
D: Dimension, | ||
{ | ||
/// Computes the [entropy] *S* of the array values, defined as | ||
/// | ||
/// ```text | ||
/// n | ||
/// S = - ∑ xᵢ ln(xᵢ) | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// If the array is empty, `None` is returned. | ||
/// | ||
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`). | ||
/// | ||
/// ## Remarks | ||
/// | ||
/// The entropy is a measure used in [Information Theory] | ||
/// to describe a probability distribution: it only make sense | ||
/// when the array values sum to 1, with each entry between | ||
/// 0 and 1 (extremes included). | ||
/// | ||
/// The array values are **not** normalised by this function before | ||
/// computing the entropy to avoid introducing potentially | ||
/// unnecessary numerical errors (e.g. if the array were to be already normalised). | ||
/// | ||
/// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0. | ||
/// | ||
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory) | ||
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory | ||
fn entropy(&self) -> Option<A> | ||
where | ||
A: Float; | ||
|
||
/// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays, | ||
/// where `self`=*p*. | ||
/// | ||
/// The Kullback-Leibler divergence is defined as: | ||
/// | ||
/// ```text | ||
/// n | ||
/// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ) | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// If the arrays are empty, Ok(`None`) is returned. | ||
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. | ||
/// | ||
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing | ||
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`. | ||
/// | ||
/// ## Remarks | ||
/// | ||
/// The Kullback-Leibler divergence is a measure used in [Information Theory] | ||
/// to describe the relationship between two probability distribution: it only make sense | ||
/// when each array sums to 1 with entries between 0 and 1 (extremes included). | ||
LukeMathWalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// | ||
/// The array values are **not** normalised by this function before | ||
/// computing the entropy to avoid introducing potentially | ||
/// unnecessary numerical errors (e.g. if the array were to be already normalised). | ||
/// | ||
/// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0. | ||
/// | ||
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence | ||
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory | ||
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> | ||
where | ||
S2: Data<Elem = A>, | ||
A: Float; | ||
|
||
/// Computes the [cross entropy] *H(p,q)* between two arrays, | ||
/// where `self`=*p*. | ||
/// | ||
/// The cross entropy is defined as: | ||
/// | ||
/// ```text | ||
/// n | ||
/// H(p,q) = - ∑ pᵢ ln(qᵢ) | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// If the arrays are empty, Ok(`None`) is returned. | ||
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. | ||
/// | ||
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number | ||
/// is a panic cause for `A`. | ||
/// | ||
/// ## Remarks | ||
/// | ||
/// The cross entropy is a measure used in [Information Theory] | ||
/// to describe the relationship between two probability distributions: it only makes sense | ||
/// when each array sums to 1 with entries between 0 and 1 (extremes included). | ||
/// | ||
/// The array values are **not** normalised by this function before | ||
/// computing the entropy to avoid introducing potentially | ||
/// unnecessary numerical errors (e.g. if the array were to be already normalised). | ||
/// | ||
/// The cross entropy is often used as an objective/loss function in | ||
/// [optimization problems], including [machine learning]. | ||
/// | ||
/// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0. | ||
/// | ||
/// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy | ||
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory | ||
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method | ||
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression | ||
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> | ||
where | ||
S2: Data<Elem = A>, | ||
A: Float; | ||
} | ||
|
||
impl<A, S, D> EntropyExt<A, S, D> for ArrayBase<S, D> | ||
where | ||
S: Data<Elem = A>, | ||
D: Dimension, | ||
{ | ||
fn entropy(&self) -> Option<A> | ||
where | ||
A: Float, | ||
{ | ||
if self.len() == 0 { | ||
None | ||
} else { | ||
let entropy = self | ||
.mapv(|x| { | ||
if x == A::zero() { | ||
A::zero() | ||
} else { | ||
x * x.ln() | ||
} | ||
}) | ||
.sum(); | ||
Some(-entropy) | ||
} | ||
} | ||
|
||
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> | ||
where | ||
A: Float, | ||
S2: Data<Elem = A>, | ||
{ | ||
if self.len() == 0 { | ||
return Ok(None); | ||
} | ||
if self.shape() != q.shape() { | ||
return Err(ShapeMismatch { | ||
first_shape: self.shape().to_vec(), | ||
second_shape: q.shape().to_vec(), | ||
}); | ||
} | ||
|
||
let mut temp = Array::zeros(self.raw_dim()); | ||
Zip::from(&mut temp) | ||
.and(self) | ||
.and(q) | ||
.apply(|result, &p, &q| { | ||
*result = { | ||
if p == A::zero() { | ||
A::zero() | ||
} else { | ||
p * (q / p).ln() | ||
} | ||
} | ||
}); | ||
let kl_divergence = -temp.sum(); | ||
Ok(Some(kl_divergence)) | ||
} | ||
|
||
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> | ||
where | ||
S2: Data<Elem = A>, | ||
A: Float, | ||
{ | ||
if self.len() == 0 { | ||
return Ok(None); | ||
} | ||
if self.shape() != q.shape() { | ||
return Err(ShapeMismatch { | ||
first_shape: self.shape().to_vec(), | ||
second_shape: q.shape().to_vec(), | ||
}); | ||
} | ||
|
||
let mut temp = Array::zeros(self.raw_dim()); | ||
Zip::from(&mut temp) | ||
.and(self) | ||
.and(q) | ||
.apply(|result, &p, &q| { | ||
*result = { | ||
if p == A::zero() { | ||
A::zero() | ||
} else { | ||
p * q.ln() | ||
} | ||
} | ||
}); | ||
let cross_entropy = -temp.sum(); | ||
Ok(Some(cross_entropy)) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::EntropyExt; | ||
use approx::assert_abs_diff_eq; | ||
use errors::ShapeMismatch; | ||
use ndarray::{array, Array1}; | ||
use noisy_float::types::n64; | ||
use std::f64; | ||
|
||
#[test] | ||
fn test_entropy_with_nan_values() { | ||
let a = array![f64::NAN, 1.]; | ||
assert!(a.entropy().unwrap().is_nan()); | ||
} | ||
|
||
#[test] | ||
fn test_entropy_with_empty_array_of_floats() { | ||
let a: Array1<f64> = array![]; | ||
assert!(a.entropy().is_none()); | ||
} | ||
|
||
#[test] | ||
fn test_entropy_with_array_of_floats() { | ||
// Array of probability values - normalized and positive. | ||
let a: Array1<f64> = array![ | ||
0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396, | ||
0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418, | ||
0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668, | ||
0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495, | ||
0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, | ||
0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634, | ||
0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204, | ||
0.01866295, | ||
]; | ||
// Computed using scipy.stats.entropy | ||
let expected_entropy = 3.721606155686918; | ||
|
||
assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6); | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> { | ||
LukeMathWalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let a = array![f64::NAN, 1.]; | ||
let b = array![2., 1.]; | ||
assert!(a.cross_entropy(&b)?.unwrap().is_nan()); | ||
assert!(b.cross_entropy(&a)?.unwrap().is_nan()); | ||
assert!(a.kl_divergence(&b)?.unwrap().is_nan()); | ||
assert!(b.kl_divergence(&a)?.unwrap().is_nan()); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() { | ||
let p = array![f64::NAN, 1.]; | ||
let q = array![2., 1., 5.]; | ||
assert!(q.cross_entropy(&p).is_err()); | ||
assert!(p.cross_entropy(&q).is_err()); | ||
assert!(q.kl_divergence(&p).is_err()); | ||
assert!(p.kl_divergence(&q).is_err()); | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() { | ||
// p: 3x2, 6 elements | ||
let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]]; | ||
// q: 2x3, 6 elements | ||
let q = array![[2., 1., 5.], [1., 1., 7.],]; | ||
assert!(q.cross_entropy(&p).is_err()); | ||
assert!(p.cross_entropy(&q).is_err()); | ||
assert!(q.kl_divergence(&p).is_err()); | ||
assert!(p.kl_divergence(&q).is_err()); | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> { | ||
let p: Array1<f64> = array![]; | ||
let q: Array1<f64> = array![]; | ||
assert!(p.cross_entropy(&q)?.is_none()); | ||
assert!(p.kl_divergence(&q)?.is_none()); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> { | ||
let p = array![1.]; | ||
let q = array![-1.]; | ||
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap(); | ||
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap(); | ||
assert!(cross_entropy.is_nan()); | ||
assert!(kl_divergence.is_nan()); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
fn test_cross_entropy_with_noisy_negative_qs() { | ||
let p = array![n64(1.)]; | ||
let q = array![n64(-1.)]; | ||
let _ = p.cross_entropy(&q); | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
fn test_kl_with_noisy_negative_qs() { | ||
let p = array![n64(1.)]; | ||
let q = array![n64(-1.)]; | ||
let _ = p.kl_divergence(&q); | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> { | ||
let p = array![0., 0.]; | ||
let q = array![0., 0.5]; | ||
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.); | ||
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership( | ||
) -> Result<(), ShapeMismatch> { | ||
let p = array![0.5, 0.5]; | ||
let mut q = array![0.5, 0.]; | ||
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY); | ||
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_cross_entropy() -> Result<(), ShapeMismatch> { | ||
// Arrays of probability values - normalized and positive. | ||
let p: Array1<f64> = array![ | ||
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189, | ||
0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174, | ||
0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246, | ||
0.00727477, 0.01004402, 0.01854399, 0.03504082, | ||
]; | ||
let q: Array1<f64> = array![ | ||
0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812, | ||
0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292, | ||
0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064, | ||
0.01813342, 0.0007763, 0.0735472, 0.05857833, | ||
]; | ||
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) | ||
let expected_cross_entropy = 3.385347705020779; | ||
|
||
assert_abs_diff_eq!( | ||
p.cross_entropy(&q)?.unwrap(), | ||
expected_cross_entropy, | ||
epsilon = 1e-6 | ||
); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_kl() -> Result<(), ShapeMismatch> { | ||
// Arrays of probability values - normalized and positive. | ||
let p: Array1<f64> = array![ | ||
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516, | ||
0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498, | ||
0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487, | ||
0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063, | ||
0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, | ||
0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831, | ||
0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394, | ||
0.01108706, | ||
]; | ||
let q: Array1<f64> = array![ | ||
0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717, | ||
0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548, | ||
0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771, | ||
0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438, | ||
0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686, | ||
0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325, | ||
0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769, | ||
0.02082707, | ||
]; | ||
// Computed using scipy.stats.entropy(p, q) | ||
let expected_kl = 0.3555862567800096; | ||
|
||
assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6); | ||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.