diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 4eb8ff13..ea5bb119 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -2,7 +2,7 @@ //! for tridiagonal matrix use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; use num_traits::Zero; use std::ops::{Index, IndexMut}; @@ -130,11 +130,11 @@ impl IndexMut<[i32; 2]> for Tridiagonal { pub trait Tridiagonal_: Scalar + Sized { /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using /// partial pivoting with row interchanges. - unsafe fn lu_tridiagonal(a: Tridiagonal) -> Result>; + fn lu_tridiagonal(a: Tridiagonal) -> Result>; - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, bl: MatrixLayout, t: Transpose, @@ -143,18 +143,23 @@ pub trait Tridiagonal_: Scalar + Sized { } macro_rules! impl_tridiagonal { - ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork); + }; + (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, ); + }; + (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { impl Tridiagonal_ for $scalar { - unsafe fn lu_tridiagonal( - mut a: Tridiagonal, - ) -> Result> { + fn lu_tridiagonal(mut a: Tridiagonal) -> Result> { let (n, _) = a.l.size(); let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); - $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv) - .as_lapack_result()?; + let mut info = 0; + unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; + info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, du2, @@ -163,56 +168,80 @@ macro_rules! impl_tridiagonal { }) } - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; + let mut work = vec![Self::zero(); 2 * n as usize]; + $( + let mut $iwork = vec![0; n as usize]; + )* let mut rcond = Self::Real::zero(); - $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, - &mut rcond, - ) - .as_lapack_result()?; + let mut info = 0; + unsafe { + $gtcon( + NormType::One as u8, + n, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + lu.a_opnorm_one, + &mut rcond, + &mut work, + $(&mut $iwork,)* + &mut info, + ); + } + info.as_lapack_result()?; Ok(rcond) } - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, - bl: MatrixLayout, + b_layout: MatrixLayout, t: Transpose, b: &mut [Self], ) -> Result<()> { let (n, _) = lu.a.l.size(); - let (_, nrhs) = bl.size(); let ipiv = &lu.ipiv; - let ldb = bl.lda(); - $gttrs( - lu.a.l.lapacke_layout(), - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let (ldb, nrhs) = b_layout.size(); + let mut info = 0; + unsafe { + $gttrs( + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + ldb, + &mut info, + ); + } + info.as_lapack_result()?; + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } Ok(()) } } }; } // impl_tridiagonal! -impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); -impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); -impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); -impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); +impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); +impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); +impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); +impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs);