Skip to content

Commit 43908b0

Browse files
authored
Merge pull request #213 from rust-ndarray/lapack-solve
LU decomposition based algorithms using LAPACK
2 parents 261e79a + c9b6d13 commit 43908b0

File tree

4 files changed

+184
-125
lines changed

4 files changed

+184
-125
lines changed

lax/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pub mod layout;
7070
pub mod least_squares;
7171
pub mod opnorm;
7272
pub mod qr;
73+
pub mod rcond;
7374
pub mod solve;
7475
pub mod solveh;
7576
pub mod svd;
@@ -83,6 +84,7 @@ pub use self::eigh::*;
8384
pub use self::least_squares::*;
8485
pub use self::opnorm::*;
8586
pub use self::qr::*;
87+
pub use self::rcond::*;
8688
pub use self::solve::*;
8789
pub use self::solveh::*;
8890
pub use self::svd::*;
@@ -107,6 +109,7 @@ pub trait Lapack:
107109
+ Eigh_
108110
+ Triangular_
109111
+ Tridiagonal_
112+
+ Rcond_
110113
{
111114
}
112115

lax/src/rcond.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use super::*;
2+
use crate::{error::*, layout::MatrixLayout};
3+
use cauchy::*;
4+
use num_traits::Zero;
5+
6+
pub trait Rcond_: Scalar + Sized {
7+
/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
8+
///
9+
/// `anorm` should be the 1-norm of the matrix `a`.
10+
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;
11+
}
12+
13+
macro_rules! impl_rcond_real {
14+
($scalar:ty, $gecon:path) => {
15+
impl Rcond_ for $scalar {
16+
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
17+
let (n, _) = l.size();
18+
let mut rcond = Self::Real::zero();
19+
let mut info = 0;
20+
21+
let mut work = vec![Self::zero(); 4 * n as usize];
22+
let mut iwork = vec![0; n as usize];
23+
let norm_type = match l {
24+
MatrixLayout::C { .. } => NormType::Infinity,
25+
MatrixLayout::F { .. } => NormType::One,
26+
} as u8;
27+
unsafe {
28+
$gecon(
29+
norm_type,
30+
n,
31+
a,
32+
l.lda(),
33+
anorm,
34+
&mut rcond,
35+
&mut work,
36+
&mut iwork,
37+
&mut info,
38+
)
39+
};
40+
info.as_lapack_result()?;
41+
42+
Ok(rcond)
43+
}
44+
}
45+
};
46+
}
47+
48+
impl_rcond_real!(f32, lapack::sgecon);
49+
impl_rcond_real!(f64, lapack::dgecon);
50+
51+
macro_rules! impl_rcond_complex {
52+
($scalar:ty, $gecon:path) => {
53+
impl Rcond_ for $scalar {
54+
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
55+
let (n, _) = l.size();
56+
let mut rcond = Self::Real::zero();
57+
let mut info = 0;
58+
let mut work = vec![Self::zero(); 2 * n as usize];
59+
let mut rwork = vec![Self::Real::zero(); 2 * n as usize];
60+
let norm_type = match l {
61+
MatrixLayout::C { .. } => NormType::Infinity,
62+
MatrixLayout::F { .. } => NormType::One,
63+
} as u8;
64+
unsafe {
65+
$gecon(
66+
norm_type,
67+
n,
68+
a,
69+
l.lda(),
70+
anorm,
71+
&mut rcond,
72+
&mut work,
73+
&mut rwork,
74+
&mut info,
75+
)
76+
};
77+
info.as_lapack_result()?;
78+
79+
Ok(rcond)
80+
}
81+
}
82+
};
83+
}
84+
85+
impl_rcond_complex!(c32, lapack::cgecon);
86+
impl_rcond_complex!(c64, lapack::zgecon);

lax/src/solve.rs

Lines changed: 62 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,119 +3,99 @@
33
use super::*;
44
use crate::{error::*, layout::MatrixLayout};
55
use cauchy::*;
6-
use num_traits::Zero;
6+
use num_traits::{ToPrimitive, Zero};
77

