@@ -115,6 +115,69 @@ impl<A, S, D> ArrayBase<S, D>
115
115
sum / & aview0 ( & cnt)
116
116
}
117
117
118
+ /// Return variance along `axis`.
119
+ ///
120
+ /// The variance is computed using the [Welford one-pass
121
+ /// algorithm](https://www.jstor.org/stable/1266577).
122
+ ///
123
+ /// The parameter `ddof` specifies the "delta degrees of freedom". For
124
+ /// example, to calculate the population variance, use `ddof = 0`, or to
125
+ /// calculate the sample variance, use `ddof = 1`.
126
+ ///
127
+ /// The variance is defined as:
128
+ ///
129
+ /// ```text
130
+ /// 1 n
131
+ /// variance = ―――――――― ∑ (xᵢ - x̅)²
132
+ /// n - ddof i=1
133
+ /// ```
134
+ ///
135
+ /// where
136
+ ///
137
+ /// ```text
138
+ /// 1 n
139
+ /// x̅ = ― ∑ xᵢ
140
+ /// n i=1
141
+ /// ```
142
+ ///
143
+ /// **Panics** if `ddof` is greater equal than the length of `axis`.
144
+ /// **Panics** if `axis` is out of bounds or if length of `axis` is zero.
145
+ ///
146
+ /// # Example
147
+ ///
148
+ /// ```
149
+ /// use ndarray::{aview1, arr2, Axis};
150
+ ///
151
+ /// let a = arr2(&[[1., 2.],
152
+ /// [3., 4.]]);
153
+ /// let var = a.var_axis(Axis(0), 0.);
154
+ /// assert_eq!(var, aview1(&[1., 1.]));
155
+ /// ```
156
+ pub fn var_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
157
+ where
158
+ A : Float ,
159
+ D : RemoveAxis ,
160
+ {
161
+ let mut count = A :: zero ( ) ;
162
+ let mut mean = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
163
+ let mut sum_sq = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
164
+ for subview in self . axis_iter ( axis) {
165
+ count = count + A :: one ( ) ;
166
+ azip ! ( mut mean, mut sum_sq, x ( subview) in {
167
+ let delta = x - * mean;
168
+ * mean = * mean + delta / count;
169
+ * sum_sq = * sum_sq + delta * ( x - * mean) ;
170
+ } ) ;
171
+ }
172
+ if ddof >= count {
173
+ panic ! ( "Ddof needs to be strictly smaller than the length \
174
+ of the axis you are computing the variance for!")
175
+ } else {
176
+ let dof = count - ddof;
177
+ sum_sq. mapv ( |s| s / dof)
178
+ }
179
+ }
180
+
118
181
/// Return `true` if the arrays' elementwise differences are all within
119
182
/// the given absolute tolerance, `false` otherwise.
120
183
///
0 commit comments