Skip to content

Commit 4771438

Browse files
committed
impl Index/IndexMut & impl opnorm
1 parent 5196445 commit 4771438

File tree

4 files changed

+171
-20
lines changed

4 files changed

+171
-20
lines changed

src/lapack/tridiagonal.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ use super::{into_result, Pivot, Transpose};
1010

1111
use crate::error::*;
1212
use crate::layout::MatrixLayout;
13+
use crate::opnorm::*;
1314
use crate::tridiagonal::{LUFactorizedTriDiagonal, TriDiagonal};
1415
use crate::types::*;
1516

1617
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
1718
pub trait TriDiagonal_: Scalar + Sized {
1819
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
1920
/// partial pivoting with row interchanges.
20-
unsafe fn lu_tridiagonal(a: &mut TriDiagonal<Self>) -> Result<(Array1<Self>, Pivot)>;
21+
unsafe fn lu_tridiagonal(
22+
a: &mut TriDiagonal<Self>,
23+
) -> Result<(Array1<Self>, Self::Real, Pivot)>;
2124
/// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
2225
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal<Self>) -> Result<Self::Real>;
2326
unsafe fn solve_tridiagonal(
@@ -31,15 +34,18 @@ pub trait TriDiagonal_: Scalar + Sized {
3134
macro_rules! impl_tridiagonal {
3235
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
3336
impl TriDiagonal_ for $scalar {
34-
unsafe fn lu_tridiagonal(a: &mut TriDiagonal<Self>) -> Result<(Array1<Self>, Pivot)> {
37+
unsafe fn lu_tridiagonal(
38+
a: &mut TriDiagonal<Self>,
39+
) -> Result<(Array1<Self>, Self::Real, Pivot)> {
3540
let (n, _) = a.l.size();
41+
let anom = a.opnorm_one()?;
3642
let dl = a.dl.as_slice_mut().unwrap();
3743
let d = a.d.as_slice_mut().unwrap();
3844
let du = a.du.as_slice_mut().unwrap();
3945
let mut du2 = vec![Zero::zero(); (n - 2) as usize];
4046
let mut ipiv = vec![0; n as usize];
4147
let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv);
42-
into_result(info, (arr1(&du2), ipiv))
48+
into_result(info, (arr1(&du2), anom, ipiv))
4349
}
4450

4551
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal<Self>) -> Result<Self::Real> {
@@ -49,7 +55,7 @@ macro_rules! impl_tridiagonal {
4955
let du = lu.a.du.as_slice().unwrap();
5056
let du2 = lu.du2.as_slice().unwrap();
5157
let ipiv = &lu.ipiv;
52-
let anorm = lu.a.n1;
58+
let anorm = lu.anom;
5359
let mut rcond = Self::Real::zero();
5460
let info = $gtcon(
5561
NormType::One as u8,

src/opnorm.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
33
use ndarray::*;
44

5+
use crate::convert::*;
56
use crate::error::*;
67
use crate::layout::*;
8+
use crate::tridiagonal::TriDiagonal;
79
use crate::types::*;
810

911
pub use crate::lapack::NormType;
@@ -46,3 +48,54 @@ where
4648
Ok(unsafe { A::opnorm(t, l, a) })
4749
}
4850
}
51+
52+
impl<A> OperationNorm for TriDiagonal<A>
53+
where
54+
A: Scalar + Lapack,
55+
{
56+
type Output = A::Real;
57+
58+
fn opnorm(&self, t: NormType) -> Result<Self::Output> {
59+
let arr = match t {
60+
NormType::One => {
61+
let zl: Array1<A> = Array::zeros(1);
62+
let zu: Array1<A> = Array::zeros(1);
63+
let dl = stack![Axis(0), self.dl.to_owned(), zl];
64+
let du = stack![Axis(0), zu, self.du.to_owned()];
65+
let arr = stack![
66+
Axis(0),
67+
into_row(du),
68+
into_row(self.d.to_owned()),
69+
into_row(dl)
70+
];
71+
arr
72+
}
73+
NormType::Infinity => {
74+
let zl: Array1<A> = Array::zeros(1);
75+
let zu: Array1<A> = Array::zeros(1);
76+
let dl = stack![Axis(0), zl, self.dl.to_owned()];
77+
let du = stack![Axis(0), self.du.to_owned(), zu];
78+
let arr = stack![
79+
Axis(1),
80+
into_col(dl),
81+
into_col(self.d.to_owned()),
82+
into_col(du)
83+
];
84+
arr
85+
}
86+
NormType::Frobenius => {
87+
let arr = stack![
88+
Axis(1),
89+
into_row(self.dl.to_owned()),
90+
into_row(self.d.to_owned()),
91+
into_row(self.du.to_owned())
92+
];
93+
arr
94+
}
95+
};
96+
97+
let l = arr.layout()?;
98+
let a = arr.as_allocated()?;
99+
Ok(unsafe { A::opnorm(t, l, a) })
100+
}
101+
}

src/tridiagonal.rs

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,23 @@
22
//! &
33
//! Methods for tridiagonal matrices
44
5+
use std::ops::{Index, IndexMut};
6+
57
use cauchy::Scalar;
68
use ndarray::*;
79
use num_traits::One;
810

9-
use crate::opnorm::OperationNorm;
10-
1111
use super::convert::*;
1212
use super::error::*;
1313
use super::lapack::*;
1414
use super::layout::*;
1515

1616
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
17-
/// This struct also holds the layout and 1-norm of the raw matrix
18-
/// for some methods (eg. rcond_tridiagonal()).
19-
#[derive(Clone)]
17+
/// This struct also holds the layout of the raw matrix.
18+
#[derive(Clone, PartialEq)]
2019
pub struct TriDiagonal<A: Scalar> {
2120
/// layout of raw matrix
2221
pub l: MatrixLayout,
23-
/// the one norm of raw matrix
24-
pub n1: <A as Scalar>::Real,
2522
/// (n-1) sub-diagonal elements of matrix.
2623
pub dl: Array1<A>,
2724
/// (n) diagonal elements of matrix.
@@ -30,10 +27,73 @@ pub struct TriDiagonal<A: Scalar> {
3027
pub du: Array1<A>,
3128
}
3229

30+
pub trait TridiagIndex {
31+
fn to_tuple(&self) -> (i32, i32);
32+
}
33+
impl TridiagIndex for [Ix; 2] {
34+
fn to_tuple(&self) -> (i32, i32) {
35+
(self[0] as i32, self[1] as i32)
36+
}
37+
}
38+
39+
fn debug_bounds_check_tridiag(n: i32, row: i32, col: i32) {
40+
if std::cmp::max(row, col) >= n {
41+
panic!(
42+
"ndarray: index {:?} is out of bounds for array of shape {}",
43+
[row, col],
44+
n
45+
);
46+
}
47+
}
48+
49+
impl<A, I> Index<I> for TriDiagonal<A>
50+
where
51+
A: Scalar,
52+
I: TridiagIndex,
53+
{
54+
type Output = A;
55+
#[inline]
56+
fn index(&self, index: I) -> &A {
57+
let (n, _) = self.l.size();
58+
let (row, col) = index.to_tuple();
59+
debug_bounds_check_tridiag(n, row, col);
60+
match row - col {
61+
0 => &self.d[row as usize],
62+
1 => &self.dl[col as usize],
63+
-1 => &self.du[row as usize],
64+
_ => panic!(
65+
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
66+
[row, col]
67+
),
68+
}
69+
}
70+
}
71+
72+
impl<A, I> IndexMut<I> for TriDiagonal<A>
73+
where
74+
A: Scalar,
75+
I: TridiagIndex,
76+
{
77+
#[inline]
78+
fn index_mut(&mut self, index: I) -> &mut A {
79+
let (n, _) = self.l.size();
80+
let (row, col) = index.to_tuple();
81+
debug_bounds_check_tridiag(n, row, col);
82+
match row - col {
83+
0 => &mut self.d[row as usize],
84+
1 => &mut self.dl[col as usize],
85+
-1 => &mut self.du[row as usize],
86+
_ => panic!(
87+
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
88+
[row, col]
89+
),
90+
}
91+
}
92+
}
93+
3394
/// An interface for making a TriDiagonal struct.
3495
pub trait ToTriDiagonal<A: Scalar> {
3596
/// Extract tridiagonal elements and layout of the raw matrix.
36-
/// And also calculate 1-norm.
3797
///
3898
/// If the raw matrix has some non-tridiagonal elements,
3999
/// they will be ignored.
@@ -53,12 +113,11 @@ where
53113
if n < 2 {
54114
panic!("Cannot make a tridiagonal matrix of shape=(1, 1)!");
55115
}
56-
let n1 = self.opnorm_one()?;
57116

58117
let dl = self.slice(s![1..n, 0..n - 1]).diag().to_owned();
59118
let d = self.diag().to_owned();
60119
let du = self.slice(s![0..n - 1, 1..n]).diag().to_owned();
61-
Ok(TriDiagonal { l, n1, dl, d, du })
120+
Ok(TriDiagonal { l, dl, d, du })
62121
}
63122
}
64123

@@ -130,13 +189,14 @@ pub trait SolveTriDiagonalInplace<A: Scalar, D: Dimension> {
130189
pub struct LUFactorizedTriDiagonal<A: Scalar> {
131190
/// A tridiagonal matrix which consists of
132191
/// - l : layout of raw matrix
133-
/// - n1: the one norm of raw matrix
134192
/// - dl: (n-1) multipliers that define the matrix L.
135193
/// - d : (n) diagonal elements of the upper triangular matrix U.
136194
/// - du: (n-1) elements of the first super-diagonal of U.
137195
pub a: TriDiagonal<A>,
138196
/// (n-2) elements of the second super-diagonal of U.
139197
pub du2: Array1<A>,
198+
/// 1-norm of raw matrix (used in .rcond_tridiagonal()).
199+
pub anom: A::Real,
140200
/// The pivot indices that define the permutation matrix `P`.
141201
pub ipiv: Pivot,
142202
}
@@ -598,10 +658,11 @@ where
598658
A: Scalar + Lapack,
599659
{
600660
fn factorize_tridiagonal_into(mut self) -> Result<LUFactorizedTriDiagonal<A>> {
601-
let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? };
661+
let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? };
602662
Ok(LUFactorizedTriDiagonal {
603663
a: self,
604664
du2: du2,
665+
anom: anom,
605666
ipiv: ipiv,
606667
})
607668
}
@@ -613,8 +674,8 @@ where
613674
{
614675
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTriDiagonal<A>> {
615676
let mut a = self.clone();
616-
let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
617-
Ok(LUFactorizedTriDiagonal { a, du2, ipiv })
677+
let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
678+
Ok(LUFactorizedTriDiagonal { a, du2, anom, ipiv })
618679
}
619680
}
620681

@@ -625,8 +686,8 @@ where
625686
{
626687
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTriDiagonal<A>> {
627688
let mut a = self.to_tridiagonal()?;
628-
let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
629-
Ok(LUFactorizedTriDiagonal { a, du2, ipiv })
689+
let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
690+
Ok(LUFactorizedTriDiagonal { a, du2, anom, ipiv })
630691
}
631692
}
632693

tests/tridiagonal.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,37 @@ fn to_tridiagonal() {
1010
assert_close_l2!(&t.du, &arr1(&[2.0, 6.0]), 1e-7);
1111
}
1212

13+
#[test]
14+
fn tridiagonal_index() {
15+
let a: Array2<f64> = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
16+
let t1 = a.to_tridiagonal().unwrap();
17+
let mut t2 = Array2::<f64>::eye(3).to_tridiagonal().unwrap();
18+
t2[[0, 1]] = 2.0;
19+
t2[[1, 0]] = 4.0;
20+
t2[[1, 1]] += 4.0;
21+
t2[[1, 2]] = 6.0;
22+
t2[[2, 1]] = 8.0;
23+
t2[[2, 2]] += 8.0;
24+
assert_eq!(t1.dl, t2.dl);
25+
assert_eq!(t1.d, t2.d);
26+
assert_eq!(t1.du, t2.du);
27+
}
28+
29+
#[test]
30+
fn opnorm_tridiagonal() {
31+
let mut a: Array2<f64> = random((4, 4));
32+
a[[0, 2]] = 0.0;
33+
a[[0, 3]] = 0.0;
34+
a[[1, 3]] = 0.0;
35+
a[[2, 0]] = 0.0;
36+
a[[3, 0]] = 0.0;
37+
a[[3, 1]] = 0.0;
38+
let t = a.to_tridiagonal().unwrap();
39+
assert_aclose!(a.opnorm_one().unwrap(), t.opnorm_one().unwrap(), 1e-7);
40+
assert_aclose!(a.opnorm_inf().unwrap(), t.opnorm_inf().unwrap(), 1e-7);
41+
assert_aclose!(a.opnorm_fro().unwrap(), t.opnorm_fro().unwrap(), 1e-7);
42+
}
43+
1344
#[test]
1445
fn solve_tridiagonal_f64() {
1546
// https://www.nag-j.co.jp/lapack/dgttrs.htm

0 commit comments

Comments
 (0)