Skip to content

Commit 6e57a08

Browse files
committed
Fix too strict type check in _sum_grad_over_bcasted_dims
1 parent 4c7b494 commit 6e57a08

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

pytensor/tensor/subtensor.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
20272027
if gx.broadcastable != x.broadcastable:
20282028
x_dim_added = gx.ndim - x.ndim
20292029
x_broad = (True,) * x_dim_added + x.broadcastable
2030-
assert sum(gx.broadcastable) <= sum(x_broad)
20312030
axis_to_sum = []
20322031
for i in range(gx.ndim):
20332032
if gx.broadcastable[i] is False and x_broad[i] is True:
@@ -2045,7 +2044,14 @@ def _sum_grad_over_bcasted_dims(x, gx):
20452044
for i in range(x_dim_added):
20462045
assert gx.broadcastable[i]
20472046
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
2048-
assert gx.broadcastable == x.broadcastable
2047+
# Broadcastable flags of gx can be the same or more specific than x.
2048+
# Only unallowed case is x_dim_b == True and gx_dim_b == False.
2049+
assert not any(
2050+
x_dim_b and not gx_dim_b
2051+
for x_dim_b, gx_dim_b in zip(
2052+
x.type.broadcastable, gx.type.broadcastable, strict=True
2053+
)
2054+
), (x.type, gx.type)
20492055
return gx
20502056

20512057

tests/tensor/test_subtensor.py

+22
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from pytensor import function
1313
from pytensor.compile import DeepCopyOp, shared
1414
from pytensor.compile.io import In
15+
from pytensor.compile.mode import Mode
1516
from pytensor.configdefaults import config
17+
from pytensor.gradient import grad
1618
from pytensor.graph.op import get_test_value
1719
from pytensor.graph.rewriting.utils import is_same_graph
1820
from pytensor.printing import pprint
@@ -22,6 +24,7 @@
2224
from pytensor.tensor.elemwise import DimShuffle
2325
from pytensor.tensor.math import exp, isinf
2426
from pytensor.tensor.math import sum as pt_sum
27+
from pytensor.tensor.shape import specify_shape
2528
from pytensor.tensor.subtensor import (
2629
AdvancedIncSubtensor,
2730
AdvancedIncSubtensor1,
@@ -1660,6 +1663,25 @@ def just_numeric_args(a, b):
16601663
),
16611664
)
16621665

1666+
def test_grad_broadcastable_specialization(self):
1667+
# Make sure gradient does not fail when gx has a more precise static_shape after indexing.
1668+
# This is a regression test for a bug reported in
1669+
# https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969
1670+
1671+
x = vector("x") # Unknown write time shape = (2,)
1672+
out = x.zeros_like()
1673+
1674+
# Update a subtensor of unknown write time shape = (1,)
1675+
out = out[1:].set(exp(x[1:]))
1676+
out = specify_shape(out, 2)
1677+
gx = grad(out.sum(), x)
1678+
1679+
mode = Mode(linker="py", optimizer=None)
1680+
np.testing.assert_allclose(
1681+
gx.eval({x: [1, 1]}, mode=mode),
1682+
[0, np.e],
1683+
)
1684+
16631685

16641686
class TestIncSubtensor1:
16651687
def setup_method(self):

0 commit comments

Comments
 (0)