Skip to content

Commit 6f80846

Browse files
committed
Fix jacobian dimensionality alignment in ChainedTransform
The existing code only considered mixing of scalar and vector transforms, but not potentially higher dimensionality transforms (e.g., matrix)
1 parent 83cd926 commit 6f80846

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc/logprob/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,8 @@ def log_jac_det(self, value, *inputs):
895895
det = 0.0
896896
for det_ in det_list:
897897
if det_.ndim > ndim0:
898-
det += det_.sum(axis=-1)
898+
ndim_diff = det_.ndim - ndim0
899+
det += det_.sum(axis=tuple(range(-ndim_diff, 0)))
899900
else:
900901
det += det_
901902
return det

0 commit comments

Comments
 (0)