Skip to content

Commit e8d5aee

Browse files
committed
Avoid unnecessary fgraph
1 parent 69f5caa commit e8d5aee

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

pymc3/distributions/logp.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ def _logp(
265265
The default assumes that the log-likelihood of a term is a zero.
266266
267267
"""
268-
value_var = rvs_to_values.get(var, var)
269-
return at.zeros_like(value_var)
270-
# raise NotImplementedError(f"Logp cannot be computed for op {op}")
268+
# value_var = rvs_to_values.get(var, var)
269+
# return at.zeros_like(value_var)
270+
raise NotImplementedError(f"Logp cannot be computed for op {op}")
271271

272272

273273
@_logp.register(Elemwise)
@@ -287,7 +287,7 @@ def elemwise_logp(op, *args, **kwargs):
287287

288288
@_logp.register(Add)
289289
@_logp.register(Mul)
290-
def linear_logp(op, var, rvs_to_values, *linear_inputs, **kwargs):
290+
def linear_logp(op, var, rvs_to_values, *linear_inputs, transformed=True, **kwargs):
291291

292292
if len(linear_inputs) != 2:
293293
raise ValueError(f"Expected 2 inputs but got: {len(linear_inputs)}")
@@ -319,34 +319,26 @@ def linear_logp(op, var, rvs_to_values, *linear_inputs, **kwargs):
319319
constant = constant[0]
320320
var_value = rvs_to_values.get(var, var)
321321

322-
# Get logp of base_rv
323-
base_value = base_rv.type()
324-
logp_base_rv = logpt(base_rv, {base_rv: base_value}, **kwargs)
325-
fgraph = FunctionGraph(
326-
[i for i in graph_inputs((logp_base_rv,)) if not isinstance(i, Constant)],
327-
outputs=[logp_base_rv],
328-
clone=False,
329-
)
330-
331-
# Transform base_rv and apply jacobian correction (for continuous rvs)
322+
# Get logp of base_rv with transformed input
332323
if isinstance(op, Add):
333-
fgraph.replace(base_value, var_value - constant, import_missing=True)
334-
logp_linear_rv = fgraph.outputs[0]
335-
elif isinstance(op, Mul):
336-
fgraph.replace(base_value, var_value / constant, import_missing=True)
337-
logp_linear_rv = fgraph.outputs[0]
338-
if "float" in base_rv.dtype:
339-
logp_linear_rv -= at.log(at.abs_(constant))
324+
base_value = var_value - constant
325+
else:
326+
base_value = var_value / constant
327+
var_logp = logpt(base_rv, {base_rv: base_value}, transformed=transformed, **kwargs)
328+
329+
# Apply product jacobian correction for continuous rvs
330+
if isinstance(op, Mul) and "float" in base_rv.dtype:
331+
var_logp -= at.log(at.abs_(constant))
340332

341333
# Replace rvs in graph
342-
(logp_linear_rv,), _ = rvs_to_value_vars(
343-
(logp_linear_rv,),
344-
apply_transforms=kwargs.get("transformed", True),
334+
(var_logp,), _ = rvs_to_value_vars(
335+
(var_logp,),
336+
apply_transforms=transformed,
345337
initial_replacements=None,
346338
)
347339

348-
logp_linear_rv.name = f"__logp_{var.name}"
349-
return logp_linear_rv
340+
var_logp.name = f"__logp_{var.name}"
341+
return var_logp
350342

351343

352344
def convert_indices(indices, entry):
@@ -405,8 +397,7 @@ def subtensor_logp(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs):
405397
# subset of variables per the index.
406398
var_copy = var.owner.clone().default_output()
407399
fgraph = FunctionGraph(
408-
[i for i in graph_inputs((indexed_rv_var,)) if not isinstance(i, Constant)],
409-
[var_copy],
400+
outputs=[var_copy],
410401
clone=False,
411402
)
412403

0 commit comments

Comments
 (0)