Skip to content

Commit d860f63

Browse files
committed
Fix broadcasting issue in ScaleTransform jacobian
1 parent 6f80846 commit d860f63

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pymc/logprob/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def backward(self, value, *inputs):
677677

678678
def log_jac_det(self, value, *inputs):
679679
scale = self.transform_args_fn(*inputs)
680-
return -pt.log(pt.abs(scale))
680+
return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape)))
681681

682682

683683
class LogTransform(RVTransform):

tests/logprob/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,8 @@ def test_chained_transform():
626626

627627
log_jac_det = ch.log_jac_det(x_val_forward, *x.owner.inputs, scale, loc)
628628
assert np.isclose(
629-
log_jac_det.eval(),
630-
-np.log(scale) - np.sum(np.log(x_val_forward - loc)),
629+
pt.sum(log_jac_det).eval(),
630+
np.sum(-np.log(scale) - np.log(x_val_forward - loc)),
631631
)
632632

633633

0 commit comments

Comments
 (0)