@@ -1675,6 +1675,7 @@ def verify_grad(
1675
1675
mode : Optional [Union ["Mode" , str ]] = None ,
1676
1676
cast_to_output_type : bool = False ,
1677
1677
no_debug_ref : bool = True ,
1678
+ sum_outputs = False ,
1678
1679
):
1679
1680
"""Test a gradient by Finite Difference Method. Raise error on failure.
1680
1681
@@ -1722,7 +1723,9 @@ def verify_grad(
1722
1723
float16 is not handled here.
1723
1724
no_debug_ref
1724
1725
Don't use `DebugMode` for the numerical gradient function.
1725
-
1726
+ sum_outputs: bool, default False
1727
+ If True, the gradient of the sum of all outputs is verified. If False, an error is raised if the function has
1728
+ multiple outputs.
1726
1729
Notes
1727
1730
-----
1728
1731
This function does not support multiple outputs. In `tests.scan.test_basic`
@@ -1782,7 +1785,7 @@ def verify_grad(
1782
1785
# fun can be either a function or an actual Op instance
1783
1786
o_output = fun (* tensor_pt )
1784
1787
1785
- if isinstance (o_output , list ):
1788
+ if isinstance (o_output , list ) and not sum_outputs :
1786
1789
raise NotImplementedError (
1787
1790
"Can't (yet) auto-test the gradient of a function with multiple outputs"
1788
1791
)
@@ -1793,7 +1796,7 @@ def verify_grad(
1793
1796
o_fn = fn_maker (tensor_pt , o_output , name = "gradient.py fwd" )
1794
1797
o_fn_out = o_fn (* [p .copy () for p in pt ])
1795
1798
1796
- if isinstance (o_fn_out , tuple ) or isinstance (o_fn_out , list ):
1799
+ if isinstance (o_fn_out , tuple ) or isinstance (o_fn_out , list ) and not sum_outputs :
1797
1800
raise TypeError (
1798
1801
"It seems like you are trying to use verify_grad "
1799
1802
"on an Op or a function which outputs a list: there should"
@@ -1802,33 +1805,45 @@ def verify_grad(
1802
1805
1803
1806
# random_projection should not have elements too small,
1804
1807
# otherwise too much precision is lost in numerical gradient
1805
- def random_projection ():
1806
- plain = rng .random (o_fn_out . shape ) + 0.5
1807
- if cast_to_output_type and o_output . dtype == "float32" :
1808
- return np .array (plain , o_output . dtype )
1808
+ def random_projection (shape , dtype ):
1809
+ plain = rng .random (shape ) + 0.5
1810
+ if cast_to_output_type and dtype == "float32" :
1811
+ return np .array (plain , dtype )
1809
1812
return plain
1810
1813
1811
- t_r = shared (random_projection (), borrow = True )
1812
- t_r .name = "random_projection"
1813
-
1814
1814
# random projection of o onto t_r
1815
1815
# This sum() is defined above, it's not the builtin sum.
1816
- cost = pytensor .tensor .sum (t_r * o_output )
1816
+ if sum_outputs :
1817
+ t_rs = [
1818
+ shared (
1819
+ value = random_projection (o .shape , o .dtype ),
1820
+ borrow = True ,
1821
+ name = f"random_projection_{ i } " ,
1822
+ )
1823
+ for i , o in enumerate (o_fn_out )
1824
+ ]
1825
+ cost = pytensor .tensor .sum (
1826
+ [pytensor .tensor .sum (x * y ) for x , y in zip (t_rs , o_output )]
1827
+ )
1828
+ else :
1829
+ t_r = shared (
1830
+ value = random_projection (o_fn_out .shape , o_fn_out .dtype ),
1831
+ borrow = True ,
1832
+ name = "random_projection" ,
1833
+ )
1834
+ cost = pytensor .tensor .sum (t_r * o_output )
1817
1835
1818
1836
if no_debug_ref :
1819
1837
mode_for_cost = mode_not_slow (mode )
1820
1838
else :
1821
1839
mode_for_cost = mode
1822
1840
1823
1841
cost_fn = fn_maker (tensor_pt , cost , name = "gradient.py cost" , mode = mode_for_cost )
1824
-
1825
1842
symbolic_grad = grad (cost , tensor_pt , disconnected_inputs = "ignore" )
1826
-
1827
1843
grad_fn = fn_maker (tensor_pt , symbolic_grad , name = "gradient.py symbolic grad" )
1828
1844
1829
1845
for test_num in range (n_tests ):
1830
1846
num_grad = numeric_grad (cost_fn , [p .copy () for p in pt ], eps , out_type )
1831
-
1832
1847
analytic_grad = grad_fn (* [p .copy () for p in pt ])
1833
1848
1834
1849
# Since `tensor_pt` is a list, `analytic_grad` should be one too.
@@ -1853,7 +1868,16 @@ def random_projection():
1853
1868
1854
1869
# get new random projection for next test
1855
1870
if test_num < n_tests - 1 :
1856
- t_r .set_value (random_projection (), borrow = True )
1871
+ if sum_outputs :
1872
+ for r in t_rs :
1873
+ r .set_value (
1874
+ random_projection (r .get_value ().shape , r .get_value ().dtype )
1875
+ )
1876
+ else :
1877
+ t_r .set_value (
1878
+ random_projection (t_r .get_value ().shape , t_r .get_value ().dtype ),
1879
+ borrow = True ,
1880
+ )
1857
1881
1858
1882
1859
1883
class GradientError (Exception ):
0 commit comments