Skip to content

Commit 791713f

Browse files
authored
Merge pull request #328 from rust-ndarray/use-lapack-sys-directly
Use lapack-sys crate directly from lax crate
2 parents 0cfde73 + 2fc0ac8 commit 791713f

File tree

15 files changed

+682
-411
lines changed

15 files changed

+682
-411
lines changed

lax/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ intel-mkl-system = ["intel-mkl-src/mkl-dynamic-lp64-seq"]
3232
thiserror = "1.0.24"
3333
cauchy = "0.4.0"
3434
num-traits = "0.2.14"
35-
lapack = "0.18.0"
35+
lapack-sys = "0.14.0"
3636

3737
[dependencies.intel-mkl-src]
3838
version = "0.7.0"

lax/src/cholesky.rs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ macro_rules! impl_cholesky {
2929
}
3030
let mut info = 0;
3131
unsafe {
32-
$trf(uplo as u8, n, a, n, &mut info);
32+
$trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info);
3333
}
3434
info.as_lapack_result()?;
3535
if matches!(l, MatrixLayout::C { .. }) {
@@ -45,7 +45,7 @@ macro_rules! impl_cholesky {
4545
}
4646
let mut info = 0;
4747
unsafe {
48-
$tri(uplo as u8, n, a, l.lda(), &mut info);
48+
$tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info);
4949
}
5050
info.as_lapack_result()?;
5151
if matches!(l, MatrixLayout::C { .. }) {
@@ -70,7 +70,16 @@ macro_rules! impl_cholesky {
7070
}
7171
}
7272
unsafe {
73-
$trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
73+
$trs(
74+
uplo.as_ptr(),
75+
&n,
76+
&nrhs,
77+
AsPtr::as_ptr(a),
78+
&l.lda(),
79+
AsPtr::as_mut_ptr(b),
80+
&n,
81+
&mut info,
82+
);
7483
}
7584
info.as_lapack_result()?;
7685
if matches!(l, MatrixLayout::C { .. }) {
@@ -84,7 +93,27 @@ macro_rules! impl_cholesky {
8493
};
8594
} // end macro_rules
8695

87-
impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
88-
impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
89-
impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
90-
impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);
96+
impl_cholesky!(
97+
f64,
98+
lapack_sys::dpotrf_,
99+
lapack_sys::dpotri_,
100+
lapack_sys::dpotrs_
101+
);
102+
impl_cholesky!(
103+
f32,
104+
lapack_sys::spotrf_,
105+
lapack_sys::spotri_,
106+
lapack_sys::spotrs_
107+
);
108+
impl_cholesky!(
109+
c64,
110+
lapack_sys::zpotrf_,
111+
lapack_sys::zpotri_,
112+
lapack_sys::zpotrs_
113+
);
114+
impl_cholesky!(
115+
c32,
116+
lapack_sys::cpotrf_,
117+
lapack_sys::cpotri_,
118+
lapack_sys::cpotrs_
119+
);

lax/src/eig.rs

