Skip to content

Commit 54f7d25

Browse files
authored
Merge pull request #197 from paulkoerbitz/master
Expose lapack routines for solving least squares problems
2 parents 7e1f1bf + 7ab8f33 commit 54f7d25

File tree

5 files changed

+895
-1
lines changed

5 files changed

+895
-1
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ rand = "0.5"
3131

3232
[dependencies.ndarray]
3333
version = "0.13.0"
34-
features = ["blas"]
34+
features = ["blas", "approx"]
3535
default-features = false
3636

3737
[dependencies.blas-src]
@@ -51,6 +51,7 @@ optional = true
5151
[dev-dependencies]
5252
paste = "0.1.9"
5353
criterion = "0.3.1"
54+
approx = { version = "0.3.2", features = ["num-complex"] }
5455

5556
[[bench]]
5657
name = "truncated_eig"

src/lapack/least_squares.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//! Least squares
2+
3+
use lapacke;
4+
use ndarray::{ErrorKind, ShapeError};
5+
use num_traits::Zero;
6+
7+
use crate::error::*;
8+
use crate::layout::MatrixLayout;
9+
use crate::types::*;
10+
11+
use super::into_result;
12+
13+
/// Result of LeastSquares
14+
pub struct LeastSquaresOutput<A: Scalar> {
15+
/// singular values
16+
pub singular_values: Vec<A::Real>,
17+
/// The rank of the input matrix A
18+
pub rank: i32,
19+
}
20+
21+
/// Wraps `*gelsd`
22+
pub trait LeastSquaresSvdDivideConquer_: Scalar {
23+
unsafe fn least_squares(
24+
a_layout: MatrixLayout,
25+
a: &mut [Self],
26+
b: &mut [Self],
27+
) -> Result<LeastSquaresOutput<Self>>;
28+
29+
unsafe fn least_squares_nrhs(
30+
a_layout: MatrixLayout,
31+
a: &mut [Self],
32+
b_layout: MatrixLayout,
33+
b: &mut [Self],
34+
) -> Result<LeastSquaresOutput<Self>>;
35+
}
36+
37+
macro_rules! impl_least_squares {
38+
($scalar:ty, $gelsd:path) => {
39+
impl LeastSquaresSvdDivideConquer_ for $scalar {
40+
unsafe fn least_squares(
41+
a_layout: MatrixLayout,
42+
a: &mut [Self],
43+
b: &mut [Self],
44+
) -> Result<LeastSquaresOutput<Self>> {
45+
let (m, n) = a_layout.size();
46+
if (m as usize) > b.len() || (n as usize) > b.len() {
47+
return Err(LinalgError::Shape(ShapeError::from_kind(
48+
ErrorKind::IncompatibleShape,
49+
)));
50+
}
51+
let k = ::std::cmp::min(m, n);
52+
let nrhs = 1;
53+
let rcond: Self::Real = -1.;
54+
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
55+
let mut rank: i32 = 0;
56+
57+
let status = $gelsd(
58+
a_layout.lapacke_layout(),
59+
m,
60+
n,
61+
nrhs,
62+
a,
63+
a_layout.lda(),
64+
b,
65+
// this is the 'leading dimension of b', in the case where
66+
// b is a single vector, this is 1
67+
nrhs,
68+
&mut singular_values,
69+
rcond,
70+
&mut rank,
71+
);
72+
73+
into_result(
74+
status,
75+
LeastSquaresOutput {
76+
singular_values,
77+
rank,
78+
},
79+
)
80+
}
81+
82+
unsafe fn least_squares_nrhs(
83+
a_layout: MatrixLayout,
84+
a: &mut [Self],
85+
b_layout: MatrixLayout,
86+
b: &mut [Self],
87+
) -> Result<LeastSquaresOutput<Self>> {
88+
let (m, n) = a_layout.size();
89+
if (m as usize) > b.len()
90+
|| (n as usize) > b.len()
91+
|| a_layout.lapacke_layout() != b_layout.lapacke_layout()
92+
{
93+
return Err(LinalgError::Shape(ShapeError::from_kind(
94+
ErrorKind::IncompatibleShape,
95+
)));
96+
}
97+
let k = ::std::cmp::min(m, n);
98+
let nrhs = b_layout.size().1;
99+
let rcond: Self::Real = -1.;
100+
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
101+
let mut rank: i32 = 0;
102+
103+
let status = $gelsd(
104+
a_layout.lapacke_layout(),
105+
m,
106+
n,
107+
nrhs,
108+
a,
109+
a_layout.lda(),
110+
b,
111+
b_layout.lda(),
112+
&mut singular_values,
113+
rcond,
114+
&mut rank,
115+
);
116+
117+
into_result(
118+
status,
119+
LeastSquaresOutput {
120+
singular_values,
121+
rank,
122+
},
123+
)
124+
}
125+
}
126+
};
127+
}
128+
129+
impl_least_squares!(f64, lapacke::dgelsd);
130+
impl_least_squares!(f32, lapacke::sgelsd);
131+
impl_least_squares!(c64, lapacke::zgelsd);
132+
impl_least_squares!(c32, lapacke::cgelsd);

src/lapack/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
pub mod cholesky;
44
pub mod eig;
55
pub mod eigh;
6+
pub mod least_squares;
67
pub mod opnorm;
78
pub mod qr;
89
pub mod solve;
@@ -14,6 +15,7 @@ pub mod triangular;
1415
pub use self::cholesky::*;
1516
pub use self::eig::*;
1617
pub use self::eigh::*;
18+
pub use self::least_squares::*;
1719
pub use self::opnorm::*;
1820
pub use self::qr::*;
1921
pub use self::solve::*;

0 commit comments

Comments
 (0)