Skip to content

Commit 3773f16

Browse files
Use the value var's values directly in logpt
1 parent 9e92b74 commit 3773f16

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

pymc3/distributions/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,22 +313,20 @@ def logpt(
313313
# (signified by having a `transform` value in their tags), then we apply
314314
# the their transforms and add their Jacobians (when enabled)
315315
if transform:
316-
logp_var = _logp(rv_node.op, transform.backward(rv_value_var), *dist_params, **kwargs)
316+
logp_var = _logp(rv_node.op, transform.backward(rv_value), *dist_params, **kwargs)
317317
logp_var = transform_logp(
318318
logp_var,
319319
tuple(replacements.values()),
320320
)
321321

322322
if jacobian:
323-
transformed_jacobian = transform.jacobian_det(rv_value_var)
323+
transformed_jacobian = transform.jacobian_det(rv_value)
324324
if transformed_jacobian:
325325
if logp_var.ndim > transformed_jacobian.ndim:
326326
logp_var = logp_var.sum(axis=-1)
327327
logp_var += transformed_jacobian
328328
else:
329-
logp_var = _logp(rv_node.op, rv_value_var, *dist_params, **kwargs)
330-
331-
(logp_var,) = clone_replace([logp_var], replace={rv_value_var: rv_value})
329+
logp_var = _logp(rv_node.op, rv_value, *dist_params, **kwargs)
332330

333331
if scaling:
334332
logp_var *= _get_scaling(

0 commit comments

Comments
 (0)