5
5
use super :: * ;
6
6
use crate :: { error:: * , layout:: MatrixLayout } ;
7
7
use cauchy:: * ;
8
+ use num_traits:: { ToPrimitive , Zero } ;
8
9
9
10
pub trait Solveh_ : Sized {
10
11
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
@@ -28,13 +29,39 @@ macro_rules! impl_solveh {
28
29
let ( n, _) = l. size( ) ;
29
30
let mut ipiv = vec![ 0 ; n as usize ] ;
30
31
if n == 0 {
31
- // Work around bug in LAPACKE functions.
32
- Ok ( ipiv)
33
- } else {
34
- $trf( l. lapacke_layout( ) , uplo as u8 , n, a, l. lda( ) , & mut ipiv)
35
- . as_lapack_result( ) ?;
36
- Ok ( ipiv)
32
+ return Ok ( Vec :: new( ) ) ;
37
33
}
34
+
35
+ // calc work size
36
+ let mut info = 0 ;
37
+ let mut work_size = [ Self :: zero( ) ] ;
38
+ $trf(
39
+ uplo as u8 ,
40
+ n,
41
+ a,
42
+ l. lda( ) ,
43
+ & mut ipiv,
44
+ & mut work_size,
45
+ -1 ,
46
+ & mut info,
47
+ ) ;
48
+ info. as_lapack_result( ) ?;
49
+
50
+ // actual
51
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
52
+ let mut work = vec![ Self :: zero( ) ; lwork] ;
53
+ $trf(
54
+ uplo as u8 ,
55
+ n,
56
+ a,
57
+ l. lda( ) ,
58
+ & mut ipiv,
59
+ & mut work,
60
+ lwork as i32 ,
61
+ & mut info,
62
+ ) ;
63
+ info. as_lapack_result( ) ?;
64
+ Ok ( ipiv)
38
65
}
39
66
40
67
unsafe fn invh(
@@ -44,7 +71,10 @@ macro_rules! impl_solveh {
44
71
ipiv: & Pivot ,
45
72
) -> Result <( ) > {
46
73
let ( n, _) = l. size( ) ;
47
- $tri( l. lapacke_layout( ) , uplo as u8 , n, a, l. lda( ) , ipiv) . as_lapack_result( ) ?;
74
+ let mut info = 0 ;
75
+ let mut work = vec![ Self :: zero( ) ; n as usize ] ;
76
+ $tri( uplo as u8 , n, a, l. lda( ) , ipiv, & mut work, & mut info) ;
77
+ info. as_lapack_result( ) ?;
48
78
Ok ( ( ) )
49
79
}
50
80
@@ -56,30 +86,16 @@ macro_rules! impl_solveh {
56
86
b: & mut [ Self ] ,
57
87
) -> Result <( ) > {
58
88
let ( n, _) = l. size( ) ;
59
- let nrhs = 1 ;
60
- let ldb = match l {
61
- MatrixLayout :: C { .. } => 1 ,
62
- MatrixLayout :: F { .. } => n,
63
- } ;
64
- $trs(
65
- l. lapacke_layout( ) ,
66
- uplo as u8 ,
67
- n,
68
- nrhs,
69
- a,
70
- l. lda( ) ,
71
- ipiv,
72
- b,
73
- ldb,
74
- )
75
- . as_lapack_result( ) ?;
89
+ let mut info = 0 ;
90
+ $trs( uplo as u8 , n, 1 , a, l. lda( ) , ipiv, b, n, & mut info) ;
91
+ info. as_lapack_result( ) ?;
76
92
Ok ( ( ) )
77
93
}
78
94
}
79
95
} ;
80
96
} // impl_solveh!
81
97
82
- impl_solveh ! ( f64 , lapacke :: dsytrf, lapacke :: dsytri, lapacke :: dsytrs) ;
83
- impl_solveh ! ( f32 , lapacke :: ssytrf, lapacke :: ssytri, lapacke :: ssytrs) ;
84
- impl_solveh ! ( c64, lapacke :: zhetrf, lapacke :: zhetri, lapacke :: zhetrs) ;
85
- impl_solveh ! ( c32, lapacke :: chetrf, lapacke :: chetri, lapacke :: chetrs) ;
98
+ impl_solveh ! ( f64 , lapack :: dsytrf, lapack :: dsytri, lapack :: dsytrs) ;
99
+ impl_solveh ! ( f32 , lapack :: ssytrf, lapack :: ssytri, lapack :: ssytrs) ;
100
+ impl_solveh ! ( c64, lapack :: zhetrf, lapack :: zhetri, lapack :: zhetrs) ;
101
+ impl_solveh ! ( c32, lapack :: chetrf, lapack :: chetri, lapack :: chetrs) ;
0 commit comments