@@ -1675,6 +1675,7 @@ def verify_grad(
1675
1675
mode : Optional [Union ["Mode" , str ]] = None ,
1676
1676
cast_to_output_type : bool = False ,
1677
1677
no_debug_ref : bool = True ,
1678
+ sum_outputs = False ,
1678
1679
):
1679
1680
"""Test a gradient by Finite Difference Method. Raise error on failure.
1680
1681
@@ -1722,7 +1723,9 @@ def verify_grad(
1722
1723
float16 is not handled here.
1723
1724
no_debug_ref
1724
1725
Don't use `DebugMode` for the numerical gradient function.
1725
-
1726
+ sum_outputs: bool, default False
1727
+ If True, the gradient of the sum of all outputs is verified. If False, an error is raised if the function has
1728
+ multiple outputs.
1726
1729
Notes
1727
1730
-----
1728
1731
This function does not support multiple outputs. In `tests.scan.test_basic`
@@ -1782,7 +1785,7 @@ def verify_grad(
1782
1785
# fun can be either a function or an actual Op instance
1783
1786
o_output = fun (* tensor_pt )
1784
1787
1785
- if isinstance (o_output , list ):
1788
+ if isinstance (o_output , list ) and not sum_outputs :
1786
1789
raise NotImplementedError (
1787
1790
"Can't (yet) auto-test the gradient of a function with multiple outputs"
1788
1791
)
@@ -1793,7 +1796,7 @@ def verify_grad(
1793
1796
o_fn = fn_maker (tensor_pt , o_output , name = "gradient.py fwd" )
1794
1797
o_fn_out = o_fn (* [p .copy () for p in pt ])
1795
1798
1796
- if isinstance (o_fn_out , tuple ) or isinstance (o_fn_out , list ):
1799
+ if isinstance (o_fn_out , tuple ) or isinstance (o_fn_out , list ) and not sum_outputs :
1797
1800
raise TypeError (
1798
1801
"It seems like you are trying to use verify_grad "
1799
1802
"on an Op or a function which outputs a list: there should"
@@ -1802,33 +1805,40 @@ def verify_grad(
1802
1805
1803
1806
# random_projection should not have elements too small,
1804
1807
# otherwise too much precision is lost in numerical gradient
1805
- def random_projection ():
1806
- plain = rng .random (o_fn_out . shape ) + 0.5
1807
- if cast_to_output_type and o_output . dtype == "float32" :
1808
- return np .array (plain , o_output . dtype )
1808
+ def random_projection (shape , dtype ):
1809
+ plain = rng .random (shape ) + 0.5
1810
+ if cast_to_output_type and dtype == "float32" :
1811
+ return np .array (plain , dtype )
1809
1812
return plain
1810
1813
1811
- t_r = shared (random_projection (), borrow = True )
1812
- t_r .name = "random_projection"
1813
-
1814
1814
# random projection of o onto t_r
1815
1815
# This sum() is defined above, it's not the builtin sum.
1816
- cost = pytensor .tensor .sum (t_r * o_output )
1816
+ if sum_outputs :
1817
+ t_rs = [
1818
+ shared (random_projection (o .shape , o .dtype ), borrow = True ) for o in o_fn_out
1819
+ ]
1820
+ for i , x in enumerate (t_rs ):
1821
+ x .name = "ranom_projection_{i}"
1822
+ cost = pytensor .tensor .sum (
1823
+ [pytensor .tensor .sum (x * y ) for x , y in zip (t_rs , o_output )]
1824
+ )
1825
+ else :
1826
+ t_r = shared (random_projection (o_fn_out .shape , o_fn_out .dtype ), borrow = True )
1827
+ t_r .name = "random_projection"
1828
+
1829
+ cost = pytensor .tensor .sum (t_r * o_output )
1817
1830
1818
1831
if no_debug_ref :
1819
1832
mode_for_cost = mode_not_slow (mode )
1820
1833
else :
1821
1834
mode_for_cost = mode
1822
1835
1823
1836
cost_fn = fn_maker (tensor_pt , cost , name = "gradient.py cost" , mode = mode_for_cost )
1824
-
1825
1837
symbolic_grad = grad (cost , tensor_pt , disconnected_inputs = "ignore" )
1826
-
1827
1838
grad_fn = fn_maker (tensor_pt , symbolic_grad , name = "gradient.py symbolic grad" )
1828
1839
1829
1840
for test_num in range (n_tests ):
1830
1841
num_grad = numeric_grad (cost_fn , [p .copy () for p in pt ], eps , out_type )
1831
-
1832
1842
analytic_grad = grad_fn (* [p .copy () for p in pt ])
1833
1843
1834
1844
# Since `tensor_pt` is a list, `analytic_grad` should be one too.
@@ -1853,7 +1863,16 @@ def random_projection():
1853
1863
1854
1864
# get new random projection for next test
1855
1865
if test_num < n_tests - 1 :
1856
- t_r .set_value (random_projection (), borrow = True )
1866
+ if sum_outputs :
1867
+ for r in t_rs :
1868
+ r .set_value (
1869
+ random_projection (r .get_value ().shape , r .get_value ().dtype )
1870
+ )
1871
+ else :
1872
+ t_r .set_value (
1873
+ random_projection (t_r .get_value ().shape , t_r .get_value ().dtype ),
1874
+ borrow = True ,
1875
+ )
1857
1876
1858
1877
1859
1878
class GradientError (Exception ):
0 commit comments