Lines changed: 82 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ macro_rules! impl_eig_complex {
2020
fn eig(
2121
calc_v: bool,
2222
l: MatrixLayout,
23-
mut a: &mut [Self],
23+
a: &mut [Self],
2424
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
2525
let (n, _) = l.size();
2626
// LAPACK assumes a column-major input. A row-major input can
@@ -35,74 +35,69 @@ macro_rules! impl_eig_complex {
3535
// eigenvalues are the eigenvalues computed with `A`.
3636
let (jobvl, jobvr) = if calc_v {
3737
match l {
38-
MatrixLayout::C { .. } => (b'V', b'N'),
39-
MatrixLayout::F { .. } => (b'N', b'V'),
38+
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
39+
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
4040
}
4141
} else {
42-
(b'N', b'N')
42+
(EigenVectorFlag::Not, EigenVectorFlag::Not)
4343
};
4444
let mut eigs = unsafe { vec_uninit(n as usize) };
45-
let mut rwork = unsafe { vec_uninit(2 * n as usize) };
45+
let mut rwork: Vec<Self::Real> = unsafe { vec_uninit(2 * n as usize) };
4646

47-
let mut vl = if jobvl == b'V' {
48-
Some(unsafe { vec_uninit((n * n) as usize) })
49-
} else {
50-
None
51-
};
52-
let mut vr = if jobvr == b'V' {
53-
Some(unsafe { vec_uninit((n * n) as usize) })
54-
} else {
55-
None
56-
};
47+
let mut vl: Option<Vec<Self>> =
48+
jobvl.then(|| unsafe { vec_uninit((n * n) as usize) });
49+
let mut vr: Option<Vec<Self>> =
50+
jobvr.then(|| unsafe { vec_uninit((n * n) as usize) });
5751

5852
// calc work size
5953
let mut info = 0;
6054
let mut work_size = [Self::zero()];
6155
unsafe {
6256
$ev(
63-
jobvl,
64-
jobvr,
65-
n,
66-
&mut a,
67-
n,
68-
&mut eigs,
69-
&mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
70-
n,
71-
&mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
72-
n,
73-
&mut work_size,
74-
-1,
75-
&mut rwork,
57+
jobvl.as_ptr(),
58+
jobvr.as_ptr(),
59+
&n,
60+
AsPtr::as_mut_ptr(a),
61+
&n,
62+
AsPtr::as_mut_ptr(&mut eigs),
63+
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
64+
&n,
65+
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
66+
&n,
67+
AsPtr::as_mut_ptr(&mut work_size),
68+
&(-1),
69+
AsPtr::as_mut_ptr(&mut rwork),
7670
&mut info,
7771
)
7872
};
7973
info.as_lapack_result()?;
8074

8175
// actal ev
8276
let lwork = work_size[0].to_usize().unwrap();
83-
let mut work = unsafe { vec_uninit(lwork) };
77+
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
78+
let lwork = lwork as i32;
8479
unsafe {
8580
$ev(
86-
jobvl,
87-
jobvr,
88-
n,
89-
&mut a,
90-
n,
91-
&mut eigs,
92-
&mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
93-
n,
94-
&mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
95-
n,
96-
&mut work,
97-
lwork as i32,
98-
&mut rwork,
81+
jobvl.as_ptr(),
82+
jobvr.as_ptr(),
83+
&n,
84+
AsPtr::as_mut_ptr(a),
85+
&n,
86+
AsPtr::as_mut_ptr(&mut eigs),
87+
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
88+
&n,
89+
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
90+
&n,
91+
AsPtr::as_mut_ptr(&mut work),
92+
&lwork,
93+
AsPtr::as_mut_ptr(&mut rwork),
9994
&mut info,
10095
)
10196
};
10297
info.as_lapack_result()?;
10398

10499
// Hermite conjugate
105-
if jobvl == b'V' {
100+
if jobvl.is_calc() {
106101
for c in vl.as_mut().unwrap().iter_mut() {
107102
c.im = -c.im
108103
}
@@ -114,16 +109,16 @@ macro_rules! impl_eig_complex {
114109
};
115110
}
116111

117-
impl_eig_complex!(c64, lapack::zgeev);
118-
impl_eig_complex!(c32, lapack::cgeev);
112+
impl_eig_complex!(c64, lapack_sys::zgeev_);
113+
impl_eig_complex!(c32, lapack_sys::cgeev_);
119114

120115
macro_rules! impl_eig_real {
121116
($scalar:ty, $ev:path) => {
122117
impl Eig_ for $scalar {
123118
fn eig(
124119
calc_v: bool,
125120
l: MatrixLayout,
126-
mut a: &mut [Self],
121+
a: &mut [Self],
127122
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
128123
let (n, _) = l.size();
129124
// LAPACK assumes a column-major input. A row-major input can
@@ -144,67 +139,62 @@ macro_rules! impl_eig_real {
144139
// `sgeev`/`dgeev`.
145140
let (jobvl, jobvr) = if calc_v {
146141
match l {
147-
MatrixLayout::C { .. } => (b'V', b'N'),
148-
MatrixLayout::F { .. } => (b'N', b'V'),
142+
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
143+
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
149144
}
150145
} else {
151-
(b'N', b'N')
146+
(EigenVectorFlag::Not, EigenVectorFlag::Not)
152147
};
153-
let mut eig_re = unsafe { vec_uninit(n as usize) };
154-
let mut eig_im = unsafe { vec_uninit(n as usize) };
148+
let mut eig_re: Vec<Self> = unsafe { vec_uninit(n as usize) };
149+
let mut eig_im: Vec<Self> = unsafe { vec_uninit(n as usize) };
155150

156-
let mut vl = if jobvl == b'V' {
157-
Some(unsafe { vec_uninit((n * n) as usize) })
158-
} else {
159-
None
160-
};
161-
let mut vr = if jobvr == b'V' {
162-
Some(unsafe { vec_uninit((n * n) as usize) })
163-
} else {
164-
None
165-
};
151+
let mut vl: Option<Vec<Self>> =
152+
jobvl.then(|| unsafe { vec_uninit((n * n) as usize) });
153+
let mut vr: Option<Vec<Self>> =
154+
jobvr.then(|| unsafe { vec_uninit((n * n) as usize) });
166155

167156
// calc work size
168157
let mut info = 0;
169-
let mut work_size = [0.0];
158+
let mut work_size: [Self; 1] = [0.0];
170159
unsafe {
171160
$ev(
172-
jobvl,
173-
jobvr,
174-
n,
175-
&mut a,
176-
n,
177-
&mut eig_re,
178-
&mut eig_im,
179-
vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
180-
n,
181-
vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
182-
n,
183-
&mut work_size,
184-
-1,
161+
jobvl.as_ptr(),
162+
jobvr.as_ptr(),
163+
&n,
164+
AsPtr::as_mut_ptr(a),
165+
&n,
166+
AsPtr::as_mut_ptr(&mut eig_re),
167+
AsPtr::as_mut_ptr(&mut eig_im),
168+
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
169+
&n,
170+
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
171+
&n,
172+
AsPtr::as_mut_ptr(&mut work_size),
173+
&(-1),
185174
&mut info,
186175
)
187176
};
188177
info.as_lapack_result()?;
189178

190179
// actual ev
191180
let lwork = work_size[0].to_usize().unwrap();
192-
let mut work = unsafe { vec_uninit(lwork) };
181+
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
182+
let lwork = lwork as i32;
193183
unsafe {
194184
$ev(
195-
jobvl,
196-
jobvr,
197-
n,
198-
&mut a,
199-
n,
200-
&mut eig_re,
201-
&mut eig_im,
202-
vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
203-
n,
204-
vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
205-
n,
206-
&mut work,
207-
lwork as i32,
185+
jobvl.as_ptr(),
186+
jobvr.as_ptr(),
187+
&n,
188+
AsPtr::as_mut_ptr(a),
189+
&n,
190+
AsPtr::as_mut_ptr(&mut eig_re),
191+
AsPtr::as_mut_ptr(&mut eig_im),
192+
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
193+
&n,
194+
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
195+
&n,
196+
AsPtr::as_mut_ptr(&mut work),
197+
&lwork,
208198
&mut info,
209199
)
210200
};
@@ -254,7 +244,7 @@ macro_rules! impl_eig_real {
254244
for row in 0..n {
255245
let re = v[row + col * n];
256246
let mut im = v[row + (col + 1) * n];
257-
if jobvl == b'V' {
247+
if jobvl.is_calc() {
258248
im = -im;
259249
}
260250
eigvecs[row + col * n] = Self::complex(re, im);
@@ -270,5 +260,5 @@ macro_rules! impl_eig_real {
270260
};
271261
}
272262

273-
impl_eig_real!(f64, lapack::dgeev);
274-
impl_eig_real!(f32, lapack::sgeev);
263+
impl_eig_real!(f64, lapack_sys::dgeev_);
264+
impl_eig_real!(f32, lapack_sys::sgeev_);

0 commit comments

Comments
 (0)