Skip to content

Implement gradient for SVD #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 118 additions & 3 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import partial
from typing import Literal
from typing import Literal, cast

import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
Expand All @@ -15,7 +15,7 @@
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
Expand Down Expand Up @@ -597,6 +597,121 @@ def infer_shape(self, fgraph, node, shapes):
else:
return [s_shape]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
"""
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
And the mxnet implementation described in ..[1]
References
----------
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
"""

def s_grad_only(
U: ptb.TensorVariable, VT: ptb.TensorVariable, ds: ptb.TensorVariable
) -> list[Variable]:
A_bar = (U.conj() * ds[..., None, :]) @ VT
return [A_bar]

(A,) = (cast(ptb.TensorVariable, x) for x in inputs)

if not self.compute_uv:
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
# in the cost function.
U, _, VT = svd(A, full_matrices=False, compute_uv=True)
ds = cast(ptb.TensorVariable, output_grads[0])
return s_grad_only(U, VT, ds)

elif self.full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True"
)

else:
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)

# Handle disconnected inputs
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
# will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness.
return [DisconnectedType()()] # pragma: no cover

elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
# needing to re-compoute U and VT
ds = cast(ptb.TensorVariable, output_grads[1])
return s_grad_only(U, VT, ds)

for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [U, s, VT]
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)

(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)

V = VT.T
dV = dVT.T

m, n = A.shape[-2:]

k = ptm.min((m, n))
eye = ptb.eye(k)

def h(t):
"""
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
Robust to identical singular values (singular matrix input), although
gradients are still wrong in this case.
"""
eps = 1e-8

# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
return ptm.maximum(ptm.abs(t), eps) * sign_t

numer = ptb.ones((k, k)) - eye
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
E = numer / denom

utgu = U.T @ dU
vtgv = VT @ dV

A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
A_bar = A_bar + eye * ds[..., :, None]
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
A_bar = U.conj() @ A_bar @ VT

A_bar = ptb.switch(
ptm.eq(m, n),
A_bar,
ptb.switch(
ptm.lt(m, n),
A_bar
+ (
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
).conj(),
A_bar
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
),
)
return [A_bar]


def svd(a, full_matrices: bool = True, compute_uv: bool = True):
"""
Expand Down
75 changes: 75 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytensor
from pytensor import function
from pytensor.configdefaults import config
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import (
SVD,
Expand Down Expand Up @@ -215,6 +216,80 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)

@pytest.mark.parametrize(
"compute_uv, full_matrices, gradient_test_case",
[(False, False, 0)]
+ [(True, False, i) for i in range(8)]
+ [(True, True, i) for i in range(8)],
ids=(
["compute_uv=False, full_matrices=False"]
+ [
f"compute_uv=True, full_matrices=False, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
]
+ [
f"compute_uv=True, full_matrices=True, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
]
),
)
@pytest.mark.parametrize(
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched):
rng = np.random.default_rng(utt.fetch_seed())
if batched:
shape = (4, *shape)

A_v = self.rng.normal(size=shape).astype(config.floatX)
if full_matrices:
with pytest.raises(
NotImplementedError,
match="Gradient of svd not implemented for full_matrices=True",
):
U, s, V = svd(
self.A, compute_uv=compute_uv, full_matrices=full_matrices
)
pytensor.grad(s.sum(), self.A)

elif compute_uv:

def svd_fn(A, case=0):
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
if case == 0:
return U.sum()
elif case == 1:
return s.sum()
elif case == 2:
return V.sum()
elif case == 3:
return U.sum() + s.sum()
elif case == 4:
return s.sum() + V.sum()
elif case == 5:
return U.sum() + V.sum()
elif case == 6:
return U.sum() + s.sum() + V.sum()
elif case == 7:
# All inputs disconnected
return as_tensor_variable(3.0)

utt.verify_grad(
partial(svd_fn, case=gradient_test_case),
[A_v],
rng=rng,
)

else:
utt.verify_grad(
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
)


def test_tensorsolve():
rng = np.random.default_rng(utt.fetch_seed())
Expand Down
Loading