Skip to content

Commit bb9b02b

Browse files
Limited verify_grad support for multiple output Ops
1 parent c067986 commit bb9b02b

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

pytensor/gradient.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,7 @@ def verify_grad(
16751675
mode: Optional[Union["Mode", str]] = None,
16761676
cast_to_output_type: bool = False,
16771677
no_debug_ref: bool = True,
1678+
sum_outputs=False,
16781679
):
16791680
"""Test a gradient by Finite Difference Method. Raise error on failure.
16801681
@@ -1722,7 +1723,9 @@ def verify_grad(
17221723
float16 is not handled here.
17231724
no_debug_ref
17241725
Don't use `DebugMode` for the numerical gradient function.
1725-
1726+
sum_outputs: bool, default False
1727+
If True, the gradient of the sum of all outputs is verified. If False, an error is raised if the function has
1728+
multiple outputs.
17261729
Notes
17271730
-----
17281731
This function does not support multiple outputs. In `tests.scan.test_basic`
@@ -1782,7 +1785,7 @@ def verify_grad(
17821785
# fun can be either a function or an actual Op instance
17831786
o_output = fun(*tensor_pt)
17841787

1785-
if isinstance(o_output, list):
1788+
if isinstance(o_output, list) and not sum_outputs:
17861789
raise NotImplementedError(
17871790
"Can't (yet) auto-test the gradient of a function with multiple outputs"
17881791
)
@@ -1793,7 +1796,7 @@ def verify_grad(
17931796
o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd")
17941797
o_fn_out = o_fn(*[p.copy() for p in pt])
17951798

1796-
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
1799+
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list) and not sum_outputs:
17971800
raise TypeError(
17981801
"It seems like you are trying to use verify_grad "
17991802
"on an Op or a function which outputs a list: there should"
@@ -1802,33 +1805,40 @@ def verify_grad(
18021805

18031806
# random_projection should not have elements too small,
18041807
# otherwise too much precision is lost in numerical gradient
1805-
def random_projection():
1806-
plain = rng.random(o_fn_out.shape) + 0.5
1807-
if cast_to_output_type and o_output.dtype == "float32":
1808-
return np.array(plain, o_output.dtype)
1808+
def random_projection(shape, dtype):
1809+
plain = rng.random(shape) + 0.5
1810+
if cast_to_output_type and dtype == "float32":
1811+
return np.array(plain, dtype)
18091812
return plain
18101813

1811-
t_r = shared(random_projection(), borrow=True)
1812-
t_r.name = "random_projection"
1813-
18141814
# random projection of o onto t_r
18151815
# This sum() is defined above, it's not the builtin sum.
1816-
cost = pytensor.tensor.sum(t_r * o_output)
1816+
if sum_outputs:
1817+
t_rs = [
1818+
shared(random_projection(o.shape, o.dtype), borrow=True) for o in o_fn_out
1819+
]
1820+
for i, x in enumerate(t_rs):
1821+
x.name = "ranom_projection_{i}"
1822+
cost = pytensor.tensor.sum(
1823+
[pytensor.tensor.sum(x * y) for x, y in zip(t_rs, o_output)]
1824+
)
1825+
else:
1826+
t_r = shared(random_projection(o_fn_out.shape, o_fn_out.dtype), borrow=True)
1827+
t_r.name = "random_projection"
1828+
1829+
cost = pytensor.tensor.sum(t_r * o_output)
18171830

18181831
if no_debug_ref:
18191832
mode_for_cost = mode_not_slow(mode)
18201833
else:
18211834
mode_for_cost = mode
18221835

18231836
cost_fn = fn_maker(tensor_pt, cost, name="gradient.py cost", mode=mode_for_cost)
1824-
18251837
symbolic_grad = grad(cost, tensor_pt, disconnected_inputs="ignore")
1826-
18271838
grad_fn = fn_maker(tensor_pt, symbolic_grad, name="gradient.py symbolic grad")
18281839

18291840
for test_num in range(n_tests):
18301841
num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps, out_type)
1831-
18321842
analytic_grad = grad_fn(*[p.copy() for p in pt])
18331843

18341844
# Since `tensor_pt` is a list, `analytic_grad` should be one too.
@@ -1853,7 +1863,16 @@ def random_projection():
18531863

18541864
# get new random projection for next test
18551865
if test_num < n_tests - 1:
1856-
t_r.set_value(random_projection(), borrow=True)
1866+
if sum_outputs:
1867+
for r in t_rs:
1868+
r.set_value(
1869+
random_projection(r.get_value().shape, r.get_value().dtype)
1870+
)
1871+
else:
1872+
t_r.set_value(
1873+
random_projection(t_r.get_value().shape, t_r.get_value().dtype),
1874+
borrow=True,
1875+
)
18571876

18581877

18591878
class GradientError(Exception):

tests/tensor/test_nlinalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def test_grad(self, compute_uv, full_matrices, shape):
245245
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
246246
[A_v],
247247
rng=rng,
248+
sum_outputs=True,
248249
)
249250

250251
else:

0 commit comments

Comments
 (0)