Skip to content

Commit d8021fc

Browse files
committed
Fix bug in gradient of set_subtensor
1 parent 91046b6 commit d8021fc

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1902,7 +1902,7 @@ def _sum_grad_over_bcasted_dims(x, gx):
19021902
if gx.broadcastable != x.broadcastable:
19031903
x_dim_added = gx.ndim - x.ndim
19041904
x_broad = (True,) * x_dim_added + x.broadcastable
1905-
assert sum(gx.broadcastable) < sum(x_broad)
1905+
assert sum(gx.broadcastable) <= sum(x_broad)
19061906
axis_to_sum = []
19071907
for i in range(gx.ndim):
19081908
if gx.broadcastable[i] is False and x_broad[i] is True:

tests/tensor/test_subtensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,15 @@ def just_numeric_args(a, b):
15931593
),
15941594
)
15951595

1596+
# Broadcastable leading dim
1597+
utt.verify_grad(
1598+
f_slice(slice(None, None), slice(1, 3)),
1599+
(
1600+
np.asarray([0, 1, 2, 3, 4, 5.0])[None, ...],
1601+
np.asarray([9, 9.0]),
1602+
),
1603+
)
1604+
15961605

15971606
class TestIncSubtensor1:
15981607
def setup_method(self):

0 commit comments

Comments
 (0)