Skip to content

Commit 207f55e

Browse files
Add gradient for svd
1 parent b63ee0c commit 207f55e

File tree

2 files changed

+107
-9
lines changed

2 files changed

+107
-9
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import warnings
2+
from collections.abc import Sequence
23
from functools import partial
4+
from typing import Union, cast
35

46
import numpy as np
57

8+
from pytensor import Variable
69
from pytensor import scalar as ps
710
from pytensor.gradient import DisconnectedType
811
from pytensor.graph.basic import Apply
@@ -582,26 +585,87 @@ def infer_shape(self, fgraph, node, shapes):
582585
else:
583586
return [s_shape]
584587

588+
def L_op(
589+
self,
590+
inputs: Sequence[Variable],
591+
outputs: Sequence[Variable],
592+
output_grads: Sequence[Variable],
593+
) -> list[Variable]:
594+
(A,) = inputs
595+
A = cast(ptb.TensorVariable, A)
585596

586-
def svd(a, full_matrices: bool = True, compute_uv: bool = True):
597+
if not self.compute_uv:
598+
(S_grad,) = output_grads
599+
S_grad = cast(ptb.TensorVariable, S_grad)
600+
601+
# Need U and V so do the whole svd anyway...
602+
[u, s, v] = svd(A, full_matrices=False, compute_uv=True) # type: ignore
603+
u = cast(ptb.TensorVariable, u)
604+
605+
return [(u.conj() * S_grad[..., None, :]) @ v]
606+
607+
elif self.full_matrices:
608+
raise NotImplementedError(
609+
"Gradient of svd not implemented for full_matrices=True"
610+
)
611+
612+
else:
613+
u, s, v = (cast(ptb.TensorVariable, x) for x in outputs)
614+
gu, gs, gv = (cast(ptb.TensorVariable, x) for x in output_grads)
615+
616+
m, n = A.shape[-2:]
617+
618+
k = ptm.min((m, n))
619+
# broadcastable identity array with shape (1, 1, ..., 1, k, k)
620+
# i = anp.reshape(anp.eye(k), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (k, k))))
621+
622+
eye = ptb.eye(k)
623+
f = 1 / (s[..., None, :] ** 2 - s[..., :, None] ** 2 + eye)
624+
625+
utgu = u.T @ gu
626+
vtgv = v.T @ gv
627+
t1 = f * (utgu - utgu.conj().T * s[..., None, :])
628+
t1 = t1 + eye * gs[..., :, None]
629+
t1 = t1 + s[..., :, None] * (f * (vtgv - vtgv.conj().T))
630+
631+
if u.dtype.startswith("complex"):
632+
t1 = t1 + 1j * ptb.diag(utgu.imag) / s[..., None, :]
633+
634+
t1 = u.conj() @ t1 @ v.T
635+
t1 = cast(ptb.TensorVariable, t1)
636+
637+
if m < n:
638+
eye_n = ptb.eye(n)
639+
i_minus_vtt = eye_n - (v @ v.conj().T)
640+
t1 = t1 + (u / s[..., None, :] @ gv.T @ i_minus_vtt).conj()
641+
642+
elif m > n:
643+
eye_m = ptb.eye(n)
644+
i_minus_uut = eye_m - u @ u.conj().T
645+
t1 = t1 + v / s[..., None, :] @ gu.T @ i_minus_uut
646+
647+
return [t1]
648+
649+
650+
def svd(
651+
a, full_matrices: bool = True, compute_uv: bool = True
652+
) -> Union[Variable, list[Variable]]:
587653
"""
588654
This function performs the SVD on CPU.
589655
590656
Parameters
591657
----------
592658
full_matrices : bool, optional
593-
If True (default), u and v have the shapes (M, M) and (N, N),
594-
respectively.
595-
Otherwise, the shapes are (M, K) and (K, N), respectively,
596-
where K = min(M, N).
659+
If True (default), u and v have the shapes (M, M) and (N, N), respectively. Otherwise, the shapes are (M, K)
660+
and (K, N), respectively, where K = min(M, N).
597661
compute_uv : bool, optional
598-
Whether or not to compute u and v in addition to s.
599-
True by default.
662+
Whether or not to compute u and v in addition to s. True by default.
600663
601664
Returns
602665
-------
603-
U, V, D : matrices
604-
666+
matrices: TensorVariable or list of TensorVariable
667+
Result of singular value decomposition. If compute_uv is True, return a list of TensorVariable [U, S, V].
668+
Otherwise, returns only singular values S.
605669
"""
606670
return SVD(full_matrices, compute_uv)(a)
607671

tests/tensor/test_nlinalg.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import functools as ft
2+
13
import numpy as np
24
import numpy.linalg
35
import pytest
@@ -189,6 +191,38 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
189191
outputs = [outputs]
190192
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)
191193

194+
@pytest.mark.parametrize(
195+
"compute_uv, full_matrices",
196+
[(True, False), (False, False), (True, True)],
197+
ids=[
198+
"compute_uv=True, full_matrices=False",
199+
"compute_uv=False, full_matrices=False",
200+
"compute_uv=True, full_matrices=True",
201+
],
202+
)
203+
def test_grad(self, compute_uv, full_matrices):
204+
A_v = self.rng.random((4, 4)).astype(self.dtype)
205+
if full_matrices:
206+
with pytest.raises(
207+
NotImplementedError,
208+
match="Gradient of svd not implemented for full_matrices=True",
209+
):
210+
u, s, v = svd(
211+
self.A, compute_uv=compute_uv, full_matrices=full_matrices
212+
)
213+
pytensor.grad(s.sum(), self.A)
214+
elif compute_uv:
215+
# u, s, v = svd(self.A, compute_uv=compute_uv, full_matrices=full_matrices)
216+
# op = pytensor.compile.builders.OpFromGraph([self.A], [s])
217+
# utt.verify_grad(op,[A_v], rng=np.random)
218+
pytest.mark.skip("Gradients of function with multiple outputs not testable")
219+
else:
220+
utt.verify_grad(
221+
ft.partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
222+
[A_v],
223+
rng=np.random,
224+
)
225+
192226

193227
def test_tensorsolve():
194228
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)