Skip to content

Commit 2774599

Browse files
authored
Implement gradient for QR decomposition (#1303)
1 parent 8a7356c commit 2774599

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

pytensor/tensor/nlinalg.py

+77
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55

66
import numpy as np
77

8+
import pytensor.tensor as pt
89
from pytensor import scalar as ps
910
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.gradient import DisconnectedType
1112
from pytensor.graph.basic import Apply
1213
from pytensor.graph.op import Op
14+
from pytensor.ifelse import ifelse
1315
from pytensor.npy_2_compat import normalize_axis_tuple
16+
from pytensor.raise_op import Assert
1417
from pytensor.tensor import TensorLike
1518
from pytensor.tensor import basic as ptb
1619
from pytensor.tensor import math as ptm
@@ -512,6 +515,80 @@ def perform(self, node, inputs, outputs):
512515
else:
513516
outputs[0][0] = res
514517

518+
def L_op(self, inputs, outputs, output_grads):
519+
"""
520+
Reverse-mode gradient of the QR function.
521+
522+
References
523+
----------
524+
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
525+
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
526+
"""
527+
528+
from pytensor.tensor.slinalg import solve_triangular
529+
530+
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
531+
m, n = A.shape
532+
533+
def _H(x: ptb.TensorVariable):
534+
return x.conj().mT
535+
536+
def _copyltu(x: ptb.TensorVariable):
537+
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
538+
539+
if self.mode == "raw":
540+
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
541+
542+
elif self.mode == "r":
543+
# We need all the components of the QR to compute the gradient of A even if we only
544+
# use the upper triangular component in the cost function.
545+
Q, R = qr(A, mode="reduced")
546+
dQ = Q.zeros_like()
547+
dR = cast(ptb.TensorVariable, output_grads[0])
548+
549+
else:
550+
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
551+
if self.mode == "complete":
552+
qr_assert_op = Assert(
553+
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
554+
)
555+
R = qr_assert_op(R, ptm.le(m, n))
556+
557+
new_output_grads = []
558+
is_disconnected = [
559+
isinstance(x.type, DisconnectedType) for x in output_grads
560+
]
561+
if all(is_disconnected):
562+
# This should never be reached by Pytensor
563+
return [DisconnectedType()()] # pragma: no cover
564+
565+
for disconnected, output_grad, output in zip(
566+
is_disconnected, output_grads, [Q, R], strict=True
567+
):
568+
if disconnected:
569+
new_output_grads.append(output.zeros_like())
570+
else:
571+
new_output_grads.append(output_grad)
572+
573+
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
574+
575+
# gradient expression when m >= n
576+
M = R @ _H(dR) - _H(dQ) @ Q
577+
K = dQ + Q @ _copyltu(M)
578+
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
579+
580+
# gradient expression when m < n
581+
Y = A[:, m:]
582+
U = R[:, :m]
583+
dU, dV = dR[:, :m], dR[:, m:]
584+
dQ_Yt_dV = dQ + Y @ _H(dV)
585+
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
586+
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
587+
Y_bar = Q @ dV
588+
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
589+
590+
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
591+
515592

516593
def qr(a, mode="reduced"):
517594
"""

tests/tensor/test_nlinalg.py

+66
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,72 @@ def test_qr_modes():
152152
assert "name 'complete' is not defined" in str(e)
153153

154154

155+
@pytest.mark.parametrize(
156+
"shape, gradient_test_case, mode",
157+
(
158+
[(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
159+
+ [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
160+
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
161+
+ [((3, 3), 0, "raw")]
162+
),
163+
ids=(
164+
[
165+
f"shape={s}, gradient_test_case={c}, mode=reduced"
166+
for s in [(3, 3), (6, 3), (3, 6)]
167+
for c in ["Q", "R", "both"]
168+
]
169+
+ [
170+
f"shape={s}, gradient_test_case={c}, mode=complete"
171+
for s in [(3, 3), (6, 3), (3, 6)]
172+
for c in ["Q", "R", "both"]
173+
]
174+
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
175+
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
176+
),
177+
)
178+
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
179+
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
180+
rng = np.random.default_rng(utt.fetch_seed())
181+
182+
def _test_fn(x, case=2, mode="reduced"):
183+
if case == 0:
184+
return qr(x, mode=mode)[0].sum()
185+
elif case == 1:
186+
return qr(x, mode=mode)[1].sum()
187+
elif case == 2:
188+
Q, R = qr(x, mode=mode)
189+
return Q.sum() + R.sum()
190+
191+
if is_complex:
192+
pytest.xfail("Complex inputs currently not supported by verify_grad")
193+
194+
m, n = shape
195+
a = rng.standard_normal(shape).astype(config.floatX)
196+
if is_complex:
197+
a += 1j * rng.standard_normal(shape).astype(config.floatX)
198+
199+
if mode == "raw":
200+
with pytest.raises(NotImplementedError):
201+
utt.verify_grad(
202+
partial(_test_fn, case=gradient_test_case, mode=mode),
203+
[a],
204+
rng=np.random,
205+
)
206+
207+
elif mode == "complete" and m > n:
208+
with pytest.raises(AssertionError):
209+
utt.verify_grad(
210+
partial(_test_fn, case=gradient_test_case, mode=mode),
211+
[a],
212+
rng=np.random,
213+
)
214+
215+
else:
216+
utt.verify_grad(
217+
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
218+
)
219+
220+
155221
class TestSvd(utt.InferShapeTester):
156222
op_class = SVD
157223

0 commit comments

Comments
 (0)