Skip to content

Commit ea91231

Browse files
committed
Impl bk, solveh, and invh using LAPACK
1 parent 43908b0 commit ea91231

File tree

1 file changed

+44
-28
lines changed

1 file changed

+44
-28
lines changed

lax/src/solveh.rs

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
use super::*;
66
use crate::{error::*, layout::MatrixLayout};
77
use cauchy::*;
8+
use num_traits::{ToPrimitive, Zero};
89

910
pub trait Solveh_: Sized {
1011
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
@@ -28,13 +29,39 @@ macro_rules! impl_solveh {
2829
let (n, _) = l.size();
2930
let mut ipiv = vec![0; n as usize];
3031
if n == 0 {
31-
// Work around bug in LAPACKE functions.
32-
Ok(ipiv)
33-
} else {
34-
$trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv)
35-
.as_lapack_result()?;
36-
Ok(ipiv)
32+
return Ok(Vec::new());
3733
}
34+
35+
// calc work size
36+
let mut info = 0;
37+
let mut work_size = [Self::zero()];
38+
$trf(
39+
uplo as u8,
40+
n,
41+
a,
42+
l.lda(),
43+
&mut ipiv,
44+
&mut work_size,
45+
-1,
46+
&mut info,
47+
);
48+
info.as_lapack_result()?;
49+
50+
// actual
51+
let lwork = work_size[0].to_usize().unwrap();
52+
let mut work = vec![Self::zero(); lwork];
53+
$trf(
54+
uplo as u8,
55+
n,
56+
a,
57+
l.lda(),
58+
&mut ipiv,
59+
&mut work,
60+
lwork as i32,
61+
&mut info,
62+
);
63+
info.as_lapack_result()?;
64+
Ok(ipiv)
3865
}
3966

4067
unsafe fn invh(
@@ -44,7 +71,10 @@ macro_rules! impl_solveh {
4471
ipiv: &Pivot,
4572
) -> Result<()> {
4673
let (n, _) = l.size();
47-
$tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv).as_lapack_result()?;
74+
let mut info = 0;
75+
let mut work = vec![Self::zero(); n as usize];
76+
$tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info);
77+
info.as_lapack_result()?;
4878
Ok(())
4979
}
5080

@@ -56,30 +86,16 @@ macro_rules! impl_solveh {
5686
b: &mut [Self],
5787
) -> Result<()> {
5888
let (n, _) = l.size();
59-
let nrhs = 1;
60-
let ldb = match l {
61-
MatrixLayout::C { .. } => 1,
62-
MatrixLayout::F { .. } => n,
63-
};
64-
$trs(
65-
l.lapacke_layout(),
66-
uplo as u8,
67-
n,
68-
nrhs,
69-
a,
70-
l.lda(),
71-
ipiv,
72-
b,
73-
ldb,
74-
)
75-
.as_lapack_result()?;
89+
let mut info = 0;
90+
$trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info);
91+
info.as_lapack_result()?;
7692
Ok(())
7793
}
7894
}
7995
};
8096
} // impl_solveh!
8197

82-
impl_solveh!(f64, lapacke::dsytrf, lapacke::dsytri, lapacke::dsytrs);
83-
impl_solveh!(f32, lapacke::ssytrf, lapacke::ssytri, lapacke::ssytrs);
84-
impl_solveh!(c64, lapacke::zhetrf, lapacke::zhetri, lapacke::zhetrs);
85-
impl_solveh!(c32, lapacke::chetrf, lapacke::chetri, lapacke::chetrs);
98+
impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs);
99+
impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs);
100+
impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs);
101+
impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs);

0 commit comments

Comments
 (0)