Skip to content

Commit dd425f4

Browse files
committed
Use vec_uninit2 in all
1 parent 41a3247 commit dd425f4

File tree

8 files changed

+57
-39
lines changed

8 files changed

+57
-39
lines changed

lax/src/least_squares.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ macro_rules! impl_least_squares {
8787
};
8888

8989
let rcond: Self::Real = -1.;
90-
let mut singular_values: Vec<Self::Real> = unsafe { vec_uninit( k as usize) };
90+
let mut singular_values: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit2( k as usize) };
9191
let mut rank: i32 = 0;
9292

9393
// eval work size
@@ -120,12 +120,12 @@ macro_rules! impl_least_squares {
120120

121121
// calc
122122
let lwork = work_size[0].to_usize().unwrap();
123-
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
123+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2(lwork) };
124124
let liwork = iwork_size[0].to_usize().unwrap();
125-
let mut iwork = unsafe { vec_uninit(liwork) };
125+
let mut iwork: Vec<MaybeUninit<i32>> = unsafe { vec_uninit2(liwork) };
126126
$(
127127
let lrwork = $rwork[0].to_usize().unwrap();
128-
let mut $rwork: Vec<Self::Real> = unsafe { vec_uninit(lrwork) };
128+
let mut $rwork: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit2(lrwork) };
129129
)*
130130
unsafe {
131131
$gelsd(
@@ -142,12 +142,14 @@ macro_rules! impl_least_squares {
142142
AsPtr::as_mut_ptr(&mut work),
143143
&(lwork as i32),
144144
$(AsPtr::as_mut_ptr(&mut $rwork),)*
145-
iwork.as_mut_ptr(),
145+
AsPtr::as_mut_ptr(&mut iwork),
146146
&mut info,
147147
);
148148
}
149149
info.as_lapack_result()?;
150150

151+
let singular_values = unsafe { singular_values.assume_init() };
152+
151153
// Skip a_t -> a transpose because A has been destroyed
152154
// Re-transpose b
153155
if let Some(b_t) = b_t {

lax/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,12 @@ macro_rules! impl_as_ptr {
147147
}
148148
};
149149
}
150+
impl_as_ptr!(i32, i32);
150151
impl_as_ptr!(f32, f32);
151152
impl_as_ptr!(f64, f64);
152153
impl_as_ptr!(c32, lapack_sys::__BindgenComplex<f32>);
153154
impl_as_ptr!(c64, lapack_sys::__BindgenComplex<f64>);
155+
impl_as_ptr!(MaybeUninit<i32>, i32);
154156
impl_as_ptr!(MaybeUninit<f32>, f32);
155157
impl_as_ptr!(MaybeUninit<f64>, f64);
156158
impl_as_ptr!(MaybeUninit<c32>, lapack_sys::__BindgenComplex<f32>);

lax/src/rcond.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ macro_rules! impl_rcond_real {
1717
let mut rcond = Self::Real::zero();
1818
let mut info = 0;
1919

20-
let mut work: Vec<Self> = unsafe { vec_uninit(4 * n as usize) };
21-
let mut iwork = unsafe { vec_uninit(n as usize) };
20+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2(4 * n as usize) };
21+
let mut iwork: Vec<MaybeUninit<i32>> = unsafe { vec_uninit2(n as usize) };
2222
let norm_type = match l {
2323
MatrixLayout::C { .. } => NormType::Infinity,
2424
MatrixLayout::F { .. } => NormType::One,
@@ -32,7 +32,7 @@ macro_rules! impl_rcond_real {
3232
&anorm,
3333
&mut rcond,
3434
AsPtr::as_mut_ptr(&mut work),
35-
iwork.as_mut_ptr(),
35+
AsPtr::as_mut_ptr(&mut iwork),
3636
&mut info,
3737
)
3838
};
@@ -54,8 +54,9 @@ macro_rules! impl_rcond_complex {
5454
let (n, _) = l.size();
5555
let mut rcond = Self::Real::zero();
5656
let mut info = 0;
57-
let mut work: Vec<Self> = unsafe { vec_uninit(2 * n as usize) };
58-
let mut rwork: Vec<Self::Real> = unsafe { vec_uninit(2 * n as usize) };
57+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2(2 * n as usize) };
58+
let mut rwork: Vec<MaybeUninit<Self::Real>> =
59+
unsafe { vec_uninit2(2 * n as usize) };
5960
let norm_type = match l {
6061
MatrixLayout::C { .. } => NormType::Infinity,
6162
MatrixLayout::F { .. } => NormType::One,

lax/src/solve.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,20 @@ macro_rules! impl_solve {
3333
return Ok(Vec::new());
3434
}
3535
let k = ::std::cmp::min(row, col);
36-
let mut ipiv = unsafe { vec_uninit(k as usize) };
36+
let mut ipiv = unsafe { vec_uninit2(k as usize) };
3737
let mut info = 0;
3838
unsafe {
3939
$getrf(
4040
&l.lda(),
4141
&l.len(),
4242
AsPtr::as_mut_ptr(a),
4343
&l.lda(),
44-
ipiv.as_mut_ptr(),
44+
AsPtr::as_mut_ptr(&mut ipiv),
4545
&mut info,
4646
)
4747
};
4848
info.as_lapack_result()?;
49+
let ipiv = unsafe { ipiv.assume_init() };
4950
Ok(ipiv)
5051
}
5152

