Skip to content

Commit b00c897

Browse files
Add gradient support for unused outputs
1 parent b75e0f6 commit b00c897

File tree

2 files changed

+52
-17
lines changed

2 files changed

+52
-17
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,20 @@ def L_op(
632632

633633
else:
634634
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)
635-
(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in output_grads)
635+
636+
# Handle disconnected inputs
637+
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
638+
# will be DisconnectedType, which have type(Variable). Matrices that are on the backwards compute path
639+
# have type TensorVariable. Thus, we replace Variables with zero matrices of the correct shapes.
640+
new_output_grads = []
641+
for output_grad, output in zip(output_grads, outputs):
642+
if not isinstance(output_grad, ptb.TensorVariable):
643+
new_output_grads.append(ptb.zeros_like(output))
644+
else:
645+
new_output_grads.append(output_grad)
646+
647+
(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
648+
636649
V = VT.T
637650
dV = dVT.T
638651

tests/tensor/test_nlinalg.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,29 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
216216
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)
217217

218218
@pytest.mark.parametrize(
219-
"compute_uv, full_matrices",
220-
[(True, False), (False, False), (True, True)],
221-
ids=[
222-
"compute_uv=True, full_matrices=False",
223-
"compute_uv=False, full_matrices=False",
224-
"compute_uv=True, full_matrices=True",
225-
],
219+
"compute_uv, full_matrices, gradient_test_case",
220+
[(False, False, 0)]
221+
+ [(True, False, i) for i in range(7)]
222+
+ [(True, True, i) for i in range(7)],
223+
ids=(
224+
["compute_uv=False, full_matrices=False"]
225+
+ [
226+
f"compute_uv=True, full_matrices=False, gradient={grad}"
227+
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V"]
228+
]
229+
+ [
230+
f"compute_uv=True, full_matrices=True, gradient={grad}"
231+
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V"]
232+
]
233+
),
226234
)
227235
@pytest.mark.parametrize(
228236
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
229237
)
230238
@pytest.mark.parametrize(
231239
"batched", [True, False], ids=["batched=True", "batched=False"]
232240
)
233-
def test_grad(self, compute_uv, full_matrices, shape, batched):
241+
def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched):
234242
rng = np.random.default_rng(utt.fetch_seed())
235243
if batched:
236244
shape = (4, *shape)
@@ -248,15 +256,29 @@ def test_grad(self, compute_uv, full_matrices, shape, batched):
248256

249257
elif compute_uv:
250258

251-
def svd_fn(A):
259+
def svd_fn(A, case=0):
252260
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
253-
return U.sum() + s.sum() + V.sum()
254-
255-
utt.verify_grad(
256-
svd_fn,
257-
[A_v],
258-
rng=rng,
259-
)
261+
if case == 0:
262+
return U.sum()
263+
elif case == 1:
264+
return s.sum()
265+
elif case == 2:
266+
return V.sum()
267+
elif case == 3:
268+
return U.sum() + s.sum()
269+
elif case == 4:
270+
return s.sum() + V.sum()
271+
elif case == 5:
272+
return U.sum() + V.sum()
273+
elif case == 6:
274+
return U.sum() + s.sum() + V.sum()
275+
276+
for case in range(7):
277+
utt.verify_grad(
278+
partial(svd_fn, case=gradient_test_case),
279+
[A_v],
280+
rng=rng,
281+
)
260282

261283
else:
262284
utt.verify_grad(

0 commit comments

Comments
 (0)