Skip to content

Commit 08a627b

Browse files
committed
Simplify single variable logp inference
1 parent beaf107 commit 08a627b

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

pymc/logprob/basic.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,16 @@
5656
from pymc.logprob.utils import rvs_to_value_vars
5757

5858

59-
def logp(rv: TensorVariable, value) -> TensorVariable:
59+
def logp(rv: TensorVariable, value: TensorVariable, **kwargs) -> TensorVariable:
6060
"""Return the log-probability graph of a Random Variable"""
6161

6262
value = pt.as_tensor_variable(value, dtype=rv.dtype)
6363
try:
64-
return logp_logprob(rv, value)
64+
return logp_logprob(rv, value, **kwargs)
6565
except NotImplementedError:
66-
try:
67-
value = rv.type.filter_variable(value)
68-
except TypeError as exc:
69-
raise TypeError(
70-
"When RV is not a pure distribution, value variable must have the same type"
71-
) from exc
72-
try:
73-
return factorized_joint_logprob({rv: value}, warn_missing_rvs=False)[value]
74-
except Exception as exc:
75-
raise NotImplementedError("PyMC could not infer logp of input variable.") from exc
66+
fgraph, _, _ = construct_ir_fgraph({rv: value})
67+
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
68+
return logp_logprob(ir_rv, ir_value, **kwargs)
7669

7770

7871
def factorized_joint_logprob(

pymc/logprob/scan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te
351351
# Return only the logp outputs, not any potentially carried states
352352
logp_outputs = logp_scan_out[-len(values) :]
353353

354+
if len(logp_outputs) == 1:
355+
return logp_outputs[0]
354356
return logp_outputs
355357

356358

0 commit comments

Comments
 (0)