Skip to content

Rename logprob/joint_logp to logprob/basic and move logcdf and icdf functions there #6599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ jobs:
tests/backends/test_mcbackend.py
tests/distributions/test_truncated.py
tests/logprob/test_abstract.py
tests/logprob/test_basic.py
tests/logprob/test_censoring.py
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
tests/logprob/test_joint_logprob.py
tests/logprob/test_mixture.py
tests/logprob/test_rewriting.py
tests/logprob/test_scan.py
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.logprob.joint_logprob import logp
from pymc.logprob.basic import logp
from pymc.logprob.utils import ignore_logprob
from pymc.model import modelcontext
from pymc.pytensorf import floatX, intX
Expand Down
10 changes: 5 additions & 5 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorConstant

from pymc.logprob.abstract import _logprob, logcdf, logprob
from pymc.logprob.abstract import _logcdf_helper, _logprob_helper

try:
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
Expand Down Expand Up @@ -722,7 +722,7 @@ def logp(value, mu, sigma, lower, upper):
else:
norm = 0.0

logp = _logprob(normal, (value,), None, None, None, mu, sigma) - norm
logp = _logprob_helper(Normal.dist(mu, sigma), value) - norm

if is_lower_bounded:
logp = pt.switch(value < lower, -np.inf, logp)
Expand Down Expand Up @@ -2033,7 +2033,7 @@ def moment(rv, size, loc, beta):
return beta

def logp(value, loc, beta):
res = pt.log(2) + logprob(Cauchy.dist(loc, beta), value)
res = pt.log(2) + _logprob_helper(Cauchy.dist(loc, beta), value)
res = pt.switch(pt.ge(value, loc), res, -np.inf)
return check_parameters(
res,
Expand Down Expand Up @@ -2342,10 +2342,10 @@ def moment(rv, size, nu):
return moment

def logp(value, nu):
return logprob(Gamma.dist(alpha=nu / 2, beta=0.5), value)
return _logprob_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)

def logcdf(value, nu):
return logcdf(Gamma.dist(alpha=nu / 2, beta=0.5), value)
return _logcdf_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)


# TODO: Remove this once logp for multiplication is working!
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pymc.distributions.distribution import Discrete
from pymc.distributions.mixture import Mixture
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.logprob.joint_logprob import logp
from pymc.logprob.basic import logp
from pymc.math import sigmoid
from pymc.pytensorf import floatX, intX
from pymc.vartypes import continuous_types
Expand Down
11 changes: 5 additions & 6 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.distributions.transforms import _default_transform
from pymc.logprob.abstract import _logcdf, _logprob, logcdf
from pymc.logprob.joint_logprob import logp
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob, _logprob_helper
from pymc.logprob.transforms import IntervalTransform
from pymc.logprob.utils import ignore_logprob
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -337,10 +336,10 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
if len(components) == 1:
# Need to broadcast value across mixture axis
mix_axis = -components[0].owner.op.ndim_supp - 1
components_logp = logp(components[0], pt.expand_dims(value, mix_axis))
components_logp = _logprob_helper(components[0], pt.expand_dims(value, mix_axis))
else:
components_logp = pt.stack(
[logp(component, value) for component in components],
[_logprob_helper(component, value) for component in components],
axis=-1,
)

Expand All @@ -363,10 +362,10 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):
if len(components) == 1:
# Need to broadcast value across mixture axis
mix_axis = -components[0].owner.op.ndim_supp - 1
components_logcdf = logcdf(components[0], pt.expand_dims(value, mix_axis))
components_logcdf = _logcdf_helper(components[0], pt.expand_dims(value, mix_axis))
else:
components_logcdf = pt.stack(
[logcdf(component, value) for component in components],
[_logcdf_helper(component, value) for component in components],
axis=-1,
)

Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from pymc.exceptions import NotConstantValueError
from pymc.logprob.abstract import _logprob
from pymc.logprob.joint_logprob import logp
from pymc.logprob.basic import logp
from pymc.logprob.utils import ignore_logprob, reconsider_logprob
from pymc.pytensorf import constant_fold, floatX, intX
from pymc.util import check_dist_not_registered
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob, icdf, logcdf
from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf
from pymc.math import logdiffexp
from pymc.util import check_dist_not_registered

