Skip to content

Commit 5196445

Browse files
committed
impl Solve_Tridiagonal for Tridiagonal
1 parent 819c664 commit 5196445

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

src/tridiagonal.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,57 @@ where
189189
}
190190
}
191191

192+
impl<A> SolveTriDiagonal<A, Ix2> for TriDiagonal<A>
193+
where
194+
A: Scalar + Lapack,
195+
{
196+
fn solve_tridiagonal<Sb: Data<Elem = A>>(
197+
&self,
198+
b: &ArrayBase<Sb, Ix2>,
199+
) -> Result<Array<A, Ix2>> {
200+
let mut b = replicate(b);
201+
self.solve_tridiagonal_inplace(&mut b)?;
202+
Ok(b)
203+
}
204+
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
205+
&self,
206+
mut b: ArrayBase<Sb, Ix2>,
207+
) -> Result<ArrayBase<Sb, Ix2>> {
208+
self.solve_tridiagonal_inplace(&mut b)?;
209+
Ok(b)
210+
}
211+
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
212+
&self,
213+
b: &ArrayBase<Sb, Ix2>,
214+
) -> Result<Array<A, Ix2>> {
215+
let mut b = replicate(b);
216+
self.solve_t_tridiagonal_inplace(&mut b)?;
217+
Ok(b)
218+
}
219+
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
220+
&self,
221+
mut b: ArrayBase<Sb, Ix2>,
222+
) -> Result<ArrayBase<Sb, Ix2>> {
223+
self.solve_t_tridiagonal_inplace(&mut b)?;
224+
Ok(b)
225+
}
226+
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
227+
&self,
228+
b: &ArrayBase<Sb, Ix2>,
229+
) -> Result<Array<A, Ix2>> {
230+
let mut b = replicate(b);
231+
self.solve_h_tridiagonal_inplace(&mut b)?;
232+
Ok(b)
233+
}
234+
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
235+
&self,
236+
mut b: ArrayBase<Sb, Ix2>,
237+
) -> Result<ArrayBase<Sb, Ix2>> {
238+
self.solve_h_tridiagonal_inplace(&mut b)?;
239+
Ok(b)
240+
}
241+
}
242+
192243
impl<A, S> SolveTriDiagonal<A, Ix2> for ArrayBase<S, Ix2>
193244
where
194245
A: Scalar + Lapack,
@@ -298,6 +349,42 @@ where
298349
}
299350
}
300351

352+
impl<A> SolveTriDiagonalInplace<A, Ix2> for TriDiagonal<A>
353+
where
354+
A: Scalar + Lapack,
355+
{
356+
fn solve_tridiagonal_inplace<'a, Sb>(
357+
&self,
358+
rhs: &'a mut ArrayBase<Sb, Ix2>,
359+
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
360+
where
361+
Sb: DataMut<Elem = A>,
362+
{
363+
let f = self.factorize_tridiagonal()?;
364+
f.solve_tridiagonal_inplace(rhs)
365+
}
366+
fn solve_t_tridiagonal_inplace<'a, Sb>(
367+
&self,
368+
rhs: &'a mut ArrayBase<Sb, Ix2>,
369+
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
370+
where
371+
Sb: DataMut<Elem = A>,
372+
{
373+
let f = self.factorize_tridiagonal()?;
374+
f.solve_t_tridiagonal_inplace(rhs)
375+
}
376+
fn solve_h_tridiagonal_inplace<'a, Sb>(
377+
&self,
378+
rhs: &'a mut ArrayBase<Sb, Ix2>,
379+
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
380+
where
381+
Sb: DataMut<Elem = A>,
382+
{
383+
let f = self.factorize_tridiagonal()?;
384+
f.solve_h_tridiagonal_inplace(rhs)
385+
}
386+
}
387+
301388
impl<A, S> SolveTriDiagonalInplace<A, Ix2> for ArrayBase<S, Ix2>
302389
where
303390
A: Scalar + Lapack,
@@ -383,6 +470,60 @@ where
383470
}
384471
}
385472

473+
impl<A> SolveTriDiagonal<A, Ix1> for TriDiagonal<A>
474+
where
475+
A: Scalar + Lapack,
476+
{
477+
fn solve_tridiagonal<Sb: Data<Elem = A>>(
478+
&self,
479+
b: &ArrayBase<Sb, Ix1>,
480+
) -> Result<Array<A, Ix1>> {
481+
let b = b.to_owned();
482+
self.solve_tridiagonal_into(b)
483+
}
484+
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
485+
&self,
486+
b: ArrayBase<Sb, Ix1>,
487+
) -> Result<ArrayBase<Sb, Ix1>> {
488+
let b = into_col(b);
489+
let f = self.factorize_tridiagonal()?;
490+
let b = f.solve_tridiagonal_into(b)?;
491+
Ok(flatten(b))
492+
}
493+
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
494+
&self,
495+
b: &ArrayBase<Sb, Ix1>,
496+
) -> Result<Array<A, Ix1>> {
497+
let b = b.to_owned();
498+
self.solve_t_tridiagonal_into(b)
499+
}
500+
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
501+
&self,
502+
b: ArrayBase<Sb, Ix1>,
503+
) -> Result<ArrayBase<Sb, Ix1>> {
504+
let b = into_col(b);
505+
let f = self.factorize_tridiagonal()?;
506+
let b = f.solve_t_tridiagonal_into(b)?;
507+
Ok(flatten(b))
508+
}
509+
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
510+
&self,
511+
b: &ArrayBase<Sb, Ix1>,
512+
) -> Result<Array<A, Ix1>> {
513+
let b = b.to_owned();
514+
self.solve_h_tridiagonal_into(b)
515+
}
516+
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
517+
&self,
518+
b: ArrayBase<Sb, Ix1>,
519+
) -> Result<ArrayBase<Sb, Ix1>> {
520+
let b = into_col(b);
521+
let f = self.factorize_tridiagonal()?;
522+
let b = f.solve_h_tridiagonal_into(b)?;
523+
Ok(flatten(b))
524+
}
525+
}
526+
386527
impl<A, S> SolveTriDiagonal<A, Ix1> for ArrayBase<S, Ix2>
387528
where
388529
A: Scalar + Lapack,

tests/tridiagonal.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,21 @@ fn solve_tridiagonal_random_t() {
125125
assert_close_l2!(&y1, &y2, 1e-7);
126126
}
127127

128+
#[test]
129+
fn to_tridiagonal_solve_random() {
130+
let mut a: Array2<f64> = random((3, 3));
131+
a[[0, 2]] = 0.0;
132+
a[[2, 0]] = 0.0;
133+
let tridiag = a.to_tridiagonal().unwrap();
134+
let x: Array1<f64> = random(3);
135+
let b1 = a.dot(&x);
136+
let b2 = b1.clone();
137+
let y1 = tridiag.solve_tridiagonal_into(b1).unwrap();
138+
let y2 = a.solve_into(b2).unwrap();
139+
assert_close_l2!(&x, &y1, 1e-7);
140+
assert_close_l2!(&y1, &y2, 1e-7);
141+
}
142+
128143
#[test]
129144
fn det_tridiagonal_f64() {
130145
let a: Array2<f64> = arr2(&[[10.0, -9.0, 0.0], [7.0, -12.0, 11.0], [0.0, 10.0, 3.0]]);

0 commit comments

Comments
 (0)