@@ -265,9 +265,9 @@ def _logp(
265
265
The default assumes that the log-likelihood of a term is a zero.
266
266
267
267
"""
268
- value_var = rvs_to_values .get (var , var )
269
- return at .zeros_like (value_var )
270
- # raise NotImplementedError(f"Logp cannot be computed for op {op}")
268
+ # value_var = rvs_to_values.get(var, var)
269
+ # return at.zeros_like(value_var)
270
+ raise NotImplementedError (f"Logp cannot be computed for op { op } " )
271
271
272
272
273
273
@_logp .register (Elemwise )
@@ -287,7 +287,7 @@ def elemwise_logp(op, *args, **kwargs):
287
287
288
288
@_logp .register (Add )
289
289
@_logp .register (Mul )
290
- def linear_logp (op , var , rvs_to_values , * linear_inputs , ** kwargs ):
290
+ def linear_logp (op , var , rvs_to_values , * linear_inputs , transformed = True , ** kwargs ):
291
291
292
292
if len (linear_inputs ) != 2 :
293
293
raise ValueError (f"Expected 2 inputs but got: { len (linear_inputs )} " )
@@ -319,34 +319,26 @@ def linear_logp(op, var, rvs_to_values, *linear_inputs, **kwargs):
319
319
constant = constant [0 ]
320
320
var_value = rvs_to_values .get (var , var )
321
321
322
- # Get logp of base_rv
323
- base_value = base_rv .type ()
324
- logp_base_rv = logpt (base_rv , {base_rv : base_value }, ** kwargs )
325
- fgraph = FunctionGraph (
326
- [i for i in graph_inputs ((logp_base_rv ,)) if not isinstance (i , Constant )],
327
- outputs = [logp_base_rv ],
328
- clone = False ,
329
- )
330
-
331
- # Transform base_rv and apply jacobian correction (for continuous rvs)
322
+ # Get logp of base_rv with transformed input
332
323
if isinstance (op , Add ):
333
- fgraph .replace (base_value , var_value - constant , import_missing = True )
334
- logp_linear_rv = fgraph .outputs [0 ]
335
- elif isinstance (op , Mul ):
336
- fgraph .replace (base_value , var_value / constant , import_missing = True )
337
- logp_linear_rv = fgraph .outputs [0 ]
338
- if "float" in base_rv .dtype :
339
- logp_linear_rv -= at .log (at .abs_ (constant ))
324
+ base_value = var_value - constant
325
+ else :
326
+ base_value = var_value / constant
327
+ var_logp = logpt (base_rv , {base_rv : base_value }, transformed = transformed , ** kwargs )
328
+
329
+ # Apply product jacobian correction for continuous rvs
330
+ if isinstance (op , Mul ) and "float" in base_rv .dtype :
331
+ var_logp -= at .log (at .abs_ (constant ))
340
332
341
333
# Replace rvs in graph
342
- (logp_linear_rv ,), _ = rvs_to_value_vars (
343
- (logp_linear_rv ,),
344
- apply_transforms = kwargs . get ( " transformed" , True ) ,
334
+ (var_logp ,), _ = rvs_to_value_vars (
335
+ (var_logp ,),
336
+ apply_transforms = transformed ,
345
337
initial_replacements = None ,
346
338
)
347
339
348
- logp_linear_rv .name = f"__logp_{ var .name } "
349
- return logp_linear_rv
340
+ var_logp .name = f"__logp_{ var .name } "
341
+ return var_logp
350
342
351
343
352
344
def convert_indices (indices , entry ):
@@ -405,8 +397,7 @@ def subtensor_logp(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs):
405
397
# subset of variables per the index.
406
398
var_copy = var .owner .clone ().default_output ()
407
399
fgraph = FunctionGraph (
408
- [i for i in graph_inputs ((indexed_rv_var ,)) if not isinstance (i , Constant )],
409
- [var_copy ],
400
+ outputs = [var_copy ],
410
401
clone = False ,
411
402
)
412
403
0 commit comments