-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Rename logprob/joint_logp
to logprob/basic
and move logcdf
and icdf
functions there
#6599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,40 +39,54 @@ | |
from collections import deque | ||
from typing import Dict, List, Optional, Sequence, Union | ||
|
||
import numpy as np | ||
import pytensor | ||
import pytensor.tensor as pt | ||
|
||
from pytensor import config | ||
from pytensor.graph.basic import graph_inputs, io_toposort | ||
from pytensor.graph.basic import Variable, graph_inputs, io_toposort | ||
from pytensor.graph.op import compute_test_value | ||
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.tensor.var import TensorVariable | ||
|
||
from pymc.logprob.abstract import _logprob, get_measurable_outputs | ||
from pymc.logprob.abstract import logprob as logp_logprob | ||
from typing_extensions import TypeAlias | ||
|
||
from pymc.logprob.abstract import ( | ||
_icdf_helper, | ||
_logcdf_helper, | ||
_logprob, | ||
_logprob_helper, | ||
get_measurable_outputs, | ||
) | ||
from pymc.logprob.rewriting import construct_ir_fgraph | ||
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite | ||
from pymc.logprob.utils import rvs_to_value_vars | ||
|
||
TensorLike: TypeAlias = Union[Variable, float, np.ndarray] | ||
|
||
|
||
def logp(rv: TensorVariable, value) -> TensorVariable: | ||
def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: | ||
"""Return the log-probability graph of a Random Variable""" | ||
|
||
value = pt.as_tensor_variable(value, dtype=rv.dtype) | ||
try: | ||
return logp_logprob(rv, value) | ||
return _logprob_helper(rv, value, **kwargs) | ||
except NotImplementedError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me see if I understand this correctly: We enter this In these cases, why would the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It falls back to the intermediate representation (IR) which converts stuff like Exp(Beta) into MeasurableExp(Beta) (or whatever is deemed valid), which itself has a logp method dispatched. That's basically how inference for non basic RVs is carried out. |
||
try: | ||
value = rv.type.filter_variable(value) | ||
except TypeError as exc: | ||
raise TypeError( | ||
"When RV is not a pure distribution, value variable must have the same type" | ||
) from exc | ||
try: | ||
return factorized_joint_logprob({rv: value}, warn_missing_rvs=False)[value] | ||
except Exception as exc: | ||
raise NotImplementedError("PyMC could not infer logp of input variable.") from exc | ||
fgraph, _, _ = construct_ir_fgraph({rv: value}) | ||
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() | ||
return _logprob_helper(ir_rv, ir_value, **kwargs) | ||
|
||
|
||
def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: | ||
"""Create a graph for the log-CDF of a Random Variable.""" | ||
value = pt.as_tensor_variable(value, dtype=rv.dtype) | ||
return _logcdf_helper(rv, value, **kwargs) | ||
|
||
|
||
def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: | ||
"""Create a graph for the inverse CDF of a Random Variable.""" | ||
value = pt.as_tensor_variable(value) | ||
return _icdf_helper(rv, value, **kwargs) | ||
|
||
|
||
def factorized_joint_logprob( | ||
|
Uh oh!
There was an error while loading. Please reload this page.