1
+ //! Implement linear solver using LU decomposition
2
+ //! for tridiagonal matrix
3
+
4
+ use lapacke;
5
+ use ndarray:: * ;
6
+ use num_traits:: Zero ;
7
+
8
+ use super :: NormType ;
9
+ use super :: { into_result, Pivot , Transpose } ;
10
+
11
+ use crate :: error:: * ;
12
+ use crate :: layout:: MatrixLayout ;
13
+ use crate :: tridiagonal:: { TriDiagonal , LUFactorizedTriDiagonal } ;
14
+ use crate :: types:: * ;
15
+
16
+ /// Wraps `*gttrf`, `*gtcon` and `*gttrs`
17
+ pub trait TriDiagonal_ : Scalar + Sized {
18
+ /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
19
+ /// partial pivoting with row interchanges.
20
+ unsafe fn lu_tridiagonal ( a : & mut TriDiagonal < Self > ) -> Result < ( Array1 < Self > , Pivot ) > ;
21
+ /// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
22
+ unsafe fn rcond_tridiagonal ( lu : & LUFactorizedTriDiagonal < Self > ) -> Result < Self :: Real > ;
23
+ unsafe fn solve_tridiagonal (
24
+ lu : & LUFactorizedTriDiagonal < Self > ,
25
+ bl : MatrixLayout ,
26
+ t : Transpose ,
27
+ b : & mut [ Self ] ) -> Result < ( ) > ;
28
+ }
29
+
30
+ macro_rules! impl_tridiagonal {
31
+ ( $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
32
+ impl TriDiagonal_ for $scalar {
33
+ unsafe fn lu_tridiagonal( a: & mut TriDiagonal <Self >) -> Result <( Array1 <Self >, Pivot ) > {
34
+ let ( n, _) = a. l. size( ) ;
35
+ let dl = a. dl. as_slice_mut( ) . unwrap( ) ;
36
+ let d = a. d. as_slice_mut( ) . unwrap( ) ;
37
+ let du = a. du. as_slice_mut( ) . unwrap( ) ;
38
+ let mut du2 = vec![ Zero :: zero( ) ; ( n-2 ) as usize ] ;
39
+ let mut ipiv = vec![ 0 ; n as usize ] ;
40
+ let info = $gttrf( n, dl, d, du, & mut du2, & mut ipiv) ;
41
+ into_result( info, ( arr1( & du2) , ipiv) )
42
+ }
43
+
44
+ unsafe fn rcond_tridiagonal( lu: & LUFactorizedTriDiagonal <Self >) -> Result <Self :: Real > {
45
+ let ( n, _) = lu. a. l. size( ) ;
46
+ let dl = lu. a. dl. as_slice( ) . unwrap( ) ;
47
+ let d = lu. a. d. as_slice( ) . unwrap( ) ;
48
+ let du = lu. a. du. as_slice( ) . unwrap( ) ;
49
+ let du2 = lu. du2. as_slice( ) . unwrap( ) ;
50
+ let ipiv = & lu. ipiv;
51
+ let anorm = lu. a. n1;
52
+ let mut rcond = Self :: Real :: zero( ) ;
53
+ let info = $gtcon(
54
+ NormType :: One as u8 ,
55
+ n,
56
+ dl,
57
+ d,
58
+ du,
59
+ du2,
60
+ ipiv,
61
+ anorm,
62
+ & mut rcond,
63
+ ) ;
64
+ into_result( info, rcond)
65
+ }
66
+
67
+ unsafe fn solve_tridiagonal(
68
+ lu: & LUFactorizedTriDiagonal <Self >,
69
+ bl: MatrixLayout ,
70
+ t: Transpose ,
71
+ b: & mut [ Self ]
72
+ ) -> Result <( ) > {
73
+ let ( n, _) = lu. a. l. size( ) ;
74
+ let ( _, nrhs) = bl. size( ) ;
75
+ let dl = lu. a. dl. as_slice( ) . unwrap( ) ;
76
+ let d = lu. a. d. as_slice( ) . unwrap( ) ;
77
+ let du = lu. a. du. as_slice( ) . unwrap( ) ;
78
+ let du2 = lu. du2. as_slice( ) . unwrap( ) ;
79
+ let ipiv = & lu. ipiv;
80
+ let ldb = bl. lda( ) ;
81
+ let info = $gttrs(
82
+ lu. a. l. lapacke_layout( ) ,
83
+ t as u8 ,
84
+ n,
85
+ nrhs,
86
+ dl,
87
+ d,
88
+ du,
89
+ du2,
90
+ ipiv,
91
+ b,
92
+ ldb,
93
+ ) ;
94
+ into_result( info, ( ) )
95
+ }
96
+ }
97
+ } ;
98
+ } // impl_tridiagonal!
99
+
100
+ impl_tridiagonal ! ( f64 , lapacke:: dgttrf, lapacke:: dgtcon, lapacke:: dgttrs) ;
101
+ impl_tridiagonal ! ( f32 , lapacke:: sgttrf, lapacke:: sgtcon, lapacke:: sgttrs) ;
102
+ impl_tridiagonal ! ( c64, lapacke:: zgttrf, lapacke:: zgtcon, lapacke:: zgttrs) ;
103
+ impl_tridiagonal ! ( c32, lapacke:: cgttrf, lapacke:: cgtcon, lapacke:: cgttrs) ;
0 commit comments