|
12 | 12 | from pytensor import function
|
13 | 13 | from pytensor.compile import DeepCopyOp, shared
|
14 | 14 | from pytensor.compile.io import In
|
| 15 | +from pytensor.compile.mode import Mode |
15 | 16 | from pytensor.configdefaults import config
|
| 17 | +from pytensor.gradient import grad |
16 | 18 | from pytensor.graph.op import get_test_value
|
17 | 19 | from pytensor.graph.rewriting.utils import is_same_graph
|
18 | 20 | from pytensor.printing import pprint
|
|
22 | 24 | from pytensor.tensor.elemwise import DimShuffle
|
23 | 25 | from pytensor.tensor.math import exp, isinf
|
24 | 26 | from pytensor.tensor.math import sum as pt_sum
|
| 27 | +from pytensor.tensor.shape import specify_shape |
25 | 28 | from pytensor.tensor.subtensor import (
|
26 | 29 | AdvancedIncSubtensor,
|
27 | 30 | AdvancedIncSubtensor1,
|
@@ -1660,6 +1663,25 @@ def just_numeric_args(a, b):
|
1660 | 1663 | ),
|
1661 | 1664 | )
|
1662 | 1665 |
|
| 1666 | + def test_grad_broadcastable_specialization(self): |
| 1667 | + # Make sure gradient does not fail when gx has a more precise static_shape after indexing. |
| 1668 | + # This is a regression test for a bug reported in |
| 1669 | + # https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969 |
| 1670 | + |
| 1671 | + x = vector("x") # Unknown write time shape = (2,) |
| 1672 | + out = x.zeros_like() |
| 1673 | + |
| 1674 | + # Update a subtensor of unknown write time shape = (1,) |
| 1675 | + out = out[1:].set(exp(x[1:])) |
| 1676 | + out = specify_shape(out, 2) |
| 1677 | + gx = grad(out.sum(), x) |
| 1678 | + |
| 1679 | + mode = Mode(linker="py", optimizer=None) |
| 1680 | + np.testing.assert_allclose( |
| 1681 | + gx.eval({x: [1, 1]}, mode=mode), |
| 1682 | + [0, np.e], |
| 1683 | + ) |
| 1684 | + |
1663 | 1685 |
|
1664 | 1686 | class TestIncSubtensor1:
|
1665 | 1687 | def setup_method(self):
|
|
0 commit comments