Skip to content

Commit 3fba662

Browse files
committed
add calculation for tridiagonal matrices (solve, factorize, det, rcond)
1 parent a595a40 commit 3fba662

File tree

6 files changed

+849
-1
lines changed

6 files changed

+849
-1
lines changed

examples/tridiagonal.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
// Solve `Ax=b` for tridiagonal matrix
5+
fn solve() -> Result<(), error::LinalgError> {
6+
let mut a: Array2<f64> = random((3, 3));
7+
let b: Array1<f64> = random(3);
8+
a[[0, 2]] = 0.0;
9+
a[[2, 0]] = 0.0;
10+
let _x = a.solve_tridiagonal(&b)?;
11+
Ok(())
12+
}
13+
14+
// Solve `Ax=b` for many b with fixed A
15+
fn factorize() -> Result<(), error::LinalgError> {
16+
let mut a: Array2<f64> = random((3, 3));
17+
a[[0, 2]] = 0.0;
18+
a[[2, 0]] = 0.0;
19+
let f = a.factorize_tridiagonal()?; // LU factorize A (A is *not* consumed)
20+
for _ in 0..10 {
21+
let b: Array1<f64> = random(3);
22+
let _x = f.solve_tridiagonal_into(b)?; // solve Ax=b using factorized L, U
23+
}
24+
Ok(())
25+
}
26+
27+
fn main() {
28+
solve().unwrap();
29+
factorize().unwrap();
30+
}

src/lapack/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod solveh;
1010
pub mod svd;
1111
pub mod svddc;
1212
pub mod triangular;
13+
pub mod tridiagonal;
1314

1415
pub use self::cholesky::*;
1516
pub use self::eig::*;
@@ -21,6 +22,7 @@ pub use self::solveh::*;
2122
pub use self::svd::*;
2223
pub use self::svddc::*;
2324
pub use self::triangular::*;
25+
pub use self::tridiagonal::*;
2426

2527
use super::error::*;
2628
use super::types::*;
@@ -29,7 +31,7 @@ pub type Pivot = Vec<i32>;
2931

3032
/// Trait for primitive types which implements LAPACK subroutines
3133
pub trait Lapack:
32-
OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_
34+
OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ + TriDiagonal_
3335
{
3436
}
3537

src/lapack/tridiagonal.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//! Implement linear solver using LU decomposition
2+
//! for tridiagonal matrix
3+
4+
use lapacke;
5+
use ndarray::*;
6+
use num_traits::Zero;
7+
8+
use super::NormType;
9+
use super::{into_result, Pivot, Transpose};
10+
11+
use crate::error::*;
12+
use crate::layout::MatrixLayout;
13+
use crate::tridiagonal::{TriDiagonal, LUFactorizedTriDiagonal};
14+
use crate::types::*;
15+
16+
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
17+
pub trait TriDiagonal_: Scalar + Sized {
18+
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
19+
/// partial pivoting with row interchanges.
20+
unsafe fn lu_tridiagonal(a: &mut TriDiagonal<Self>) -> Result<(Array1<Self>, Pivot)>;
21+
/// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
22+
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal<Self>) -> Result<Self::Real>;
23+
unsafe fn solve_tridiagonal(
24+
lu: &LUFactorizedTriDiagonal<Self>,
25+
bl: MatrixLayout,
26+
t: Transpose,
27+
b: &mut [Self]) -> Result<()>;
28+
}
29+
30+
macro_rules! impl_tridiagonal {
31+
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
32+
impl TriDiagonal_ for $scalar {
33+
unsafe fn lu_tridiagonal(a: &mut TriDiagonal<Self>) -> Result<(Array1<Self>, Pivot)> {
34+
let (n, _) = a.l.size();
35+
let dl = a.dl.as_slice_mut().unwrap();
36+
let d = a.d.as_slice_mut().unwrap();
37+
let du = a.du.as_slice_mut().unwrap();
38+
let mut du2 = vec![Zero::zero(); (n-2) as usize];
39+
let mut ipiv = vec![0; n as usize];
40+
let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv);
41+
into_result(info, (arr1(&du2), ipiv))
42+
}
43+
44+
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal<Self>) -> Result<Self::Real> {
45+
let (n, _) = lu.a.l.size();
46+
let dl = lu.a.dl.as_slice().unwrap();
47+
let d = lu.a.d.as_slice().unwrap();
48+
let du = lu.a.du.as_slice().unwrap();
49+
let du2 = lu.du2.as_slice().unwrap();
50+
let ipiv = &lu.ipiv;
51+
let anorm = lu.a.n1;
52+
let mut rcond = Self::Real::zero();
53+
let info = $gtcon(
54+
NormType::One as u8,
55+
n,
56+
dl,
57+
d,
58+
du,
59+
du2,
60+
ipiv,
61+
anorm,
62+
&mut rcond,
63+
);
64+
into_result(info, rcond)
65+
}
66+
67+
unsafe fn solve_tridiagonal(
68+
lu: &LUFactorizedTriDiagonal<Self>,
69+
bl: MatrixLayout,
70+
t: Transpose,
71+
b: &mut [Self]
72+
) -> Result<()> {
73+
let (n, _) = lu.a.l.size();
74+
let (_, nrhs) = bl.size();
75+
let dl = lu.a.dl.as_slice().unwrap();
76+
let d = lu.a.d.as_slice().unwrap();
77+
let du = lu.a.du.as_slice().unwrap();
78+
let du2 = lu.du2.as_slice().unwrap();
79+
let ipiv = &lu.ipiv;
80+
let ldb = bl.lda();
81+
let info = $gttrs(
82+
lu.a.l.lapacke_layout(),
83+
t as u8,
84+
n,
85+
nrhs,
86+
dl,
87+
d,
88+
du,
89+
du2,
90+
ipiv,
91+
b,
92+
ldb,
93+
);
94+
into_result(info, ())
95+
}
96+
}
97+
};
98+
} // impl_tridiagonal!
99+
100+
impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs);
101+
impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs);
102+
impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs);
103+
impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs);

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//! - [General matrices](solve/index.html)
1515
//! - [Triangular matrices](triangular/index.html)
1616
//! - [Hermitian/real symmetric matrices](solveh/index.html)
17+
//! - [Tridiagonal matrices](tridiagonal/index.html)
1718
//! - [Inverse matrix computation](solve/trait.Inverse.html)
1819
//!
1920
//! Naming Convention
@@ -66,6 +67,7 @@ pub mod svd;
6667
pub mod svddc;
6768
pub mod trace;
6869
pub mod triangular;
70+
pub mod tridiagonal;
6971
pub mod types;
7072

7173
pub use assert::*;
@@ -88,4 +90,5 @@ pub use svd::*;
8890
pub use svddc::*;
8991
pub use trace::*;
9092
pub use triangular::*;
93+
pub use tridiagonal::*;
9194
pub use types::*;

0 commit comments

Comments
 (0)