60
60
//! // `a` and `b` have been moved, no longer valid
61
61
//! ```
62
62
63
- use ndarray:: { s, Array , Array1 , Array2 , ArrayBase , Axis , Data , DataMut , Ix1 , Ix2 } ;
63
+ use ndarray:: { s, Array , Array1 , Array2 , ArrayBase , Axis , Data , DataMut , Dimension , Ix0 , Ix1 , Ix2 } ;
64
64
65
65
use crate :: error:: * ;
66
66
use crate :: lapack:: least_squares:: * ;
67
67
use crate :: layout:: * ;
68
68
use crate :: types:: * ;
69
69
70
- pub trait Ix1OrIx2 < E : Scalar > {
71
- type ScalarOrArray1 ;
72
- }
73
-
74
- impl < E : Scalar > Ix1OrIx2 < E > for Ix1 {
75
- type ScalarOrArray1 = E :: Real ;
76
- }
77
-
78
- impl < E : Scalar > Ix1OrIx2 < E > for Ix2 {
79
- type ScalarOrArray1 = Array1 < E :: Real > ;
80
- }
81
-
82
70
/// Result of a LeastSquares computation
83
71
///
84
72
/// Takes two type parameters, `E`, the element type of the matrix
@@ -88,7 +76,7 @@ impl<E: Scalar> Ix1OrIx2<E> for Ix2 {
88
76
/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
89
77
/// (which can be seen as solving `Ax = b` k times for different b) and
90
78
/// the solution is a `m x k` matrix.
91
- pub struct LeastSquaresResult < E : Scalar , I : Ix1OrIx2 < E > > {
79
+ pub struct LeastSquaresResult < E : Scalar , I : Dimension > {
92
80
/// The singular values of the matrix A in `Ax = b`
93
81
pub singular_values : Array1 < E :: Real > ,
94
82
/// The solution vector or matrix `x` which is the best
@@ -97,16 +85,16 @@ pub struct LeastSquaresResult<E: Scalar, I: Ix1OrIx2<E>> {
97
85
/// The rank of the matrix A in `Ax = b`
98
86
pub rank : i32 ,
99
87
/// If n < m and rank(A) == n, the sum of squares
100
- /// If b is a (m x 1) vector, this is a single value
101
- /// If b is a m x k matrix, this is a k x 1 column vector
102
- pub residual_sum_of_squares : Option < I :: ScalarOrArray1 > ,
88
+ /// If b is a (m x 1) vector, this is a 0-dimensional array ( single value)
89
+ /// If b is a ( m x k) matrix, this is a ( k x 1) column vector
90
+ pub residual_sum_of_squares : Option < Array < E :: Real , I :: Smaller > > ,
103
91
}
104
92
/// Solve least squares for immutable references
105
93
pub trait LeastSquaresSvd < D , E , I >
106
94
where
107
95
D : Data < Elem = E > ,
108
96
E : Scalar + Lapack ,
109
- I : Ix1OrIx2 < E > ,
97
+ I : Dimension ,
110
98
{
111
99
/// Solve a least squares problem of the form `Ax = rhs`
112
100
/// by calling `A.least_squares(&rhs)`. `A` and `rhs`
@@ -123,7 +111,7 @@ pub trait LeastSquaresSvdInto<D, E, I>
123
111
where
124
112
D : Data < Elem = E > ,
125
113
E : Scalar + Lapack ,
126
- I : Ix1OrIx2 < E > ,
114
+ I : Dimension ,
127
115
{
128
116
/// Solve a least squares problem of the form `Ax = rhs`
129
117
/// by calling `A.least_squares(rhs)`, consuming both `A`
@@ -142,7 +130,7 @@ pub trait LeastSquaresSvdInPlace<D, E, I>
142
130
where
143
131
D : Data < Elem = E > ,
144
132
E : Scalar + Lapack ,
145
- I : Ix1OrIx2 < E > ,
133
+ I : Dimension ,
146
134
{
147
135
/// Solve a least squares problem of the form `Ax = rhs`
148
136
/// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
@@ -328,11 +316,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
328
316
n : usize ,
329
317
rank : i32 ,
330
318
b : & ArrayBase < D , Ix1 > ,
331
- ) -> Option < E :: Real > {
319
+ ) -> Option < Array < E :: Real , Ix0 > > {
332
320
if m < n || n != rank as usize {
333
321
return None ;
334
322
}
335
- Some ( b. slice ( s ! [ n..] ) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) )
323
+ let mut arr: Array < E :: Real , Ix0 > = Array :: zeros ( ( ) ) ;
324
+ arr[ ( ) ] = b. slice ( s ! [ n..] ) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
325
+ Some ( arr)
336
326
}
337
327
338
328
/// Solve least squares for mutable references and a matrix
@@ -429,11 +419,10 @@ mod tests {
429
419
use ndarray:: { ArcArray1 , ArcArray2 , Array1 , Array2 , CowArray } ;
430
420
use num_complex:: Complex ;
431
421
432
- ///////////////////////////////////////////////////////////////////////////
433
- /// Test cases taken from the scipy test suite for the scipy lstsq function
434
- /// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
435
- ///////////////////////////////////////////////////////////////////////////
436
-
422
+ //
423
+ // Test cases taken from the scipy test suite for the scipy lstsq function
424
+ // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425
+ //
437
426
#[ test]
438
427
fn scipy_test_simple_exact ( ) {
439
428
let a = array ! [ [ 1. , 20. ] , [ -30. , 4. ] ] ;
@@ -463,10 +452,7 @@ mod tests {
463
452
assert_eq ! ( res. rank, 2 ) ;
464
453
let b_hat = a. dot ( & res. solution ) ;
465
454
let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
466
- assert ! ( res
467
- . residual_sum_of_squares
468
- . unwrap( )
469
- . abs_diff_eq( & rssq, 1e-12 ) ) ;
455
+ assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
470
456
assert ! ( res
471
457
. solution
472
458
. abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
@@ -480,10 +466,7 @@ mod tests {
480
466
assert_eq ! ( res. rank, 2 ) ;
481
467
let b_hat = a. dot ( & res. solution ) ;
482
468
let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
483
- assert ! ( res
484
- . residual_sum_of_squares
485
- . unwrap( )
486
- . abs_diff_eq( & rssq, 1e-6 ) ) ;
469
+ assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-6 ) ) ;
487
470
assert ! ( res
488
471
. solution
489
472
. abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-6 ) ) ;
@@ -505,10 +488,7 @@ mod tests {
505
488
assert_eq ! ( res. rank, 2 ) ;
506
489
let b_hat = a. dot ( & res. solution ) ;
507
490
let rssq = ( & b_hat - & b) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
508
- assert ! ( res
509
- . residual_sum_of_squares
510
- . unwrap( )
511
- . abs_diff_eq( & rssq, 1e-12 ) ) ;
491
+ assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
512
492
assert ! ( res. solution. abs_diff_eq(
513
493
& array![
514
494
c( -0.4831460674157303 , 0.258426966292135 ) ,
@@ -546,18 +526,18 @@ mod tests {
546
526
assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
547
527
}
548
528
549
- ///////////////////////////////////////////////////////////////////////////
550
- /// Test that the different lest squares traits work as intended on the
551
- /// different array types.
552
- ///
553
- /// | least_squares | ls_into | ls_in_place |
554
- /// --------------+---------------+---------+-------------+
555
- /// Array | yes | yes | yes |
556
- /// ArcArray | yes | no | no |
557
- /// CowArray | yes | yes | yes |
558
- /// ArrayView | yes | no | no |
559
- /// ArrayViewMut | yes | no | yes |
560
- ///////////////////////////////////////////////////////////////////////////
529
+ //
530
+ // Test that the different lest squares traits work as intended on the
531
+ // different array types.
532
+ //
533
+ // | least_squares | ls_into | ls_in_place |
534
+ // --------------+---------------+---------+-------------+
535
+ // Array | yes | yes | yes |
536
+ // ArcArray | yes | no | no |
537
+ // CowArray | yes | yes | yes |
538
+ // ArrayView | yes | no | no |
539
+ // ArrayViewMut | yes | no | yes |
540
+ //
561
541
562
542
fn assert_result < D : Data < Elem = f64 > > (
563
543
a : & ArrayBase < D , Ix2 > ,
@@ -567,10 +547,7 @@ mod tests {
567
547
assert_eq ! ( res. rank, 2 ) ;
568
548
let b_hat = a. dot ( & res. solution ) ;
569
549
let rssq = ( b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
570
- assert ! ( res
571
- . residual_sum_of_squares
572
- . unwrap( )
573
- . abs_diff_eq( & rssq, 1e-12 ) ) ;
550
+ assert ! ( res. residual_sum_of_squares. as_ref( ) . unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
574
551
assert ! ( res
575
552
. solution
576
553
. abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
@@ -674,10 +651,10 @@ mod tests {
674
651
assert_result ( & a, & b, & res) ;
675
652
}
676
653
677
- ///////////////////////////////////////////////////////////////////////////
678
- /// Test cases taken from the netlib documentation at
679
- /// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
680
- ///////////////////////////////////////////////////////////////////////////
654
+ //
655
+ // Test cases taken from the netlib documentation at
656
+ // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657
+ //
681
658
#[ test]
682
659
fn netlib_lapack_example_for_dgels_1 ( ) {
683
660
let a: Array2 < f64 > = array ! [
@@ -694,7 +671,7 @@ mod tests {
694
671
695
672
let residual = b - a. dot ( & result. solution ) ;
696
673
let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
697
- assert ! ( ( resid_ssq - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
674
+ assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
698
675
}
699
676
700
677
#[ test]
@@ -713,7 +690,7 @@ mod tests {
713
690
714
691
let residual = b - a. dot ( & result. solution ) ;
715
692
let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
716
- assert ! ( ( resid_ssq - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
693
+ assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
717
694
}
718
695
719
696
#[ test]
@@ -738,9 +715,9 @@ mod tests {
738
715
. abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
739
716
}
740
717
741
- ///////////////////////////////////////////////////////////////////////////
742
- /// Testing error cases
743
- ///////////////////////////////////////////////////////////////////////////
718
+ //
719
+ // Testing error cases
720
+ //
744
721
use crate :: layout:: MatrixLayout ;
745
722
use ndarray:: ErrorKind ;
746
723
0 commit comments