Skip to content

Commit fc1fca3

Browse files
authored
Merge pull request #216 from rust-ndarray/lapack-solveh
Triangular factorization methods for real-symmetric, Hermitian matrix using LAPACK
2 parents 43908b0 + 9ccb6af commit fc1fca3

File tree

2 files changed

+54
-31
lines changed

2 files changed

+54
-31
lines changed

lax/src/solveh.rs

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
use super::*;
66
use crate::{error::*, layout::MatrixLayout};
77
use cauchy::*;
8+
use num_traits::{ToPrimitive, Zero};
89

910
pub trait Solveh_: Sized {
1011
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
@@ -28,13 +29,39 @@ macro_rules! impl_solveh {
2829
let (n, _) = l.size();
2930
let mut ipiv = vec![0; n as usize];
3031
if n == 0 {
31-
// Work around bug in LAPACKE functions.
32-
Ok(ipiv)
33-
} else {
34-
$trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv)
35-
.as_lapack_result()?;
36-
Ok(ipiv)
32+
return Ok(Vec::new());
3733
}
34+
35+
// calc work size
36+
let mut info = 0;
37+
let mut work_size = [Self::zero()];
38+
$trf(
39+
uplo as u8,
40+
n,
41+
a,
42+
l.lda(),
43+
&mut ipiv,
44+
&mut work_size,
45+
-1,
46+
&mut info,
47+
);
48+
info.as_lapack_result()?;
49+
50+
// actual
51+
let lwork = work_size[0].to_usize().unwrap();
52+
let mut work = vec![Self::zero(); lwork];
53+
$trf(
54+
uplo as u8,
55+
n,
56+
a,
57+
l.lda(),
58+
&mut ipiv,
59+
&mut work,
60+
lwork as i32,
61+
&mut info,
62+
);
63+
info.as_lapack_result()?;
64+
Ok(ipiv)
3865
}
3966

4067
unsafe fn invh(
@@ -44,7 +71,10 @@ macro_rules! impl_solveh {
4471
ipiv: &Pivot,
4572
) -> Result<()> {
4673
let (n, _) = l.size();
47-
$tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv).as_lapack_result()?;
74+
let mut info = 0;
75+
let mut work = vec![Self::zero(); n as usize];
76+
$tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info);
77+
info.as_lapack_result()?;
4878
Ok(())
4979
}
5080

@@ -56,30 +86,16 @@ macro_rules! impl_solveh {
5686
b: &mut [Self],
5787
) -> Result<()> {
5888
let (n, _) = l.size();
59-
let nrhs = 1;
60-
let ldb = match l {
61-
MatrixLayout::C { .. } => 1,
62-
MatrixLayout::F { .. } => n,
63-
};
64-
$trs(
65-
l.lapacke_layout(),
66-
uplo as u8,
67-
n,
68-
nrhs,
69-
a,
70-
l.lda(),
71-
ipiv,
72-
b,
73-
ldb,
74-
)
75-
.as_lapack_result()?;
89+
let mut info = 0;
90+
$trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info);
91+
info.as_lapack_result()?;
7692
Ok(())
7793
}
7894
}
7995
};
8096
} // impl_solveh!
8197

82-
impl_solveh!(f64, lapacke::dsytrf, lapacke::dsytri, lapacke::dsytrs);
83-
impl_solveh!(f32, lapacke::ssytrf, lapacke::ssytri, lapacke::ssytrs);
84-
impl_solveh!(c64, lapacke::zhetrf, lapacke::zhetri, lapacke::zhetrs);
85-
impl_solveh!(c32, lapacke::chetrf, lapacke::chetri, lapacke::chetrs);
98+
impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs);
99+
impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs);
100+
impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs);
101+
impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs);

ndarray-linalg/src/solveh.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ where
314314
S: Data<Elem = A>,
315315
A: Scalar + Lapack,
316316
{
317+
let layout = a.layout().unwrap();
317318
let mut sign = A::Real::one();
318319
let mut ln_det = A::Real::zero();
319320
let mut ipiv_enum = ipiv_iter.enumerate();
@@ -337,9 +338,15 @@ where
337338
debug_assert_eq!(lower_diag.im(), Zero::zero());
338339

339340
// Off-diagonal elements, can be complex.
340-
let off_diag = match uplo {
341-
UPLO::Upper => unsafe { a.uget((k, k + 1)) },
342-
UPLO::Lower => unsafe { a.uget((k + 1, k)) },
341+
let off_diag = match layout {
342+
MatrixLayout::C { .. } => match uplo {
343+
UPLO::Upper => unsafe { a.uget((k + 1, k)) },
344+
UPLO::Lower => unsafe { a.uget((k, k + 1)) },
345+
},
346+
MatrixLayout::F { .. } => match uplo {
347+
UPLO::Upper => unsafe { a.uget((k, k + 1)) },
348+
UPLO::Lower => unsafe { a.uget((k + 1, k)) },
349+
},
343350
};
344351

345352
// Determinant of 2x2 block.

0 commit comments

Comments
 (0)