Skip to content

Commit 196b5e4

Browse files
Add gradient for SVD
1 parent eb18f0e commit 196b5e4

File tree

2 files changed

+193
-3
lines changed

2 files changed

+193
-3
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
2-
from collections.abc import Callable
2+
from collections.abc import Callable, Sequence
33
from functools import partial
4-
from typing import Literal
4+
from typing import Literal, cast
55

66
import numpy as np
77
from numpy.core.numeric import normalize_axis_tuple # type: ignore
@@ -15,7 +15,7 @@
1515
from pytensor.tensor import math as ptm
1616
from pytensor.tensor.basic import as_tensor_variable, diagonal
1717
from pytensor.tensor.blockwise import Blockwise
18-
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
18+
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
1919

2020

2121
class MatrixPinv(Op):
@@ -597,6 +597,121 @@ def infer_shape(self, fgraph, node, shapes):
597597
else:
598598
return [s_shape]
599599

600+
def L_op(
601+
self,
602+
inputs: Sequence[Variable],
603+
outputs: Sequence[Variable],
604+
output_grads: Sequence[Variable],
605+
) -> list[Variable]:
606+
"""
607+
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
608+
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
609+
610+
And the mxnet implementation described in ..[1]
611+
612+
References
613+
----------
614+
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
615+
"""
616+
617+
def s_grad_only(
618+
U: ptb.TensorVariable, VT: ptb.TensorVariable, ds: ptb.TensorVariable
619+
) -> list[Variable]:
620+
A_bar = (U.conj() * ds[..., None, :]) @ VT
621+
return [A_bar]
622+
623+
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
624+
625+
if not self.compute_uv:
626+
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
627+
# in the cost function.
628+
U, _, VT = svd(A, full_matrices=False, compute_uv=True)
629+
ds = cast(ptb.TensorVariable, output_grads[0])
630+
return s_grad_only(U, VT, ds)
631+
632+
elif self.full_matrices:
633+
raise NotImplementedError(
634+
"Gradient of svd not implemented for full_matrices=True"
635+
)
636+
637+
else:
638+
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)
639+
640+
# Handle disconnected inputs
641+
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
642+
# will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.
643+
new_output_grads = []
644+
is_disconnected = [
645+
isinstance(x.type, DisconnectedType) for x in output_grads
646+
]
647+
if all(is_disconnected):
648+
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
649+
# graph if its fully disconnected. It is included for completeness.
650+
return [DisconnectedType()()] # pragma: no cover
651+
652+
elif is_disconnected == [True, False, True]:
653+
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
654+
# needing to re-compoute U and VT
655+
ds = cast(ptb.TensorVariable, output_grads[1])
656+
return s_grad_only(U, VT, ds)
657+
658+
for disconnected, output_grad, output in zip(
659+
is_disconnected, output_grads, [U, s, VT]
660+
):
661+
if disconnected:
662+
new_output_grads.append(output.zeros_like())
663+
else:
664+
new_output_grads.append(output_grad)
665+
666+
(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
667+
668+
V = VT.T
669+
dV = dVT.T
670+
671+
m, n = A.shape[-2:]
672+
673+
k = ptm.min((m, n))
674+
eye = ptb.eye(k)
675+
676+
def h(t):
677+
"""
678+
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
679+
Robust to identical singular values (singular matrix input), although
680+
gradients are still wrong in this case.
681+
"""
682+
eps = 1e-8
683+
684+
# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
685+
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
686+
return ptm.maximum(ptm.abs(t), eps) * sign_t
687+
688+
numer = ptb.ones((k, k)) - eye
689+
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
690+
E = numer / denom
691+
692+
utgu = U.T @ dU
693+
vtgv = VT @ dV
694+
695+
A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
696+
A_bar = A_bar + eye * ds[..., :, None]
697+
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
698+
A_bar = U.conj() @ A_bar @ VT
699+
700+
A_bar = ptb.switch(
701+
ptm.eq(m, n),
702+
A_bar,
703+
ptb.switch(
704+
ptm.lt(m, n),
705+
A_bar
706+
+ (
707+
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
708+
).conj(),
709+
A_bar
710+
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
711+
),
712+
)
713+
return [A_bar]
714+
600715

601716
def svd(a, full_matrices: bool = True, compute_uv: bool = True):
602717
"""

tests/tensor/test_nlinalg.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytensor
99
from pytensor import function
1010
from pytensor.configdefaults import config
11+
from pytensor.tensor.basic import as_tensor_variable
1112
from pytensor.tensor.math import _allclose
1213
from pytensor.tensor.nlinalg import (
1314
SVD,
@@ -215,6 +216,80 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
215216
outputs = [outputs]
216217
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)
217218

219+
@pytest.mark.parametrize(
220+
"compute_uv, full_matrices, gradient_test_case",
221+
[(False, False, 0)]
222+
+ [(True, False, i) for i in range(8)]
223+
+ [(True, True, i) for i in range(8)],
224+
ids=(
225+
["compute_uv=False, full_matrices=False"]
226+
+ [
227+
f"compute_uv=True, full_matrices=False, gradient={grad}"
228+
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
229+
]
230+
+ [
231+
f"compute_uv=True, full_matrices=True, gradient={grad}"
232+
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
233+
]
234+
),
235+
)
236+
@pytest.mark.parametrize(
237+
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
238+
)
239+
@pytest.mark.parametrize(
240+
"batched", [True, False], ids=["batched=True", "batched=False"]
241+
)
242+
def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched):
243+
rng = np.random.default_rng(utt.fetch_seed())
244+
if batched:
245+
shape = (4, *shape)
246+
247+
A_v = self.rng.normal(size=shape).astype(config.floatX)
248+
if full_matrices:
249+
with pytest.raises(
250+
NotImplementedError,
251+
match="Gradient of svd not implemented for full_matrices=True",
252+
):
253+
U, s, V = svd(
254+
self.A, compute_uv=compute_uv, full_matrices=full_matrices
255+
)
256+
pytensor.grad(s.sum(), self.A)
257+
258+
elif compute_uv:
259+
260+
def svd_fn(A, case=0):
261+
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
262+
if case == 0:
263+
return U.sum()
264+
elif case == 1:
265+
return s.sum()
266+
elif case == 2:
267+
return V.sum()
268+
elif case == 3:
269+
return U.sum() + s.sum()
270+
elif case == 4:
271+
return s.sum() + V.sum()
272+
elif case == 5:
273+
return U.sum() + V.sum()
274+
elif case == 6:
275+
return U.sum() + s.sum() + V.sum()
276+
elif case == 7:
277+
# All inputs disconnected
278+
return as_tensor_variable(3.0)
279+
280+
utt.verify_grad(
281+
partial(svd_fn, case=gradient_test_case),
282+
[A_v],
283+
rng=rng,
284+
)
285+
286+
else:
287+
utt.verify_grad(
288+
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
289+
[A_v],
290+
rng=rng,
291+
)
292+
218293

219294
def test_tensorsolve():
220295
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)