|
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 |
|
| 8 | +import pytensor.tensor as pt |
8 | 9 | from pytensor import scalar as ps
|
9 | 10 | from pytensor.compile.builders import OpFromGraph
|
10 | 11 | from pytensor.gradient import DisconnectedType
|
11 | 12 | from pytensor.graph.basic import Apply
|
12 | 13 | from pytensor.graph.op import Op
|
| 14 | +from pytensor.ifelse import ifelse |
13 | 15 | from pytensor.npy_2_compat import normalize_axis_tuple
|
| 16 | +from pytensor.raise_op import Assert |
14 | 17 | from pytensor.tensor import TensorLike
|
15 | 18 | from pytensor.tensor import basic as ptb
|
16 | 19 | from pytensor.tensor import math as ptm
|
@@ -512,6 +515,80 @@ def perform(self, node, inputs, outputs):
|
512 | 515 | else:
|
513 | 516 | outputs[0][0] = res
|
514 | 517 |
|
| 518 | + def L_op(self, inputs, outputs, output_grads): |
| 519 | + """ |
| 520 | + Reverse-mode gradient of the QR function. |
| 521 | +
|
| 522 | + References |
| 523 | + ---------- |
| 524 | + .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ |
| 525 | + .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 |
| 526 | + """ |
| 527 | + |
| 528 | + from pytensor.tensor.slinalg import solve_triangular |
| 529 | + |
| 530 | + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) |
| 531 | + m, n = A.shape |
| 532 | + |
| 533 | + def _H(x: ptb.TensorVariable): |
| 534 | + return x.conj().mT |
| 535 | + |
| 536 | + def _copyltu(x: ptb.TensorVariable): |
| 537 | + return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) |
| 538 | + |
| 539 | + if self.mode == "raw": |
| 540 | + raise NotImplementedError("Gradient of qr not implemented for mode=raw") |
| 541 | + |
| 542 | + elif self.mode == "r": |
| 543 | + # We need all the components of the QR to compute the gradient of A even if we only |
| 544 | + # use the upper triangular component in the cost function. |
| 545 | + Q, R = qr(A, mode="reduced") |
| 546 | + dQ = Q.zeros_like() |
| 547 | + dR = cast(ptb.TensorVariable, output_grads[0]) |
| 548 | + |
| 549 | + else: |
| 550 | + Q, R = (cast(ptb.TensorVariable, x) for x in outputs) |
| 551 | + if self.mode == "complete": |
| 552 | + qr_assert_op = Assert( |
| 553 | + "Gradient of qr not implemented for m x n matrices with m > n and mode=complete" |
| 554 | + ) |
| 555 | + R = qr_assert_op(R, ptm.le(m, n)) |
| 556 | + |
| 557 | + new_output_grads = [] |
| 558 | + is_disconnected = [ |
| 559 | + isinstance(x.type, DisconnectedType) for x in output_grads |
| 560 | + ] |
| 561 | + if all(is_disconnected): |
| 562 | + # This should never be reached by Pytensor |
| 563 | + return [DisconnectedType()()] # pragma: no cover |
| 564 | + |
| 565 | + for disconnected, output_grad, output in zip( |
| 566 | + is_disconnected, output_grads, [Q, R], strict=True |
| 567 | + ): |
| 568 | + if disconnected: |
| 569 | + new_output_grads.append(output.zeros_like()) |
| 570 | + else: |
| 571 | + new_output_grads.append(output_grad) |
| 572 | + |
| 573 | + (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) |
| 574 | + |
| 575 | + # gradient expression when m >= n |
| 576 | + M = R @ _H(dR) - _H(dQ) @ Q |
| 577 | + K = dQ + Q @ _copyltu(M) |
| 578 | + A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) |
| 579 | + |
| 580 | + # gradient expression when m < n |
| 581 | + Y = A[:, m:] |
| 582 | + U = R[:, :m] |
| 583 | + dU, dV = dR[:, :m], dR[:, m:] |
| 584 | + dQ_Yt_dV = dQ + Y @ _H(dV) |
| 585 | + M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q |
| 586 | + X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) |
| 587 | + Y_bar = Q @ dV |
| 588 | + A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) |
| 589 | + |
| 590 | + return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] |
| 591 | + |
515 | 592 |
|
516 | 593 | def qr(a, mode="reduced"):
|
517 | 594 | """
|
|
0 commit comments