|
1 | 1 | import warnings
|
| 2 | +from collections.abc import Sequence |
2 | 3 | from functools import partial
|
| 4 | +from typing import Union, cast |
3 | 5 |
|
4 | 6 | import numpy as np
|
5 | 7 |
|
| 8 | +from pytensor import Variable |
6 | 9 | from pytensor import scalar as ps
|
7 | 10 | from pytensor.gradient import DisconnectedType
|
8 | 11 | from pytensor.graph.basic import Apply
|
@@ -582,26 +585,87 @@ def infer_shape(self, fgraph, node, shapes):
|
582 | 585 | else:
|
583 | 586 | return [s_shape]
|
584 | 587 |
|
| 588 | + def L_op( |
| 589 | + self, |
| 590 | + inputs: Sequence[Variable], |
| 591 | + outputs: Sequence[Variable], |
| 592 | + output_grads: Sequence[Variable], |
| 593 | + ) -> list[Variable]: |
| 594 | + (A,) = inputs |
| 595 | + A = cast(ptb.TensorVariable, A) |
585 | 596 |
|
586 |
| -def svd(a, full_matrices: bool = True, compute_uv: bool = True): |
| 597 | + if not self.compute_uv: |
| 598 | + (S_grad,) = output_grads |
| 599 | + S_grad = cast(ptb.TensorVariable, S_grad) |
| 600 | + |
| 601 | + # Need U and V so do the whole svd anyway... |
| 602 | + [u, s, v] = svd(A, full_matrices=False, compute_uv=True) # type: ignore |
| 603 | + u = cast(ptb.TensorVariable, u) |
| 604 | + |
| 605 | + return [(u.conj() * S_grad[..., None, :]) @ v] |
| 606 | + |
| 607 | + elif self.full_matrices: |
| 608 | + raise NotImplementedError( |
| 609 | + "Gradient of svd not implemented for full_matrices=True" |
| 610 | + ) |
| 611 | + |
| 612 | + else: |
| 613 | + u, s, v = (cast(ptb.TensorVariable, x) for x in outputs) |
| 614 | + gu, gs, gv = (cast(ptb.TensorVariable, x) for x in output_grads) |
| 615 | + |
| 616 | + m, n = A.shape[-2:] |
| 617 | + |
| 618 | + k = ptm.min((m, n)) |
| 619 | + # broadcastable identity array with shape (1, 1, ..., 1, k, k) |
| 620 | + # i = anp.reshape(anp.eye(k), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (k, k)))) |
| 621 | + |
| 622 | + eye = ptb.eye(k) |
| 623 | + f = 1 / (s[..., None, :] ** 2 - s[..., :, None] ** 2 + eye) |
| 624 | + |
| 625 | + utgu = u.T @ gu |
| 626 | + vtgv = v.T @ gv |
| 627 | + t1 = f * (utgu - utgu.conj().T * s[..., None, :]) |
| 628 | + t1 = t1 + eye * gs[..., :, None] |
| 629 | + t1 = t1 + s[..., :, None] * (f * (vtgv - vtgv.conj().T)) |
| 630 | + |
| 631 | + if u.dtype.startswith("complex"): |
| 632 | + t1 = t1 + 1j * ptb.diag(utgu.imag) / s[..., None, :] |
| 633 | + |
| 634 | + t1 = u.conj() @ t1 @ v.T |
| 635 | + t1 = cast(ptb.TensorVariable, t1) |
| 636 | + |
| 637 | + if m < n: |
| 638 | + eye_n = ptb.eye(n) |
| 639 | + i_minus_vtt = eye_n - (v @ v.conj().T) |
| 640 | + t1 = t1 + (u / s[..., None, :] @ gv.T @ i_minus_vtt).conj() |
| 641 | + |
| 642 | + elif m > n: |
| 643 | + eye_m = ptb.eye(n) |
| 644 | + i_minus_uut = eye_m - u @ u.conj().T |
| 645 | + t1 = t1 + v / s[..., None, :] @ gu.T @ i_minus_uut |
| 646 | + |
| 647 | + return [t1] |
| 648 | + |
| 649 | + |
| 650 | +def svd( |
| 651 | + a, full_matrices: bool = True, compute_uv: bool = True |
| 652 | +) -> Union[Variable, list[Variable]]: |
587 | 653 | """
|
588 | 654 | This function performs the SVD on CPU.
|
589 | 655 |
|
590 | 656 | Parameters
|
591 | 657 | ----------
|
592 | 658 | full_matrices : bool, optional
|
593 |
| - If True (default), u and v have the shapes (M, M) and (N, N), |
594 |
| - respectively. |
595 |
| - Otherwise, the shapes are (M, K) and (K, N), respectively, |
596 |
| - where K = min(M, N). |
| 659 | + If True (default), u and v have the shapes (M, M) and (N, N), respectively. Otherwise, the shapes are (M, K) |
| 660 | + and (K, N), respectively, where K = min(M, N). |
597 | 661 | compute_uv : bool, optional
|
598 |
| - Whether or not to compute u and v in addition to s. |
599 |
| - True by default. |
| 662 | + Whether or not to compute u and v in addition to s. True by default. |
600 | 663 |
|
601 | 664 | Returns
|
602 | 665 | -------
|
603 |
| - U, V, D : matrices |
604 |
| -
|
| 666 | + matrices: TensorVariable or list of TensorVariable |
| 667 | + Result of singular value decomposition. If compute_uv is True, return a list of TensorVariable [U, S, V]. |
| 668 | + Otherwise, returns only singular values S. |
605 | 669 | """
|
606 | 670 | return SVD(full_matrices, compute_uv)(a)
|
607 | 671 |
|
|
0 commit comments