Skip to content

Commit 3f1f902

Browse files
Limited verify_grad support for multiple output Ops
1 parent b6c79fd commit 3f1f902

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

pytensor/gradient.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,7 @@ def verify_grad(
16751675
mode: Optional[Union["Mode", str]] = None,
16761676
cast_to_output_type: bool = False,
16771677
no_debug_ref: bool = True,
1678+
sum_outputs=False,
16781679
):
16791680
"""Test a gradient by Finite Difference Method. Raise error on failure.
16801681
@@ -1722,7 +1723,9 @@ def verify_grad(
17221723
float16 is not handled here.
17231724
no_debug_ref
17241725
Don't use `DebugMode` for the numerical gradient function.
1725-
1726+
sum_outputs: bool, default False
1727+
If True, the gradient of the sum of all outputs is verified. If False, an error is raised if the function has
1728+
multiple outputs.
17261729
Notes
17271730
-----
17281731
This function does not support multiple outputs. In `tests.scan.test_basic`
@@ -1782,7 +1785,7 @@ def verify_grad(
17821785
# fun can be either a function or an actual Op instance
17831786
o_output = fun(*tensor_pt)
17841787

1785-
if isinstance(o_output, list):
1788+
if isinstance(o_output, list) and not sum_outputs:
17861789
raise NotImplementedError(
17871790
"Can't (yet) auto-test the gradient of a function with multiple outputs"
17881791
)
@@ -1793,7 +1796,7 @@ def verify_grad(
17931796
o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd")
17941797
o_fn_out = o_fn(*[p.copy() for p in pt])
17951798

1796-
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
1799+
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list) and not sum_outputs:
17971800
raise TypeError(
17981801
"It seems like you are trying to use verify_grad "
17991802
"on an Op or a function which outputs a list: there should"
@@ -1802,33 +1805,45 @@ def verify_grad(
18021805

18031806
# random_projection should not have elements too small,
18041807
# otherwise too much precision is lost in numerical gradient
1805-
def random_projection():
1806-
plain = rng.random(o_fn_out.shape) + 0.5
1807-
if cast_to_output_type and o_output.dtype == "float32":
1808-
return np.array(plain, o_output.dtype)
1808+
def random_projection(shape, dtype):
1809+
plain = rng.random(shape) + 0.5
1810+
if cast_to_output_type and dtype == "float32":
1811+
return np.array(plain, dtype)
18091812
return plain
18101813

1811-
t_r = shared(random_projection(), borrow=True)
1812-
t_r.name = "random_projection"
1813-
18141814
# random projection of o onto t_r
18151815
# This sum() is defined above, it's not the builtin sum.
1816-
cost = pytensor.tensor.sum(t_r * o_output)
1816+
if sum_outputs:
1817+
t_rs = [
1818+
shared(
1819+
value=random_projection(o.shape, o.dtype),
1820+
borrow=True,
1821+
name=f"random_projection_{i}",
1822+
)
1823+
for i, o in enumerate(o_fn_out)
1824+
]
1825+
cost = pytensor.tensor.sum(
1826+
[pytensor.tensor.sum(x * y) for x, y in zip(t_rs, o_output)]
1827+
)
1828+
else:
1829+
t_r = shared(
1830+
value=random_projection(o_fn_out.shape, o_fn_out.dtype),
1831+
borrow=True,
1832+
name="random_projection",
1833+
)
1834+
cost = pytensor.tensor.sum(t_r * o_output)
18171835

18181836
if no_debug_ref:
18191837
mode_for_cost = mode_not_slow(mode)
18201838
else:
18211839
mode_for_cost = mode
18221840

18231841
cost_fn = fn_maker(tensor_pt, cost, name="gradient.py cost", mode=mode_for_cost)
1824-
18251842
symbolic_grad = grad(cost, tensor_pt, disconnected_inputs="ignore")
1826-
18271843
grad_fn = fn_maker(tensor_pt, symbolic_grad, name="gradient.py symbolic grad")
18281844

18291845
for test_num in range(n_tests):
18301846
num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps, out_type)
1831-
18321847
analytic_grad = grad_fn(*[p.copy() for p in pt])
18331848

18341849
# Since `tensor_pt` is a list, `analytic_grad` should be one too.
@@ -1853,7 +1868,16 @@ def random_projection():
18531868

18541869
# get new random projection for next test
18551870
if test_num < n_tests - 1:
1856-
t_r.set_value(random_projection(), borrow=True)
1871+
if sum_outputs:
1872+
for r in t_rs:
1873+
r.set_value(
1874+
random_projection(r.get_value().shape, r.get_value().dtype)
1875+
)
1876+
else:
1877+
t_r.set_value(
1878+
random_projection(t_r.get_value().shape, t_r.get_value().dtype),
1879+
borrow=True,
1880+
)
18571881

18581882

18591883
class GradientError(Exception):

tests/tensor/test_nlinalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def test_grad(self, compute_uv, full_matrices, shape, batched):
250250
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
251251
[A_v],
252252
rng=rng,
253+
sum_outputs=True,
253254
)
254255

255256
else:

0 commit comments

Comments
 (0)