8-
/// Wraps `*getrf`, `*getri`, and `*getrs`
98
pub trait Solve_: Scalar + Sized {
109
/// Computes the LU factorization of a general `m x n` matrix `a` using
1110
/// partial pivoting with row interchanges.
1211
///
13-
/// If the result matches `Err(LinalgError::Lapack(LapackError {
14-
/// return_code )) if return_code > 0`, then `U[(return_code-1,
15-
/// return_code-1)]` is exactly zero. The factorization has been completed,
16-
/// but the factor `U` is exactly singular, and division by zero will occur
17-
/// if it is used to solve a system of equations.
18-
unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
19-
unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
20-
/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
12+
/// $ PA = LU $
2113
///
22-
/// `anorm` should be the 1-norm of the matrix `a`.
23-
unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;
24-
unsafe fn solve(
25-
l: MatrixLayout,
26-
t: Transpose,
27-
a: &[Self],
28-
p: &Pivot,
29-
b: &mut [Self],
30-
) -> Result<()>;
14+
/// Error
15+
/// ------
16+
/// - `LapackComputationalFailure { return_code }` when the matrix is singular
17+
/// - Division by zero will occur if it is used to solve a system of equations
18+
/// because `U[(return_code-1, return_code-1)]` is exactly zero.
19+
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
20+
21+
fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
22+
23+
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
3124
}
3225

3326
macro_rules! impl_solve {
34-
($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => {
27+
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
3528
impl Solve_ for $scalar {
36-
unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
29+
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
3730
let (row, col) = l.size();
31+
assert_eq!(a.len() as i32, row * col);
32+
if row == 0 || col == 0 {
33+
// Do nothing for empty matrix
34+
return Ok(Vec::new());
35+
}
3836
let k = ::std::cmp::min(row, col);
3937
let mut ipiv = vec![0; k as usize];
40-
$getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?;
38+
let mut info = 0;
39+
unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
40+
info.as_lapack_result()?;
4141
Ok(ipiv)
4242
}
4343

44-
unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
44+
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
4545
let (n, _) = l.size();
46-
$getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?;
47-
Ok(())
48-
}
4946

50-
unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
51-
let (n, _) = l.size();
52-
let mut rcond = Self::Real::zero();
53-
$gecon(
54-
l.lapacke_layout(),
55-
NormType::One as u8,
56-
n,
57-
a,
58-
l.lda(),
59-
anorm,
60-
&mut rcond,
61-
)
62-
.as_lapack_result()?;
63-
Ok(rcond)
47+
// calc work size
48+
let mut info = 0;
49+
let mut work_size = [Self::zero()];
50+
unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
51+
info.as_lapack_result()?;
52+
53+
// actual
54+
let lwork = work_size[0].to_usize().unwrap();
55+
let mut work = vec![Self::zero(); lwork];
56+
unsafe {
57+
$getri(
58+
l.len(),
59+
a,
60+
l.lda(),
61+
ipiv,
62+
&mut work,
63+
lwork as i32,
64+
&mut info,
65+
)
66+
};
67+
info.as_lapack_result()?;
68+
69+
Ok(())
6470
}
6571

66-
unsafe fn solve(
72+
fn solve(
6773
l: MatrixLayout,
6874
t: Transpose,
6975
a: &[Self],
7076
ipiv: &Pivot,
7177
b: &mut [Self],
7278
) -> Result<()> {
79+
let t = match l {
80+
MatrixLayout::C { .. } => match t {
81+
Transpose::No => Transpose::Transpose,
82+
Transpose::Transpose | Transpose::Hermite => Transpose::No,
83+
},
84+
_ => t,
85+
};
7386
let (n, _) = l.size();
7487
let nrhs = 1;
75-
let ldb = 1;
76-
$getrs(
77-
l.lapacke_layout(),
78-
t as u8,
79-
n,
80-
nrhs,
81-
a,
82-
l.lda(),
83-
ipiv,
84-
b,
85-
ldb,
86-
)
87-
.as_lapack_result()?;
88+
let ldb = l.lda();
89+
let mut info = 0;
90+
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
91+
info.as_lapack_result()?;
8892
Ok(())
8993
}
9094
}
9195
};
9296
} // impl_solve!
9397

94-
impl_solve!(
95-
f64,
96-
lapacke::dgetrf,
97-
lapacke::dgetri,
98-
lapacke::dgecon,
99-
lapacke::dgetrs
100-
);
101-
impl_solve!(
102-
f32,
103-
lapacke::sgetrf,
104-
lapacke::sgetri,
105-
lapacke::sgecon,
106-
lapacke::sgetrs
107-
);
108-
impl_solve!(
109-
c64,
110-
lapacke::zgetrf,
111-
lapacke::zgetri,
112-
lapacke::zgecon,
113-
lapacke::zgetrs
114-
);
115-
impl_solve!(
116-
c32,
117-
lapacke::cgetrf,
118-
lapacke::cgetri,
119-
lapacke::cgecon,
120-
lapacke::cgetrs
121-
);
98+
impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
99+
impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
100+
impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
101+
impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);

0 commit comments

Comments
 (0)