File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -313,22 +313,20 @@ def logpt(
313
313
# (signified by having a `transform` value in their tags), then we apply
314
314
# the their transforms and add their Jacobians (when enabled)
315
315
if transform :
316
- logp_var = _logp (rv_node .op , transform .backward (rv_value_var ), * dist_params , ** kwargs )
316
+ logp_var = _logp (rv_node .op , transform .backward (rv_value ), * dist_params , ** kwargs )
317
317
logp_var = transform_logp (
318
318
logp_var ,
319
319
tuple (replacements .values ()),
320
320
)
321
321
322
322
if jacobian :
323
- transformed_jacobian = transform .jacobian_det (rv_value_var )
323
+ transformed_jacobian = transform .jacobian_det (rv_value )
324
324
if transformed_jacobian :
325
325
if logp_var .ndim > transformed_jacobian .ndim :
326
326
logp_var = logp_var .sum (axis = - 1 )
327
327
logp_var += transformed_jacobian
328
328
else :
329
- logp_var = _logp (rv_node .op , rv_value_var , * dist_params , ** kwargs )
330
-
331
- (logp_var ,) = clone_replace ([logp_var ], replace = {rv_value_var : rv_value })
329
+ logp_var = _logp (rv_node .op , rv_value , * dist_params , ** kwargs )
332
330
333
331
if scaling :
334
332
logp_var *= _get_scaling (
You can’t perform that action at this time.
0 commit comments