|
1 | 1 | import warnings
|
2 |
| -from collections.abc import Callable |
| 2 | +from collections.abc import Callable, Sequence |
3 | 3 | from functools import partial
|
4 |
| -from typing import Literal |
| 4 | +from typing import Literal, cast |
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 | from numpy.core.numeric import normalize_axis_tuple # type: ignore
|
|
15 | 15 | from pytensor.tensor import math as ptm
|
16 | 16 | from pytensor.tensor.basic import as_tensor_variable, diagonal
|
17 | 17 | from pytensor.tensor.blockwise import Blockwise
|
18 |
| -from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector |
| 18 | +from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector |
19 | 19 |
|
20 | 20 |
|
21 | 21 | class MatrixPinv(Op):
|
@@ -597,6 +597,118 @@ def infer_shape(self, fgraph, node, shapes):
|
597 | 597 | else:
|
598 | 598 | return [s_shape]
|
599 | 599 |
|
| 600 | + def L_op( |
| 601 | + self, |
| 602 | + inputs: Sequence[Variable], |
| 603 | + outputs: Sequence[Variable], |
| 604 | + output_grads: Sequence[Variable], |
| 605 | + ) -> list[Variable]: |
| 606 | + """ |
| 607 | + Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here: |
| 608 | + https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194 |
| 609 | +
|
| 610 | + And the mxnet implementation described in ..[1] |
| 611 | +
|
| 612 | + References |
| 613 | + ---------- |
| 614 | + .. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017). |
| 615 | + """ |
| 616 | + |
| 617 | + def s_grad_only( |
| 618 | + U: ptb.TensorVariable, VT: ptb.TensorVariable, ds: ptb.TensorVariable |
| 619 | + ) -> list[Variable]: |
| 620 | + A_bar = (U.conj() * ds[..., None, :]) @ VT |
| 621 | + return [A_bar] |
| 622 | + |
| 623 | + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) |
| 624 | + |
| 625 | + if not self.compute_uv: |
| 626 | + # We need all the components of the SVD to compute the gradient of A even if we only use the singular values |
| 627 | + # in the cost function. |
| 628 | + U, _, VT = svd(A, full_matrices=False, compute_uv=True) |
| 629 | + ds = cast(ptb.TensorVariable, output_grads[0]) |
| 630 | + return s_grad_only(U, VT, ds) |
| 631 | + |
| 632 | + elif self.full_matrices: |
| 633 | + raise NotImplementedError( |
| 634 | + "Gradient of svd not implemented for full_matrices=True" |
| 635 | + ) |
| 636 | + |
| 637 | + else: |
| 638 | + U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs) |
| 639 | + |
| 640 | + # Handle disconnected inputs |
| 641 | + # If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs |
| 642 | + # will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes. |
| 643 | + new_output_grads = [] |
| 644 | + is_disconnected = [ |
| 645 | + isinstance(x.type, DisconnectedType) for x in output_grads |
| 646 | + ] |
| 647 | + if all(is_disconnected): |
| 648 | + return [DisconnectedType()()] |
| 649 | + elif is_disconnected == [True, False, True]: |
| 650 | + # This is the same as the compute_uv = False, so we can drop back to that simpler computation, without |
| 651 | + # needing to re-compoute U and VT |
| 652 | + ds = cast(ptb.TensorVariable, output_grads[1]) |
| 653 | + return s_grad_only(U, VT, ds) |
| 654 | + |
| 655 | + for disconnected, output_grad, output in zip( |
| 656 | + is_disconnected, output_grads, [U, s, VT] |
| 657 | + ): |
| 658 | + if disconnected: |
| 659 | + new_output_grads.append(output.zeros_like()) |
| 660 | + else: |
| 661 | + new_output_grads.append(output_grad) |
| 662 | + |
| 663 | + (dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads) |
| 664 | + |
| 665 | + V = VT.T |
| 666 | + dV = dVT.T |
| 667 | + |
| 668 | + m, n = A.shape[-2:] |
| 669 | + |
| 670 | + k = ptm.min((m, n)) |
| 671 | + eye = ptb.eye(k) |
| 672 | + |
| 673 | + def h(t): |
| 674 | + """ |
| 675 | + Approximation of s_i ** 2 - s_j ** 2, from .. [1]. |
| 676 | + Robust to identical singular values (singular matrix input), although |
| 677 | + gradients are still wrong in this case. |
| 678 | + """ |
| 679 | + eps = 1e-8 |
| 680 | + |
| 681 | + # sign(0) = 0 in pytensor, which defeats the whole purpose of this function |
| 682 | + sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t)) |
| 683 | + return ptm.maximum(ptm.abs(t), eps) * sign_t |
| 684 | + |
| 685 | + numer = ptb.ones((k, k)) - eye |
| 686 | + denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None]) |
| 687 | + E = numer / denom |
| 688 | + |
| 689 | + utgu = U.T @ dU |
| 690 | + vtgv = VT @ dV |
| 691 | + |
| 692 | + A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :] |
| 693 | + A_bar = A_bar + eye * ds[..., :, None] |
| 694 | + A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T)) |
| 695 | + A_bar = U.conj() @ A_bar @ VT |
| 696 | + |
| 697 | + A_bar = ptb.switch( |
| 698 | + ptm.eq(m, n), |
| 699 | + A_bar, |
| 700 | + ptb.switch( |
| 701 | + ptm.lt(m, n), |
| 702 | + A_bar |
| 703 | + + ( |
| 704 | + U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T) |
| 705 | + ).conj(), |
| 706 | + A_bar |
| 707 | + + (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T, |
| 708 | + ), |
| 709 | + ) |
| 710 | + return [A_bar] |
| 711 | + |
600 | 712 |
|
601 | 713 | def svd(a, full_matrices: bool = True, compute_uv: bool = True):
|
602 | 714 | """
|
|
0 commit comments