Expand Down
10 changes: 6 additions & 4 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from pymc.logprob.abstract import logprob, logcdf # isort: split

from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logp, logp
from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp

# isort: off
# Add rewrites to the DBs
Expand All @@ -49,4 +47,8 @@

# isort: on

__all__ = ("logp", "logcdf")
__all__ = (
"logp",
"logcdf",
"icdf",
)
65 changes: 34 additions & 31 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,35 +48,6 @@
from pytensor.tensor.random.op import RandomVariable


def logprob(rv_var, *rv_values, **kwargs):
"""Create a graph for the log-probability of a ``RandomVariable``."""
logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs)

for rv_var in rv_values:
if rv_var.name:
logprob.name = f"{rv_var.name}_logprob"

return logprob


def logcdf(rv_var, rv_value, **kwargs):
"""Create a graph for the logcdf of a ``RandomVariable``."""
logcdf = _logcdf(rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs)

if rv_var.name:
logcdf.name = f"{rv_var.name}_logcdf"

return logcdf


def icdf(rv, value, **kwargs):
"""Create a graph for the inverse CDF of a `RandomVariable`."""
rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs)
if rv.name:
rv_icdf.name = f"{rv.name}_icdf"
return rv_icdf


@singledispatch
def _logprob(
op: Op,
Expand All @@ -94,6 +65,18 @@ def _logprob(
raise NotImplementedError(f"Logprob method not implemented for {op}")


def _logprob_helper(rv, *values, **kwargs):
"""Helper that calls `_logprob` dispatcher."""
logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)

for rv in values:
if rv.name:
logprob.name = f"{rv.name}_logprob"
break

return logprob


@singledispatch
def _logcdf(
op: Op,
Expand All @@ -107,7 +90,17 @@ def _logcdf(
of ``RandomVariable``. If you want to implement new logcdf graphs
for a ``RandomVariable``, register a new function on this dispatcher.
"""
raise NotImplementedError(f"Logcdf method not implemented for {op}")
raise NotImplementedError(f"LogCDF method not implemented for {op}")


def _logcdf_helper(rv, value, **kwargs):
"""Helper that calls `_logcdf` dispatcher."""
logcdf = _logcdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)

if rv.name:
logcdf.name = f"{rv.name}_logcdf"

return logcdf


@singledispatch
Expand All @@ -122,7 +115,17 @@ def _icdf(
This function dispatches on the type of `op`, which should be a subclass
of `RandomVariable`.
"""
raise NotImplementedError(f"icdf not implemented for {op}")
raise NotImplementedError(f"Inverse CDF method not implemented for {op}")


def _icdf_helper(rv, value, **kwargs):
"""Helper that calls `_icdf` dispatcher."""
rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs)

if rv.name:
rv_icdf.name = f"{rv.name}_icdf"

return rv_icdf


class MeasurableVariable(abc.ABC):
Expand Down
46 changes: 30 additions & 16 deletions pymc/logprob/joint_logprob.py → pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,40 +39,54 @@
from collections import deque
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import pytensor
import pytensor.tensor as pt

from pytensor import config
from pytensor.graph.basic import graph_inputs, io_toposort
from pytensor.graph.basic import Variable, graph_inputs, io_toposort
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import _logprob, get_measurable_outputs
from pymc.logprob.abstract import logprob as logp_logprob
from typing_extensions import TypeAlias

from pymc.logprob.abstract import (
_icdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
get_measurable_outputs,
)
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
from pymc.logprob.utils import rvs_to_value_vars

TensorLike: TypeAlias = Union[Variable, float, np.ndarray]


def logp(rv: TensorVariable, value) -> TensorVariable:
def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
"""Return the log-probability graph of a Random Variable"""

value = pt.as_tensor_variable(value, dtype=rv.dtype)
try:
return logp_logprob(rv, value)
return _logprob_helper(rv, value, **kwargs)
except NotImplementedError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if I understand this correctly:

We enter this NonImplementedError in cases where the rv.owner.op has a type that was not registered anywhere with @_logprob.register(...), and so that falls back to the default (non)-implementation in abstract.py.

In these cases, why would the FunctionGraph approach work?

Copy link
Member Author

@ricardoV94 ricardoV94 Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It falls back to the intermediate representation (IR) which converts stuff like Exp(Beta) into MeasurableExp(Beta) (or whatever is deemed valid), which itself has a logp method dispatched.

That's basically how inference for non basic RVs is carried out.

try:
value = rv.type.filter_variable(value)
except TypeError as exc:
raise TypeError(
"When RV is not a pure distribution, value variable must have the same type"
) from exc
try:
return factorized_joint_logprob({rv: value}, warn_missing_rvs=False)[value]
except Exception as exc:
raise NotImplementedError("PyMC could not infer logp of input variable.") from exc
fgraph, _, _ = construct_ir_fgraph({rv: value})
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
return _logprob_helper(ir_rv, ir_value, **kwargs)


def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
"""Create a graph for the log-CDF of a Random Variable."""
value = pt.as_tensor_variable(value, dtype=rv.dtype)
return _logcdf_helper(rv, value, **kwargs)


def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
"""Create a graph for the inverse CDF of a Random Variable."""
value = pt.as_tensor_variable(value)
return _icdf_helper(rv, value, **kwargs)


def factorized_joint_logprob(
Expand Down
4 changes: 2 additions & 2 deletions pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.extra_ops import CumOp

from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import ignore_logprob

Expand Down Expand Up @@ -72,7 +72,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs):
axis=op.axis,
)

cumsum_logp = logprob(base_rv, value_diff)
cumsum_logp = _logprob_helper(base_rv, value_diff)

return cumsum_logp

Expand Down
10 changes: 5 additions & 5 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.rewriting import (
local_lift_DiracDelta,
logprob_rewrites_db,
Expand Down Expand Up @@ -445,7 +445,7 @@ def logprob_MixtureRV(
# this intentional one-off?
rv_m = rv_pull_down(rv[m_indices] if m_indices else rv)
val_m = value[idx_m_on_axis]
logp_m = logprob(rv_m, val_m)
logp_m = _logprob_helper(rv_m, val_m)
logp_val = pt.set_subtensor(logp_val[idx_m_on_axis], logp_m)

else:
Expand All @@ -463,7 +463,7 @@ def logprob_MixtureRV(

logp_val = 0.0
for i, comp_rv in enumerate(comp_rvs):
comp_logp = logprob(comp_rv, value)
comp_logp = _logprob_helper(comp_rv, value)
if join_axis_val is not None:
comp_logp = pt.squeeze(comp_logp, axis=join_axis_val)
logp_val += ifelse(
Expand Down Expand Up @@ -540,10 +540,10 @@ def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)}

logps_then = [
logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
_logprob_helper(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
]
logps_else = [
logprob(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items()
_logprob_helper(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items()
]

# If the multiple variables depend on each other, we have to replace them
Expand Down
4 changes: 3 additions & 1 deletion pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
_logprob,
get_measurable_outputs,
)
from pymc.logprob.joint_logprob import factorized_joint_logprob
from pymc.logprob.basic import factorized_joint_logprob
from pymc.logprob.rewriting import (
construct_ir_fgraph,
inc_subtensor_ops,
Expand Down Expand Up @@ -351,6 +351,8 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te
# Return only the logp outputs, not any potentially carried states
logp_outputs = logp_scan_out[-len(values) :]

if len(logp_outputs) == 1:
return logp_outputs[0]
return logp_outputs


Expand Down
Loading