Skip to content

Commit b6c79fd

Browse files
Add gradient for SVD
1 parent 453fb4d commit b6c79fd

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
2+
from collections.abc import Sequence
23
from functools import partial
3-
from typing import Callable, Literal, Optional, Union
4+
from typing import Callable, Literal, Optional, Union, cast
45

56
import numpy as np
67
from numpy.core.numeric import normalize_axis_tuple # type: ignore
@@ -13,7 +14,7 @@
1314
from pytensor.tensor import math as ptm
1415
from pytensor.tensor.basic import as_tensor_variable, diagonal
1516
from pytensor.tensor.blockwise import Blockwise
16-
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
17+
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
1718

1819

1920
class MatrixPinv(Op):
@@ -595,6 +596,89 @@ def infer_shape(self, fgraph, node, shapes):
595596
else:
596597
return [s_shape]
597598

599+
def L_op(
600+
self,
601+
inputs: Sequence[Variable],
602+
outputs: Sequence[Variable],
603+
output_grads: Sequence[Variable],
604+
) -> list[Variable]:
605+
"""
606+
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
607+
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
608+
609+
And the mxnet implementation described in ..[1]
610+
611+
References
612+
----------
613+
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
614+
"""
615+
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
616+
617+
if not self.compute_uv:
618+
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
619+
# in the cost function.
620+
U, s, VT = svd(A, full_matrices=False, compute_uv=True)
621+
622+
(ds,) = (cast(ptb.TensorVariable, x) for x in output_grads)
623+
A_bar = (U.conj() * ds[..., None, :]) @ VT
624+
625+
return [A_bar]
626+
627+
elif self.full_matrices:
628+
raise NotImplementedError(
629+
"Gradient of svd not implemented for full_matrices=True"
630+
)
631+
632+
else:
633+
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)
634+
(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in output_grads)
635+
V = VT.T
636+
dV = dVT.T
637+
638+
m, n = A.shape[-2:]
639+
640+
k = ptm.min((m, n))
641+
eye = ptb.eye(k)
642+
643+
def h(t):
644+
"""
645+
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
646+
Robust to identical singular values (singular matrix input), although
647+
gradients are still wrong in this case.
648+
"""
649+
eps = 1e-8
650+
651+
# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
652+
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
653+
return ptm.maximum(ptm.abs(t), eps) * sign_t
654+
655+
numer = ptb.ones((k, k)) - eye
656+
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
657+
E = numer / denom
658+
659+
utgu = U.T @ dU
660+
vtgv = VT @ dV
661+
662+
A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
663+
A_bar = A_bar + eye * ds[..., :, None]
664+
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
665+
A_bar = U.conj() @ A_bar @ VT
666+
667+
A_bar = ptb.switch(
668+
ptm.eq(m, n),
669+
A_bar,
670+
ptb.switch(
671+
ptm.lt(m, n),
672+
A_bar
673+
+ (
674+
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
675+
).conj(),
676+
A_bar
677+
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
678+
),
679+
)
680+
return [A_bar]
681+
598682

599683
def svd(a, full_matrices: bool = True, compute_uv: bool = True):
600684
"""

tests/tensor/test_nlinalg.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,51 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
214214
outputs = [outputs]
215215
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)
216216

217+
@pytest.mark.parametrize(
218+
"compute_uv, full_matrices",
219+
[(True, False), (False, False), (True, True)],
220+
ids=[
221+
"compute_uv=True, full_matrices=False",
222+
"compute_uv=False, full_matrices=False",
223+
"compute_uv=True, full_matrices=True",
224+
],
225+
)
226+
@pytest.mark.parametrize(
227+
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
228+
)
229+
@pytest.mark.parametrize(
230+
"batched", [True, False], ids=["batched=True", "batched=False"]
231+
)
232+
def test_grad(self, compute_uv, full_matrices, shape, batched):
233+
rng = np.random.default_rng(utt.fetch_seed())
234+
if batched:
235+
shape = (4, *shape)
236+
237+
A_v = self.rng.normal(size=shape).astype(config.floatX)
238+
if full_matrices:
239+
with pytest.raises(
240+
NotImplementedError,
241+
match="Gradient of svd not implemented for full_matrices=True",
242+
):
243+
U, s, V = svd(
244+
self.A, compute_uv=compute_uv, full_matrices=full_matrices
245+
)
246+
pytensor.grad(s.sum(), self.A)
247+
248+
elif compute_uv:
249+
utt.verify_grad(
250+
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
251+
[A_v],
252+
rng=rng,
253+
)
254+
255+
else:
256+
utt.verify_grad(
257+
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
258+
[A_v],
259+
rng=rng,
260+
)
261+
217262

218263
def test_tensorsolve():
219264
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)