@@ -74,7 +75,7 @@ macro_rules! impl_solve {
7475

7576
// actual
7677
let lwork = work_size[0].to_usize().unwrap();
77-
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
78+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2(lwork) };
7879
unsafe {
7980
$getri(
8081
&l.len(),

lax/src/solveh.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ macro_rules! impl_solveh {
2020
impl Solveh_ for $scalar {
2121
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
2222
let (n, _) = l.size();
23-
let mut ipiv = unsafe { vec_uninit(n as usize) };
23+
let mut ipiv = unsafe { vec_uninit2(n as usize) };
2424
if n == 0 {
2525
return Ok(Vec::new());
2626
}
@@ -34,7 +34,7 @@ macro_rules! impl_solveh {
3434
&n,
3535
AsPtr::as_mut_ptr(a),
3636
&l.lda(),
37-
ipiv.as_mut_ptr(),
37+
AsPtr::as_mut_ptr(&mut ipiv),
3838
AsPtr::as_mut_ptr(&mut work_size),
3939
&(-1),
4040
&mut info,
@@ -44,27 +44,28 @@ macro_rules! impl_solveh {
4444

4545
// actual
4646
let lwork = work_size[0].to_usize().unwrap();
47-
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
47+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2(lwork) };
4848
unsafe {
4949
$trf(
5050
uplo.as_ptr(),
5151
&n,
5252
AsPtr::as_mut_ptr(a),
5353
&l.lda(),
54-
ipiv.as_mut_ptr(),
54+
AsPtr::as_mut_ptr(&mut ipiv),
5555
AsPtr::as_mut_ptr(&mut work),
5656
&(lwork as i32),
5757
&mut info,
5858
)
5959
};
6060
info.as_lapack_result()?;
61+
let ipiv = unsafe { ipiv.assume_init() };
6162
Ok(ipiv)
6263
}
6364

6465
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
6566
let (n, _) = l.size();
6667
let mut info = 0;
67-
let mut work: Vec<Self> = unsafe { vec_uninit(n as usize) };
68+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2(n as usize) };
6869
unsafe {
6970
$tri(
7071
uplo.as_ptr(),

lax/src/svd.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,21 @@ macro_rules! impl_svd {
6565

6666
let m = l.lda();
6767
let mut u = match ju {
68-
FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
68+
FlagSVD::All => Some(unsafe { vec_uninit2( (m * m) as usize) }),
6969
FlagSVD::No => None,
7070
};
7171

7272
let n = l.len();
7373
let mut vt = match jvt {
74-
FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
74+
FlagSVD::All => Some(unsafe { vec_uninit2( (n * n) as usize) }),
7575
FlagSVD::No => None,
7676
};
7777

7878
let k = std::cmp::min(m, n);
79-
let mut s = unsafe { vec_uninit( k as usize) };
79+
let mut s = unsafe { vec_uninit2( k as usize) };
8080

8181
$(
82-
let mut $rwork_ident: Vec<Self::Real> = unsafe { vec_uninit( 5 * k as usize) };
82+
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit2( 5 * k as usize) };
8383
)*
8484

8585
// eval work size
@@ -108,7 +108,7 @@ macro_rules! impl_svd {
108108

109109
// calc
110110
let lwork = work_size[0].to_usize().unwrap();
111-
let mut work: Vec<Self> = unsafe { vec_uninit( lwork) };
111+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2( lwork) };
112112
unsafe {
113113
$gesvd(
114114
ju.as_ptr(),
@@ -129,6 +129,11 @@ macro_rules! impl_svd {
129129
);
130130
}
131131
info.as_lapack_result()?;
132+
133+
let s = unsafe { s.assume_init() };
134+
let u = u.map(|v| unsafe { v.assume_init() });
135+
let vt = vt.map(|v| unsafe { v.assume_init() });
136+
132137
match l {
133138
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
134139
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),

lax/src/svddc.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,20 @@ macro_rules! impl_svddc {
3939
let m = l.lda();
4040
let n = l.len();
4141
let k = m.min(n);
42-
let mut s = unsafe { vec_uninit( k as usize) };
42+
let mut s = unsafe { vec_uninit2( k as usize) };
4343

4444
let (u_col, vt_row) = match jobz {
4545
UVTFlag::Full | UVTFlag::None => (m, n),
4646
UVTFlag::Some => (k, k),
4747
};
4848
let (mut u, mut vt) = match jobz {
4949
UVTFlag::Full => (
50-
Some(unsafe { vec_uninit( (m * m) as usize) }),
51-
Some(unsafe { vec_uninit( (n * n) as usize) }),
50+
Some(unsafe { vec_uninit2( (m * m) as usize) }),
51+
Some(unsafe { vec_uninit2( (n * n) as usize) }),
5252
),
5353
UVTFlag::Some => (
54-
Some(unsafe { vec_uninit( (m * u_col) as usize) }),
55-
Some(unsafe { vec_uninit( (n * vt_row) as usize) }),
54+
Some(unsafe { vec_uninit2( (m * u_col) as usize) }),
55+
Some(unsafe { vec_uninit2( (n * vt_row) as usize) }),
5656
),
5757
UVTFlag::None => (None, None),
5858
};
@@ -64,12 +64,12 @@ macro_rules! impl_svddc {
6464
UVTFlag::None => 7 * mn,
6565
_ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
6666
};
67-
let mut $rwork_ident: Vec<Self::Real> = unsafe { vec_uninit( lrwork) };
67+
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit2( lrwork) };
6868
)*
6969

7070
// eval work size
7171
let mut info = 0;
72-
let mut iwork = unsafe { vec_uninit( 8 * k as usize) };
72+
let mut iwork: Vec<MaybeUninit<i32>> = unsafe { vec_uninit2( 8 * k as usize) };
7373
let mut work_size = [Self::zero()];
7474
unsafe {
7575
$gesdd(
@@ -86,15 +86,15 @@ macro_rules! impl_svddc {
8686
AsPtr::as_mut_ptr(&mut work_size),
8787
&(-1),
8888
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
89-
iwork.as_mut_ptr(),
89+
AsPtr::as_mut_ptr(&mut iwork),
9090
&mut info,
9191
);
9292
}
9393
info.as_lapack_result()?;
9494

9595
// do svd
9696
let lwork = work_size[0].to_usize().unwrap();
97-
let mut work: Vec<Self> = unsafe { vec_uninit( lwork) };
97+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2( lwork) };
9898
unsafe {
9999
$gesdd(
100100
jobz.as_ptr(),
@@ -110,12 +110,16 @@ macro_rules! impl_svddc {
110110
AsPtr::as_mut_ptr(&mut work),
111111
&(lwork as i32),
112112
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
113-
iwork.as_mut_ptr(),
113+
AsPtr::as_mut_ptr(&mut iwork),
114114
&mut info,
115115
);
116116
}
117117
info.as_lapack_result()?;
118118

119+
let s = unsafe { s.assume_init() };
120+
let u = u.map(|v| unsafe { v.assume_init() });
121+
let vt = vt.map(|v| unsafe { v.assume_init() });
122+
119123
match l {
120124
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
121125
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),

lax/src/tridiagonal.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ macro_rules! impl_tridiagonal {
152152
impl Tridiagonal_ for $scalar {
153153
fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
154154
let (n, _) = a.l.size();
155-
let mut du2 = unsafe { vec_uninit( (n - 2) as usize) };
156-
let mut ipiv = unsafe { vec_uninit( n as usize) };
155+
let mut du2 = unsafe { vec_uninit2( (n - 2) as usize) };
156+
let mut ipiv = unsafe { vec_uninit2( n as usize) };
157157
// We have to calc one-norm before LU factorization
158158
let a_opnorm_one = a.opnorm_one();
159159
let mut info = 0;
@@ -164,11 +164,13 @@ macro_rules! impl_tridiagonal {
164164
AsPtr::as_mut_ptr(&mut a.d),
165165
AsPtr::as_mut_ptr(&mut a.du),
166166
AsPtr::as_mut_ptr(&mut du2),
167-
ipiv.as_mut_ptr(),
167+
AsPtr::as_mut_ptr(&mut ipiv),
168168
&mut info,
169169
)
170170
};
171171
info.as_lapack_result()?;
172+
let du2 = unsafe { du2.assume_init() };
173+
let ipiv = unsafe { ipiv.assume_init() };
172174
Ok(LUFactorizedTridiagonal {
173175
a,
174176
du2,
@@ -180,9 +182,9 @@ macro_rules! impl_tridiagonal {
180182
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
181183
let (n, _) = lu.a.l.size();
182184
let ipiv = &lu.ipiv;
183-
let mut work: Vec<Self> = unsafe { vec_uninit( 2 * n as usize) };
185+
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit2( 2 * n as usize) };
184186
$(
185-
let mut $iwork = unsafe { vec_uninit( n as usize) };
187+
let mut $iwork: Vec<MaybeUninit<i32>> = unsafe { vec_uninit2( n as usize) };
186188
)*
187189
let mut rcond = Self::Real::zero();
188190
let mut info = 0;
@@ -198,7 +200,7 @@ macro_rules! impl_tridiagonal {
198200
&lu.a_opnorm_one,
199201
&mut rcond,
200202
AsPtr::as_mut_ptr(&mut work),
201-
$($iwork.as_mut_ptr(),)*
203+
$(AsPtr::as_mut_ptr(&mut $iwork),)*
202204
&mut info,
203205
);
204206
}

0 commit comments

Comments
 (0)