Skip to content

Commit 91966e8

Browse files
committed
Fix bug in grad of discrete Switch
1 parent 848ce19 commit 91966e8

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

pytensor/scalar/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,8 +1598,8 @@ def L_op(self, inputs, outputs, gout):
15981598
second_part = switch(cond, 0.0, gz)
15991599

16001600
if outputs[0].type in discrete_types:
1601-
first_part = 0.0
1602-
second_part = 0.0
1601+
first_part = ift.zeros_like(config.floatX)
1602+
second_part = iff.zeros_like(config.floatX)
16031603

16041604
# cond does affect the elements of the output so it is connected.
16051605
# For the sake of making the gradient convenient we assume that

tests/scalar/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,13 @@ def test_grad_switch():
428428

429429
pytensor.gradient.grad(l, x)
430430

431+
# Bug reported in https://github.com/pymc-devs/pytensor/issues/331
432+
x = matrix(dtype=int)
433+
s = pytensor.tensor.switch(0, x, -x)
434+
l = s.sum()
435+
436+
pytensor.gradient.grad(l, x)
437+
431438

432439
def test_grad_identity():
433440
# Check that the grad method of Identity correctly handles int dytpes

0 commit comments

Comments
 (0)