|
1 | 1 | import warnings
|
| 2 | +from collections.abc import Sequence |
2 | 3 | from functools import partial
|
3 |
| -from typing import Callable, Literal, Optional, Union |
| 4 | +from typing import Callable, Literal, Optional, Union, cast |
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | from numpy.core.numeric import normalize_axis_tuple # type: ignore
|
|
13 | 14 | from pytensor.tensor import math as ptm
|
14 | 15 | from pytensor.tensor.basic import as_tensor_variable, diagonal
|
15 | 16 | from pytensor.tensor.blockwise import Blockwise
|
16 |
| -from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector |
| 17 | +from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class MatrixPinv(Op):
|
@@ -595,6 +596,89 @@ def infer_shape(self, fgraph, node, shapes):
|
595 | 596 | else:
|
596 | 597 | return [s_shape]
|
597 | 598 |
|
| 599 | + def L_op( |
| 600 | + self, |
| 601 | + inputs: Sequence[Variable], |
| 602 | + outputs: Sequence[Variable], |
| 603 | + output_grads: Sequence[Variable], |
| 604 | + ) -> list[Variable]: |
| 605 | + """ |
| 606 | + Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here: |
| 607 | + https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194 |
| 608 | +
|
| 609 | + And the mxnet implementation described in ..[1] |
| 610 | +
|
| 611 | + References |
| 612 | + ---------- |
| 613 | + .. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017). |
| 614 | + """ |
| 615 | + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) |
| 616 | + |
| 617 | + if not self.compute_uv: |
| 618 | + # We need all the components of the SVD to compute the gradient of A even if we only use the singular values |
| 619 | + # in the cost function. |
| 620 | + U, s, VT = svd(A, full_matrices=False, compute_uv=True) |
| 621 | + |
| 622 | + (ds,) = (cast(ptb.TensorVariable, x) for x in output_grads) |
| 623 | + A_bar = (U.conj() * ds[..., None, :]) @ VT |
| 624 | + |
| 625 | + return [A_bar] |
| 626 | + |
| 627 | + elif self.full_matrices: |
| 628 | + raise NotImplementedError( |
| 629 | + "Gradient of svd not implemented for full_matrices=True" |
| 630 | + ) |
| 631 | + |
| 632 | + else: |
| 633 | + U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs) |
| 634 | + (dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in output_grads) |
| 635 | + V = VT.T |
| 636 | + dV = dVT.T |
| 637 | + |
| 638 | + m, n = A.shape[-2:] |
| 639 | + |
| 640 | + k = ptm.min((m, n)) |
| 641 | + eye = ptb.eye(k) |
| 642 | + |
| 643 | + def h(t): |
| 644 | + """ |
| 645 | + Approximation of s_i ** 2 - s_j ** 2, from .. [1]. |
| 646 | + Robust to identical singular values (singular matrix input), although |
| 647 | + gradients are still wrong in this case. |
| 648 | + """ |
| 649 | + eps = 1e-8 |
| 650 | + |
| 651 | + # sign(0) = 0 in pytensor, which defeats the whole purpose of this function |
| 652 | + sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t)) |
| 653 | + return ptm.maximum(ptm.abs(t), eps) * sign_t |
| 654 | + |
| 655 | + numer = ptb.ones((k, k)) - eye |
| 656 | + denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None]) |
| 657 | + E = numer / denom |
| 658 | + |
| 659 | + utgu = U.T @ dU |
| 660 | + vtgv = VT @ dV |
| 661 | + |
| 662 | + A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :] |
| 663 | + A_bar = A_bar + eye * ds[..., :, None] |
| 664 | + A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T)) |
| 665 | + A_bar = U.conj() @ A_bar @ VT |
| 666 | + |
| 667 | + A_bar = ptb.switch( |
| 668 | + ptm.eq(m, n), |
| 669 | + A_bar, |
| 670 | + ptb.switch( |
| 671 | + ptm.lt(m, n), |
| 672 | + A_bar |
| 673 | + + ( |
| 674 | + U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T) |
| 675 | + ).conj(), |
| 676 | + A_bar |
| 677 | + + (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T, |
| 678 | + ), |
| 679 | + ) |
| 680 | + return [A_bar] |
| 681 | + |
598 | 682 |
|
599 | 683 | def svd(a, full_matrices: bool = True, compute_uv: bool = True):
|
600 | 684 | """
|
|
0 commit comments