Skip to content

Commit 7ab8f33

Browse files
committed
Address comments from code review
1 parent 6e67fd5 commit 7ab8f33

File tree

1 file changed

+41
-64
lines changed

1 file changed

+41
-64
lines changed

src/least_squares.rs

Lines changed: 41 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -60,25 +60,13 @@
6060
//! // `a` and `b` have been moved, no longer valid
6161
//! ```
6262
63-
use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2};
63+
use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2};
6464

6565
use crate::error::*;
6666
use crate::lapack::least_squares::*;
6767
use crate::layout::*;
6868
use crate::types::*;
6969

70-
pub trait Ix1OrIx2<E: Scalar> {
71-
type ScalarOrArray1;
72-
}
73-
74-
impl<E: Scalar> Ix1OrIx2<E> for Ix1 {
75-
type ScalarOrArray1 = E::Real;
76-
}
77-
78-
impl<E: Scalar> Ix1OrIx2<E> for Ix2 {
79-
type ScalarOrArray1 = Array1<E::Real>;
80-
}
81-
8270
/// Result of a LeastSquares computation
8371
///
8472
/// Takes two type parameters, `E`, the element type of the matrix
@@ -88,7 +76,7 @@ impl<E: Scalar> Ix1OrIx2<E> for Ix2 {
8876
/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
8977
/// (which can be seen as solving `Ax = b` k times for different b) and
9078
/// the solution is a `m x k` matrix.
91-
pub struct LeastSquaresResult<E: Scalar, I: Ix1OrIx2<E>> {
79+
pub struct LeastSquaresResult<E: Scalar, I: Dimension> {
9280
/// The singular values of the matrix A in `Ax = b`
9381
pub singular_values: Array1<E::Real>,
9482
/// The solution vector or matrix `x` which is the best
@@ -97,16 +85,16 @@ pub struct LeastSquaresResult<E: Scalar, I: Ix1OrIx2<E>> {
9785
/// The rank of the matrix A in `Ax = b`
9886
pub rank: i32,
9987
/// If n < m and rank(A) == n, the sum of squares
100-
/// If b is a (m x 1) vector, this is a single value
101-
/// If b is a m x k matrix, this is a k x 1 column vector
102-
pub residual_sum_of_squares: Option<I::ScalarOrArray1>,
88+
/// If b is a (m x 1) vector, this is a 0-dimensional array (single value)
89+
/// If b is a (m x k) matrix, this is a (k x 1) column vector
90+
pub residual_sum_of_squares: Option<Array<E::Real, I::Smaller>>,
10391
}
10492
/// Solve least squares for immutable references
10593
pub trait LeastSquaresSvd<D, E, I>
10694
where
10795
D: Data<Elem = E>,
10896
E: Scalar + Lapack,
109-
I: Ix1OrIx2<E>,
97+
I: Dimension,
11098
{
11199
/// Solve a least squares problem of the form `Ax = rhs`
112100
/// by calling `A.least_squares(&rhs)`. `A` and `rhs`
@@ -123,7 +111,7 @@ pub trait LeastSquaresSvdInto<D, E, I>
123111
where
124112
D: Data<Elem = E>,
125113
E: Scalar + Lapack,
126-
I: Ix1OrIx2<E>,
114+
I: Dimension,
127115
{
128116
/// Solve a least squares problem of the form `Ax = rhs`
129117
/// by calling `A.least_squares(rhs)`, consuming both `A`
@@ -142,7 +130,7 @@ pub trait LeastSquaresSvdInPlace<D, E, I>
142130
where
143131
D: Data<Elem = E>,
144132
E: Scalar + Lapack,
145-
I: Ix1OrIx2<E>,
133+
I: Dimension,
146134
{
147135
/// Solve a least squares problem of the form `Ax = rhs`
148136
/// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
@@ -328,11 +316,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
328316
n: usize,
329317
rank: i32,
330318
b: &ArrayBase<D, Ix1>,
331-
) -> Option<E::Real> {
319+
) -> Option<Array<E::Real, Ix0>> {
332320
if m < n || n != rank as usize {
333321
return None;
334322
}
335-
Some(b.slice(s![n..]).mapv(|x| x.powi(2).abs()).sum())
323+
let mut arr: Array<E::Real, Ix0> = Array::zeros(());
324+
arr[()] = b.slice(s![n..]).mapv(|x| x.powi(2).abs()).sum();
325+
Some(arr)
336326
}
337327

338328
/// Solve least squares for mutable references and a matrix
@@ -429,11 +419,10 @@ mod tests {
429419
use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray};
430420
use num_complex::Complex;
431421

432-
///////////////////////////////////////////////////////////////////////////
433-
/// Test cases taken from the scipy test suite for the scipy lstsq function
434-
/// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
435-
///////////////////////////////////////////////////////////////////////////
436-
422+
//
423+
// Test cases taken from the scipy test suite for the scipy lstsq function
424+
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425+
//
437426
#[test]
438427
fn scipy_test_simple_exact() {
439428
let a = array![[1., 20.], [-30., 4.]];
@@ -463,10 +452,7 @@ mod tests {
463452
assert_eq!(res.rank, 2);
464453
let b_hat = a.dot(&res.solution);
465454
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
466-
assert!(res
467-
.residual_sum_of_squares
468-
.unwrap()
469-
.abs_diff_eq(&rssq, 1e-12));
455+
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
470456
assert!(res
471457
.solution
472458
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
@@ -480,10 +466,7 @@ mod tests {
480466
assert_eq!(res.rank, 2);
481467
let b_hat = a.dot(&res.solution);
482468
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
483-
assert!(res
484-
.residual_sum_of_squares
485-
.unwrap()
486-
.abs_diff_eq(&rssq, 1e-6));
469+
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6));
487470
assert!(res
488471
.solution
489472
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6));
@@ -505,10 +488,7 @@ mod tests {
505488
assert_eq!(res.rank, 2);
506489
let b_hat = a.dot(&res.solution);
507490
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
508-
assert!(res
509-
.residual_sum_of_squares
510-
.unwrap()
511-
.abs_diff_eq(&rssq, 1e-12));
491+
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
512492
assert!(res.solution.abs_diff_eq(
513493
&array![
514494
c(-0.4831460674157303, 0.258426966292135),
@@ -546,18 +526,18 @@ mod tests {
546526
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
547527
}
548528

549-
///////////////////////////////////////////////////////////////////////////
550-
/// Test that the different lest squares traits work as intended on the
551-
/// different array types.
552-
///
553-
/// | least_squares | ls_into | ls_in_place |
554-
/// --------------+---------------+---------+-------------+
555-
/// Array | yes | yes | yes |
556-
/// ArcArray | yes | no | no |
557-
/// CowArray | yes | yes | yes |
558-
/// ArrayView | yes | no | no |
559-
/// ArrayViewMut | yes | no | yes |
560-
///////////////////////////////////////////////////////////////////////////
529+
//
530+
// Test that the different lest squares traits work as intended on the
531+
// different array types.
532+
//
533+
// | least_squares | ls_into | ls_in_place |
534+
// --------------+---------------+---------+-------------+
535+
// Array | yes | yes | yes |
536+
// ArcArray | yes | no | no |
537+
// CowArray | yes | yes | yes |
538+
// ArrayView | yes | no | no |
539+
// ArrayViewMut | yes | no | yes |
540+
//
561541

562542
fn assert_result<D: Data<Elem = f64>>(
563543
a: &ArrayBase<D, Ix2>,
@@ -567,10 +547,7 @@ mod tests {
567547
assert_eq!(res.rank, 2);
568548
let b_hat = a.dot(&res.solution);
569549
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum();
570-
assert!(res
571-
.residual_sum_of_squares
572-
.unwrap()
573-
.abs_diff_eq(&rssq, 1e-12));
550+
assert!(res.residual_sum_of_squares.as_ref().unwrap()[()].abs_diff_eq(&rssq, 1e-12));
574551
assert!(res
575552
.solution
576553
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
@@ -674,10 +651,10 @@ mod tests {
674651
assert_result(&a, &b, &res);
675652
}
676653

677-
///////////////////////////////////////////////////////////////////////////
678-
/// Test cases taken from the netlib documentation at
679-
/// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
680-
///////////////////////////////////////////////////////////////////////////
654+
//
655+
// Test cases taken from the netlib documentation at
656+
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657+
//
681658
#[test]
682659
fn netlib_lapack_example_for_dgels_1() {
683660
let a: Array2<f64> = array![
@@ -694,7 +671,7 @@ mod tests {
694671

695672
let residual = b - a.dot(&result.solution);
696673
let resid_ssq = result.residual_sum_of_squares.unwrap();
697-
assert!((resid_ssq - residual.dot(&residual)).abs() < 1e-12);
674+
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
698675
}
699676

700677
#[test]
@@ -713,7 +690,7 @@ mod tests {
713690

714691
let residual = b - a.dot(&result.solution);
715692
let resid_ssq = result.residual_sum_of_squares.unwrap();
716-
assert!((resid_ssq - residual.dot(&residual)).abs() < 1e-12);
693+
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
717694
}
718695

719696
#[test]
@@ -738,9 +715,9 @@ mod tests {
738715
.abs_diff_eq(&residual_ssq, 1e-12));
739716
}
740717

741-
///////////////////////////////////////////////////////////////////////////
742-
/// Testing error cases
743-
///////////////////////////////////////////////////////////////////////////
718+
//
719+
// Testing error cases
720+
//
744721
use crate::layout::MatrixLayout;
745722
use ndarray::ErrorKind;
746723

0 commit comments

Comments
 (0)