2
2
//! &
3
3
//! Methods for tridiagonal matrices
4
4
5
+ use std:: ops:: { Index , IndexMut } ;
6
+
5
7
use cauchy:: Scalar ;
6
8
use ndarray:: * ;
7
9
use num_traits:: One ;
8
10
9
- use crate :: opnorm:: OperationNorm ;
10
-
11
11
use super :: convert:: * ;
12
12
use super :: error:: * ;
13
13
use super :: lapack:: * ;
14
14
use super :: layout:: * ;
15
15
16
16
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
17
- /// This struct also holds the layout and 1-norm of the raw matrix
18
- /// for some methods (eg. rcond_tridiagonal()).
19
- #[ derive( Clone ) ]
17
+ /// This struct also holds the layout of the raw matrix.
18
+ #[ derive( Clone , PartialEq ) ]
20
19
pub struct TriDiagonal < A : Scalar > {
21
20
/// layout of raw matrix
22
21
pub l : MatrixLayout ,
23
- /// the one norm of raw matrix
24
- pub n1 : <A as Scalar >:: Real ,
25
22
/// (n-1) sub-diagonal elements of matrix.
26
23
pub dl : Array1 < A > ,
27
24
/// (n) diagonal elements of matrix.
@@ -30,10 +27,73 @@ pub struct TriDiagonal<A: Scalar> {
30
27
pub du : Array1 < A > ,
31
28
}
32
29
30
+ pub trait TridiagIndex {
31
+ fn to_tuple ( & self ) -> ( i32 , i32 ) ;
32
+ }
33
+ impl TridiagIndex for [ Ix ; 2 ] {
34
+ fn to_tuple ( & self ) -> ( i32 , i32 ) {
35
+ ( self [ 0 ] as i32 , self [ 1 ] as i32 )
36
+ }
37
+ }
38
+
39
+ fn debug_bounds_check_tridiag ( n : i32 , row : i32 , col : i32 ) {
40
+ if std:: cmp:: max ( row, col) >= n {
41
+ panic ! (
42
+ "ndarray: index {:?} is out of bounds for array of shape {}" ,
43
+ [ row, col] ,
44
+ n
45
+ ) ;
46
+ }
47
+ }
48
+
49
+ impl < A , I > Index < I > for TriDiagonal < A >
50
+ where
51
+ A : Scalar ,
52
+ I : TridiagIndex ,
53
+ {
54
+ type Output = A ;
55
+ #[ inline]
56
+ fn index ( & self , index : I ) -> & A {
57
+ let ( n, _) = self . l . size ( ) ;
58
+ let ( row, col) = index. to_tuple ( ) ;
59
+ debug_bounds_check_tridiag ( n, row, col) ;
60
+ match row - col {
61
+ 0 => & self . d [ row as usize ] ,
62
+ 1 => & self . dl [ col as usize ] ,
63
+ -1 => & self . du [ row as usize ] ,
64
+ _ => panic ! (
65
+ "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element" ,
66
+ [ row, col]
67
+ ) ,
68
+ }
69
+ }
70
+ }
71
+
72
+ impl < A , I > IndexMut < I > for TriDiagonal < A >
73
+ where
74
+ A : Scalar ,
75
+ I : TridiagIndex ,
76
+ {
77
+ #[ inline]
78
+ fn index_mut ( & mut self , index : I ) -> & mut A {
79
+ let ( n, _) = self . l . size ( ) ;
80
+ let ( row, col) = index. to_tuple ( ) ;
81
+ debug_bounds_check_tridiag ( n, row, col) ;
82
+ match row - col {
83
+ 0 => & mut self . d [ row as usize ] ,
84
+ 1 => & mut self . dl [ col as usize ] ,
85
+ -1 => & mut self . du [ row as usize ] ,
86
+ _ => panic ! (
87
+ "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element" ,
88
+ [ row, col]
89
+ ) ,
90
+ }
91
+ }
92
+ }
93
+
33
94
/// An interface for making a TriDiagonal struct.
34
95
pub trait ToTriDiagonal < A : Scalar > {
35
96
/// Extract tridiagonal elements and layout of the raw matrix.
36
- /// And also calculate 1-norm.
37
97
///
38
98
/// If the raw matrix has some non-tridiagonal elements,
39
99
/// they will be ignored.
@@ -53,12 +113,11 @@ where
53
113
if n < 2 {
54
114
panic ! ( "Cannot make a tridiagonal matrix of shape=(1, 1)!" ) ;
55
115
}
56
- let n1 = self . opnorm_one ( ) ?;
57
116
58
117
let dl = self . slice ( s ! [ 1 ..n, 0 ..n - 1 ] ) . diag ( ) . to_owned ( ) ;
59
118
let d = self . diag ( ) . to_owned ( ) ;
60
119
let du = self . slice ( s ! [ 0 ..n - 1 , 1 ..n] ) . diag ( ) . to_owned ( ) ;
61
- Ok ( TriDiagonal { l, n1 , dl, d, du } )
120
+ Ok ( TriDiagonal { l, dl, d, du } )
62
121
}
63
122
}
64
123
@@ -130,13 +189,14 @@ pub trait SolveTriDiagonalInplace<A: Scalar, D: Dimension> {
130
189
pub struct LUFactorizedTriDiagonal < A : Scalar > {
131
190
/// A tridiagonal matrix which consists of
132
191
/// - l : layout of raw matrix
133
- /// - n1: the one norm of raw matrix
134
192
/// - dl: (n-1) multipliers that define the matrix L.
135
193
/// - d : (n) diagonal elements of the upper triangular matrix U.
136
194
/// - du: (n-1) elements of the first super-diagonal of U.
137
195
pub a : TriDiagonal < A > ,
138
196
/// (n-2) elements of the second super-diagonal of U.
139
197
pub du2 : Array1 < A > ,
198
+ /// 1-norm of raw matrix (used in .rcond_tridiagonal()).
199
+ pub anom : A :: Real ,
140
200
/// The pivot indices that define the permutation matrix `P`.
141
201
pub ipiv : Pivot ,
142
202
}
@@ -598,10 +658,11 @@ where
598
658
A : Scalar + Lapack ,
599
659
{
600
660
fn factorize_tridiagonal_into ( mut self ) -> Result < LUFactorizedTriDiagonal < A > > {
601
- let ( du2, ipiv) = unsafe { A :: lu_tridiagonal ( & mut self ) ? } ;
661
+ let ( du2, anom , ipiv) = unsafe { A :: lu_tridiagonal ( & mut self ) ? } ;
602
662
Ok ( LUFactorizedTriDiagonal {
603
663
a : self ,
604
664
du2 : du2,
665
+ anom : anom,
605
666
ipiv : ipiv,
606
667
} )
607
668
}
@@ -613,8 +674,8 @@ where
613
674
{
614
675
fn factorize_tridiagonal ( & self ) -> Result < LUFactorizedTriDiagonal < A > > {
615
676
let mut a = self . clone ( ) ;
616
- let ( du2, ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
617
- Ok ( LUFactorizedTriDiagonal { a, du2, ipiv } )
677
+ let ( du2, anom , ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
678
+ Ok ( LUFactorizedTriDiagonal { a, du2, anom , ipiv } )
618
679
}
619
680
}
620
681
@@ -625,8 +686,8 @@ where
625
686
{
626
687
fn factorize_tridiagonal ( & self ) -> Result < LUFactorizedTriDiagonal < A > > {
627
688
let mut a = self . to_tridiagonal ( ) ?;
628
- let ( du2, ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
629
- Ok ( LUFactorizedTriDiagonal { a, du2, ipiv } )
689
+ let ( du2, anom , ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
690
+ Ok ( LUFactorizedTriDiagonal { a, du2, anom , ipiv } )
630
691
}
631
692
}
632
693
0 commit comments