|
3 | 3 | use super::*;
|
4 | 4 | use crate::{error::*, layout::MatrixLayout};
|
5 | 5 | use cauchy::*;
|
6 |
| -use num_traits::Zero; |
| 6 | +use num_traits::{ToPrimitive, Zero}; |
7 | 7 |
|
8 |
| -/// Wraps `*getrf`, `*getri`, and `*getrs` |
9 | 8 | pub trait Solve_: Scalar + Sized {
|
10 | 9 | /// Computes the LU factorization of a general `m x n` matrix `a` using
|
11 | 10 | /// partial pivoting with row interchanges.
|
12 | 11 | ///
|
13 |
| - /// If the result matches `Err(LinalgError::Lapack(LapackError { |
14 |
| - /// return_code )) if return_code > 0`, then `U[(return_code-1, |
15 |
| - /// return_code-1)]` is exactly zero. The factorization has been completed, |
16 |
| - /// but the factor `U` is exactly singular, and division by zero will occur |
17 |
| - /// if it is used to solve a system of equations. |
18 |
| - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>; |
19 |
| - unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; |
20 |
| - /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. |
| 12 | + /// $ PA = LU $ |
21 | 13 | ///
|
22 |
| - /// `anorm` should be the 1-norm of the matrix `a`. |
23 |
| - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>; |
24 |
| - unsafe fn solve( |
25 |
| - l: MatrixLayout, |
26 |
| - t: Transpose, |
27 |
| - a: &[Self], |
28 |
| - p: &Pivot, |
29 |
| - b: &mut [Self], |
30 |
| - ) -> Result<()>; |
| 14 | + /// Error |
| 15 | + /// ------ |
| 16 | + /// - `LapackComputationalFailure { return_code }` when the matrix is singular |
| 17 | + /// - Division by zero will occur if it is used to solve a system of equations |
| 18 | + /// because `U[(return_code-1, return_code-1)]` is exactly zero. |
| 19 | + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>; |
| 20 | + |
| 21 | + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; |
| 22 | + |
| 23 | + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; |
31 | 24 | }
|
32 | 25 |
|
33 | 26 | macro_rules! impl_solve {
|
34 |
| - ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { |
| 27 | + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { |
35 | 28 | impl Solve_ for $scalar {
|
36 |
| - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> { |
| 29 | + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> { |
37 | 30 | let (row, col) = l.size();
|
| 31 | + assert_eq!(a.len() as i32, row * col); |
| 32 | + if row == 0 || col == 0 { |
| 33 | + // Do nothing for empty matrix |
| 34 | + return Ok(Vec::new()); |
| 35 | + } |
38 | 36 | let k = ::std::cmp::min(row, col);
|
39 | 37 | let mut ipiv = vec![0; k as usize];
|
40 |
| - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?; |
| 38 | + let mut info = 0; |
| 39 | + unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; |
| 40 | + info.as_lapack_result()?; |
41 | 41 | Ok(ipiv)
|
42 | 42 | }
|
43 | 43 |
|
44 |
| - unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { |
| 44 | + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { |
45 | 45 | let (n, _) = l.size();
|
46 |
| - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; |
47 |
| - Ok(()) |
48 |
| - } |
49 | 46 |
|
50 |
| - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> { |
51 |
| - let (n, _) = l.size(); |
52 |
| - let mut rcond = Self::Real::zero(); |
53 |
| - $gecon( |
54 |
| - l.lapacke_layout(), |
55 |
| - NormType::One as u8, |
56 |
| - n, |
57 |
| - a, |
58 |
| - l.lda(), |
59 |
| - anorm, |
60 |
| - &mut rcond, |
61 |
| - ) |
62 |
| - .as_lapack_result()?; |
63 |
| - Ok(rcond) |
| 47 | + // calc work size |
| 48 | + let mut info = 0; |
| 49 | + let mut work_size = [Self::zero()]; |
| 50 | + unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; |
| 51 | + info.as_lapack_result()?; |
| 52 | + |
| 53 | + // actual |
| 54 | + let lwork = work_size[0].to_usize().unwrap(); |
| 55 | + let mut work = vec![Self::zero(); lwork]; |
| 56 | + unsafe { |
| 57 | + $getri( |
| 58 | + l.len(), |
| 59 | + a, |
| 60 | + l.lda(), |
| 61 | + ipiv, |
| 62 | + &mut work, |
| 63 | + lwork as i32, |
| 64 | + &mut info, |
| 65 | + ) |
| 66 | + }; |
| 67 | + info.as_lapack_result()?; |
| 68 | + |
| 69 | + Ok(()) |
64 | 70 | }
|
65 | 71 |
|
66 |
| - unsafe fn solve( |
| 72 | + fn solve( |
67 | 73 | l: MatrixLayout,
|
68 | 74 | t: Transpose,
|
69 | 75 | a: &[Self],
|
70 | 76 | ipiv: &Pivot,
|
71 | 77 | b: &mut [Self],
|
72 | 78 | ) -> Result<()> {
|
| 79 | + let t = match l { |
| 80 | + MatrixLayout::C { .. } => match t { |
| 81 | + Transpose::No => Transpose::Transpose, |
| 82 | + Transpose::Transpose | Transpose::Hermite => Transpose::No, |
| 83 | + }, |
| 84 | + _ => t, |
| 85 | + }; |
73 | 86 | let (n, _) = l.size();
|
74 | 87 | let nrhs = 1;
|
75 |
| - let ldb = 1; |
76 |
| - $getrs( |
77 |
| - l.lapacke_layout(), |
78 |
| - t as u8, |
79 |
| - n, |
80 |
| - nrhs, |
81 |
| - a, |
82 |
| - l.lda(), |
83 |
| - ipiv, |
84 |
| - b, |
85 |
| - ldb, |
86 |
| - ) |
87 |
| - .as_lapack_result()?; |
| 88 | + let ldb = l.lda(); |
| 89 | + let mut info = 0; |
| 90 | + unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; |
| 91 | + info.as_lapack_result()?; |
88 | 92 | Ok(())
|
89 | 93 | }
|
90 | 94 | }
|
91 | 95 | };
|
92 | 96 | } // impl_solve!
|
93 | 97 |
|
94 |
| -impl_solve!( |
95 |
| - f64, |
96 |
| - lapacke::dgetrf, |
97 |
| - lapacke::dgetri, |
98 |
| - lapacke::dgecon, |
99 |
| - lapacke::dgetrs |
100 |
| -); |
101 |
| -impl_solve!( |
102 |
| - f32, |
103 |
| - lapacke::sgetrf, |
104 |
| - lapacke::sgetri, |
105 |
| - lapacke::sgecon, |
106 |
| - lapacke::sgetrs |
107 |
| -); |
108 |
| -impl_solve!( |
109 |
| - c64, |
110 |
| - lapacke::zgetrf, |
111 |
| - lapacke::zgetri, |
112 |
| - lapacke::zgecon, |
113 |
| - lapacke::zgetrs |
114 |
| -); |
115 |
| -impl_solve!( |
116 |
| - c32, |
117 |
| - lapacke::cgetrf, |
118 |
| - lapacke::cgetri, |
119 |
| - lapacke::cgecon, |
120 |
| - lapacke::cgetrs |
121 |
| -); |
| 98 | +impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); |
| 99 | +impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); |
| 100 | +impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); |
| 101 | +impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); |
0 commit comments