Skip to content

Commit 1a0d12d

Browse files
Implement Blockwise for SVD
1 parent ee7c946 commit 1a0d12d

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def qr(a, mode="reduced"):
523523

524524
class SVD(Op):
525525
"""
526+
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
526527
527528
Parameters
528529
----------
@@ -543,13 +544,23 @@ class SVD(Op):
543544
def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
544545
self.full_matrices = bool(full_matrices)
545546
self.compute_uv = bool(compute_uv)
547+
if self.compute_uv:
548+
if self.full_matrices:
549+
self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)"
550+
else:
551+
self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)"
552+
else:
553+
self.gufunc_signature = "(m,n)->(k)"
546554

547555
def make_node(self, x):
548556
x = as_tensor_variable(x)
549557
assert x.ndim == 2, "The input of svd function should be a matrix."
550558

551559
in_dtype = x.type.numpy_dtype
552-
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
560+
if in_dtype.name.startswith("int"):
561+
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
562+
else:
563+
out_dtype = in_dtype
553564

554565
s = vector(dtype=out_dtype)
555566

@@ -603,7 +614,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
603614
U, V, D : matrices
604615
605616
"""
606-
return SVD(full_matrices, compute_uv)(a)
617+
return Blockwise(SVD(full_matrices, compute_uv))(a)
607618

608619

609620
class Lstsq(Op):

tests/tensor/test_nlinalg.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import numpy.linalg
35
import pytest
@@ -34,6 +36,7 @@
3436
lscalar,
3537
matrix,
3638
scalar,
39+
tensor,
3740
tensor3,
3841
tensor4,
3942
vector,
@@ -150,29 +153,52 @@ def test_qr_modes():
150153

151154
class TestSvd(utt.InferShapeTester):
152155
op_class = SVD
153-
dtype = "float32"
154156

155157
def setup_method(self):
156158
super().setup_method()
157159
self.rng = np.random.default_rng(utt.fetch_seed())
158-
self.A = matrix(dtype=self.dtype)
160+
self.A = matrix(dtype=config.floatX)
159161
self.op = svd
160162

161-
def test_svd(self):
162-
A = matrix("A", dtype=self.dtype)
163-
U, S, VT = svd(A)
164-
fn = function([A], [U, S, VT])
165-
a = self.rng.random((4, 4)).astype(self.dtype)
166-
n_u, n_s, n_vt = np.linalg.svd(a)
167-
t_u, t_s, t_vt = fn(a)
163+
@pytest.mark.parametrize(
164+
"core_shape", [(3, 3), (4, 3), (3, 4)], ids=["square", "tall", "wide"]
165+
)
166+
@pytest.mark.parametrize(
167+
"full_matrix", [True, False], ids=["full=True", "full=False"]
168+
)
169+
@pytest.mark.parametrize(
170+
"compute_uv", [True, False], ids=["compute_uv=True", "compute_uv=False"]
171+
)
172+
@pytest.mark.parametrize(
173+
"batched", [True, False], ids=["batched=True", "batched=False"]
174+
)
175+
@pytest.mark.parametrize(
176+
"test_imag", [True, False], ids=["test_imag=True", "test_imag=False"]
177+
)
178+
def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
179+
dtype = config.floatX
180+
if test_imag:
181+
dtype = "complex128" if dtype.endswith("64") else "complex64"
182+
shape = core_shape if not batched else (10, *core_shape)
183+
A = tensor("A", shape=shape, dtype=dtype)
184+
a = self.rng.random(shape).astype(dtype)
185+
186+
outputs = svd(A, compute_uv=compute_uv, full_matrices=full_matrix)
187+
outputs = outputs if isinstance(outputs, list) else [outputs]
188+
fn = function(inputs=[A], outputs=outputs)
189+
190+
np_fn = np.vectorize(
191+
partial(np.linalg.svd, compute_uv=compute_uv, full_matrices=full_matrix),
192+
signature=outputs[0].owner.op.core_op.gufunc_signature,
193+
)
194+
195+
np_outputs = np_fn(a)
196+
pt_outputs = fn(a)
168197

169-
assert _allclose(n_u, t_u)
170-
assert _allclose(n_s, t_s)
171-
assert _allclose(n_vt, t_vt)
198+
np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
172199

173-
fn = function([A], svd(A, compute_uv=False))
174-
t_s = fn(a)
175-
assert _allclose(n_s, t_s)
200+
for np_val, pt_val in zip(np_outputs, pt_outputs):
201+
assert _allclose(np_val, pt_val)
176202

177203
def test_svd_infer_shape(self):
178204
self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
@@ -183,7 +209,7 @@ def test_svd_infer_shape(self):
183209

184210
def validate_shape(self, shape, compute_uv=True, full_matrices=True):
185211
A = self.A
186-
A_v = self.rng.random(shape).astype(self.dtype)
212+
A_v = self.rng.random(shape).astype(config.floatX)
187213
outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv)
188214
if not compute_uv:
189215
outputs = [outputs]
@@ -451,8 +477,8 @@ def test_non_tensorial_input(self):
451477
norm(3, None)
452478

453479
def test_tensor_input(self):
454-
with pytest.raises(NotImplementedError):
455-
norm(np.random.random((3, 4, 5)), None)
480+
res = norm(np.random.random((3, 4, 5)), None)
481+
assert res.shape.eval() == (3,)
456482

457483
def test_numpy_compare(self):
458484
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)