diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6064e80af8..7a0a87dc92 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -90,10 +90,10 @@ jobs: tests/backends/test_mcbackend.py tests/distributions/test_truncated.py tests/logprob/test_abstract.py + tests/logprob/test_basic.py tests/logprob/test_censoring.py tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py - tests/logprob/test_joint_logprob.py tests/logprob/test_mixture.py tests/logprob/test_rewriting.py tests/logprob/test_scan.py diff --git a/pymc/distributions/bound.py b/pymc/distributions/bound.py index e8eb698637..a28b3647ad 100644 --- a/pymc/distributions/bound.py +++ b/pymc/distributions/bound.py @@ -25,7 +25,7 @@ from pymc.distributions.distribution import Continuous, Discrete from pymc.distributions.shape_utils import to_tuple from pymc.distributions.transforms import _default_transform -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.logprob.utils import ignore_logprob from pymc.model import modelcontext from pymc.pytensorf import floatX, intX diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index b45eaeb96f..8688cb04c2 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -56,7 +56,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.var import TensorConstant -from pymc.logprob.abstract import _logprob, logcdf, logprob +from pymc.logprob.abstract import _logcdf_helper, _logprob_helper try: from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma @@ -722,7 +722,7 @@ def logp(value, mu, sigma, lower, upper): else: norm = 0.0 - logp = _logprob(normal, (value,), None, None, None, mu, sigma) - norm + logp = _logprob_helper(Normal.dist(mu, sigma), value) - norm if is_lower_bounded: logp = pt.switch(value < lower, -np.inf, logp) @@ -2033,7 +2033,7 @@ def moment(rv, size, loc, beta): return beta def logp(value, loc, beta): - res = pt.log(2) + logprob(Cauchy.dist(loc, beta), value) + res = pt.log(2) + _logprob_helper(Cauchy.dist(loc, beta), value) res = pt.switch(pt.ge(value, loc), res, -np.inf) return check_parameters( res, @@ -2342,10 +2342,10 @@ def moment(rv, size, nu): return moment def logp(value, nu): - return logprob(Gamma.dist(alpha=nu / 2, beta=0.5), value) + return _logprob_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value) def logcdf(value, nu): - return logcdf(Gamma.dist(alpha=nu / 2, beta=0.5), value) + return _logcdf_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value) # TODO: Remove this once logp for multiplication is working! diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index fa39ac0370..8e52c812d1 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -48,7 +48,7 @@ from pymc.distributions.distribution import Discrete from pymc.distributions.mixture import Mixture from pymc.distributions.shape_utils import rv_size_is_none -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.math import sigmoid from pymc.pytensorf import floatX, intX from pymc.vartypes import continuous_types diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 92045de68c..769dd3a671 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -32,8 +32,7 @@ ) from pymc.distributions.shape_utils import _change_dist_size, change_dist_size from pymc.distributions.transforms import _default_transform -from pymc.logprob.abstract import _logcdf, _logprob, logcdf -from pymc.logprob.joint_logprob import logp +from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob, _logprob_helper from pymc.logprob.transforms import IntervalTransform from pymc.logprob.utils import ignore_logprob from pymc.util import check_dist_not_registered @@ -337,10 +336,10 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs): if len(components) == 1: # Need to broadcast value across mixture axis mix_axis = -components[0].owner.op.ndim_supp - 1 - components_logp = logp(components[0], pt.expand_dims(value, mix_axis)) + components_logp = _logprob_helper(components[0], pt.expand_dims(value, mix_axis)) else: components_logp = pt.stack( - [logp(component, value) for component in components], + [_logprob_helper(component, value) for component in components], axis=-1, ) @@ -363,10 +362,10 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs): if len(components) == 1: # Need to broadcast value across mixture axis mix_axis = -components[0].owner.op.ndim_supp - 1 - components_logcdf = logcdf(components[0], pt.expand_dims(value, mix_axis)) + components_logcdf = _logcdf_helper(components[0], pt.expand_dims(value, mix_axis)) else: components_logcdf = pt.stack( - [logcdf(component, value) for component in components], + [_logcdf_helper(component, value) for component in components], axis=-1, ) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index a8e9f566a9..1cee128881 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -42,7 +42,7 @@ ) from pymc.exceptions import NotConstantValueError from pymc.logprob.abstract import _logprob -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.logprob.utils import ignore_logprob, reconsider_logprob from pymc.pytensorf import constant_fold, floatX, intX from pymc.util import check_dist_not_registered diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index d989e9dd0a..05616fbe4c 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -37,7 +37,8 @@ from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple from pymc.distributions.transforms import _default_transform from pymc.exceptions import TruncationError -from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob, icdf, logcdf +from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob +from pymc.logprob.basic import icdf, logcdf from pymc.math import logdiffexp from pymc.util import check_dist_not_registered diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 06b0c78edd..0b4bfc3385 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -34,9 +34,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from pymc.logprob.abstract import logprob, logcdf # isort: split - -from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logp, logp +from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp # isort: off # Add rewrites to the DBs @@ -49,4 +47,8 @@ # isort: on -__all__ = ("logp", "logcdf") +__all__ = ( + "logp", + "logcdf", + "icdf", +) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index ff308773b8..52ba5149d5 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -48,35 +48,6 @@ from pytensor.tensor.random.op import RandomVariable -def logprob(rv_var, *rv_values, **kwargs): - """Create a graph for the log-probability of a ``RandomVariable``.""" - logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs) - - for rv_var in rv_values: - if rv_var.name: - logprob.name = f"{rv_var.name}_logprob" - - return logprob - - -def logcdf(rv_var, rv_value, **kwargs): - """Create a graph for the logcdf of a ``RandomVariable``.""" - logcdf = _logcdf(rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs) - - if rv_var.name: - logcdf.name = f"{rv_var.name}_logcdf" - - return logcdf - - -def icdf(rv, value, **kwargs): - """Create a graph for the inverse CDF of a `RandomVariable`.""" - rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs) - if rv.name: - rv_icdf.name = f"{rv.name}_icdf" - return rv_icdf - - @singledispatch def _logprob( op: Op, @@ -94,6 +65,18 @@ def _logprob( raise NotImplementedError(f"Logprob method not implemented for {op}") +def _logprob_helper(rv, *values, **kwargs): + """Helper that calls `_logprob` dispatcher.""" + logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs) + + for rv in values: + if rv.name: + logprob.name = f"{rv.name}_logprob" + break + + return logprob + + @singledispatch def _logcdf( op: Op, @@ -107,7 +90,17 @@ def _logcdf( of ``RandomVariable``. If you want to implement new logcdf graphs for a ``RandomVariable``, register a new function on this dispatcher. """ - raise NotImplementedError(f"Logcdf method not implemented for {op}") + raise NotImplementedError(f"LogCDF method not implemented for {op}") + + +def _logcdf_helper(rv, value, **kwargs): + """Helper that calls `_logcdf` dispatcher.""" + logcdf = _logcdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs) + + if rv.name: + logcdf.name = f"{rv.name}_logcdf" + + return logcdf @singledispatch @@ -122,7 +115,17 @@ def _icdf( This function dispatches on the type of `op`, which should be a subclass of `RandomVariable`. """ - raise NotImplementedError(f"icdf not implemented for {op}") + raise NotImplementedError(f"Inverse CDF method not implemented for {op}") + + +def _icdf_helper(rv, value, **kwargs): + """Helper that calls `_icdf` dispatcher.""" + rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs) + + if rv.name: + rv_icdf.name = f"{rv.name}_icdf" + + return rv_icdf class MeasurableVariable(abc.ABC): diff --git a/pymc/logprob/joint_logprob.py b/pymc/logprob/basic.py similarity index 90% rename from pymc/logprob/joint_logprob.py rename to pymc/logprob/basic.py index 0c5de128f6..f65e72f7c1 100644 --- a/pymc/logprob/joint_logprob.py +++ b/pymc/logprob/basic.py @@ -39,40 +39,54 @@ from collections import deque from typing import Dict, List, Optional, Sequence, Union +import numpy as np import pytensor import pytensor.tensor as pt from pytensor import config -from pytensor.graph.basic import graph_inputs, io_toposort +from pytensor.graph.basic import Variable, graph_inputs, io_toposort from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.var import TensorVariable - -from pymc.logprob.abstract import _logprob, get_measurable_outputs -from pymc.logprob.abstract import logprob as logp_logprob +from typing_extensions import TypeAlias + +from pymc.logprob.abstract import ( + _icdf_helper, + _logcdf_helper, + _logprob, + _logprob_helper, + get_measurable_outputs, +) from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.transforms import RVTransform, TransformValuesRewrite from pymc.logprob.utils import rvs_to_value_vars +TensorLike: TypeAlias = Union[Variable, float, np.ndarray] + -def logp(rv: TensorVariable, value) -> TensorVariable: +def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: """Return the log-probability graph of a Random Variable""" value = pt.as_tensor_variable(value, dtype=rv.dtype) try: - return logp_logprob(rv, value) + return _logprob_helper(rv, value, **kwargs) except NotImplementedError: - try: - value = rv.type.filter_variable(value) - except TypeError as exc: - raise TypeError( - "When RV is not a pure distribution, value variable must have the same type" - ) from exc - try: - return factorized_joint_logprob({rv: value}, warn_missing_rvs=False)[value] - except Exception as exc: - raise NotImplementedError("PyMC could not infer logp of input variable.") from exc + fgraph, _, _ = construct_ir_fgraph({rv: value}) + [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() + return _logprob_helper(ir_rv, ir_value, **kwargs) + + +def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: + """Create a graph for the log-CDF of a Random Variable.""" + value = pt.as_tensor_variable(value, dtype=rv.dtype) + return _logcdf_helper(rv, value, **kwargs) + + +def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: + """Create a graph for the inverse CDF of a Random Variable.""" + value = pt.as_tensor_variable(value) + return _icdf_helper(rv, value, **kwargs) def factorized_joint_logprob( diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index bd3c5ba341..1ae8a0b60c 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -41,7 +41,7 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.extra_ops import CumOp -from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import ignore_logprob @@ -72,7 +72,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs): axis=op.axis, ) - cumsum_logp = logprob(base_rv, value_diff) + cumsum_logp = _logprob_helper(base_rv, value_diff) return cumsum_logp diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index ec9b8970d0..af331ce637 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -69,7 +69,7 @@ from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType from pytensor.tensor.var import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper from pymc.logprob.rewriting import ( local_lift_DiracDelta, logprob_rewrites_db, @@ -445,7 +445,7 @@ def logprob_MixtureRV( # this intentional one-off? rv_m = rv_pull_down(rv[m_indices] if m_indices else rv) val_m = value[idx_m_on_axis] - logp_m = logprob(rv_m, val_m) + logp_m = _logprob_helper(rv_m, val_m) logp_val = pt.set_subtensor(logp_val[idx_m_on_axis], logp_m) else: @@ -463,7 +463,7 @@ def logprob_MixtureRV( logp_val = 0.0 for i, comp_rv in enumerate(comp_rvs): - comp_logp = logprob(comp_rv, value) + comp_logp = _logprob_helper(comp_rv, value) if join_axis_val is not None: comp_logp = pt.squeeze(comp_logp, axis=join_axis_val) logp_val += ifelse( @@ -540,10 +540,10 @@ def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs): rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)} logps_then = [ - logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items() + _logprob_helper(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items() ] logps_else = [ - logprob(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items() + _logprob_helper(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items() ] # If the multiple variables depend on each other, we have to replace them diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 2298b9d038..aaa2d69f73 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -59,7 +59,7 @@ _logprob, get_measurable_outputs, ) -from pymc.logprob.joint_logprob import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.rewriting import ( construct_ir_fgraph, inc_subtensor_ops, @@ -351,6 +351,8 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te # Return only the logp outputs, not any potentially carried states logp_outputs = logp_scan_out[-len(values) :] + if len(logp_outputs) == 1: + return logp_outputs[0] return logp_outputs diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index be5b05e2f3..6ca11b65f4 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -50,7 +50,7 @@ local_rv_size_lift, ) -from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars @@ -137,7 +137,7 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs): base_rv.name = f"base_rv[{i}]" value.name = f"value[{i}]" - logps = [logprob(base_rv, value) for base_rv, value in base_rvs_to_values.items()] + logps = [_logprob_helper(base_rv, value) for base_rv, value in base_rvs_to_values.items()] # If the stacked variables depend on each other, we have to replace them by the respective values logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_values) @@ -174,7 +174,8 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)} logps = [ - logprob(base_var, split_value) for base_var, split_value in base_rvs_to_split_values.items() + _logprob_helper(base_var, split_value) + for base_var, split_value in base_rvs_to_split_values.items() ] if len({logp.ndim for logp in logps}) != 1: @@ -271,7 +272,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): undo_ds = [original_shuffle.index(i) for i in range(len(original_shuffle))] value = value.dimshuffle(undo_ds) - raw_logp = logprob(base_var, value) + raw_logp = _logprob_helper(base_var, value) # Re-apply original dimshuffle, ignoring any support dimensions consumed by # the logprob function. This assumes that support dimensions are always in diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 82aac0fa88..1190d7ce03 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -78,7 +78,7 @@ MeasurableVariable, _get_measurable_outputs, _logprob, - logprob, + _logprob_helper, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import ignore_logprob, walk_model @@ -369,10 +369,13 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa # Some transformations, like squaring may produce multiple backward values if isinstance(backward_value, tuple): input_logprob = pt.logaddexp( - *(logprob(measurable_input, backward_val, **kwargs) for backward_val in backward_value) + *( + _logprob_helper(measurable_input, backward_val, **kwargs) + for backward_val in backward_value + ) ) else: - input_logprob = logprob(measurable_input, backward_value) + input_logprob = _logprob_helper(measurable_input, backward_value) if input_logprob.ndim < value.ndim: # Do we just need to sum the jacobian terms across the support dims? diff --git a/pymc/model.py b/pymc/model.py index 08a75db4c6..2187de1da4 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -60,7 +60,7 @@ ShapeWarning, ) from pymc.initial_point import make_initial_point_fn -from pymc.logprob.joint_logprob import joint_logp +from pymc.logprob.basic import joint_logp from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, diff --git a/pymc/testing.py b/pymc/testing.py index 3bb222222f..86f2910a2b 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -34,11 +34,11 @@ import pymc as pm -from pymc import Distribution, logcdf, logp +from pymc.distributions.distribution import Distribution from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn -from pymc.logprob import joint_logp -from pymc.logprob.abstract import MeasurableVariable, icdf +from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.basic import icdf, joint_logp, logcdf, logp from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( compile_pymc, diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index c5d2a85aca..0ee4b060ca 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -19,8 +19,12 @@ from pytensor.graph import Apply, Op from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable -from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob -from pymc.logprob.abstract import logprob as logprob_logprob +from pymc.logprob.abstract import ( + MeasurableVariable, + _get_measurable_outputs, + _logprob, + _logprob_helper, +) from pymc.logprob.utils import ignore_logprob @@ -110,4 +114,4 @@ def _get_measurable_outputs_minibatch_random_variable(op, node): def minibatch_rv_logprob(op, values, *inputs, **kwargs): [value] = values rv, *total_size = inputs - return logprob_logprob(rv, value, **kwargs) * get_scaling(total_size, value.shape) + return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 3774ad8333..2ad9c8a6f9 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -30,7 +30,7 @@ pymc/distributions/truncated.py pymc/initial_point.py pymc/logprob/censoring.py -pymc/logprob/joint_logprob.py +pymc/logprob/basic.py pymc/logprob/mixture.py pymc/logprob/rewriting.py pymc/logprob/scan.py diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 8b4484a66c..8857a569fa 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -28,8 +28,7 @@ from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated from pymc.distributions.dist_math import clipped_beta_rvs -from pymc.logprob.abstract import icdf, logcdf -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import icdf, logcdf, logp from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import floatX from pymc.testing import ( diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 78dbd7999b..6233adf3a4 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -29,8 +29,7 @@ import pymc as pm from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit -from pymc.logprob.abstract import icdf, logcdf -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import icdf, logcdf, logp from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import floatX from pymc.testing import ( diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index cc8c505e38..50e0b30a62 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -46,8 +46,8 @@ from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple from pymc.distributions.transforms import log from pymc.exceptions import BlockModelAccessError -from pymc.logprob.abstract import get_measurable_outputs, logcdf -from pymc.logprob.joint_logprob import logp +from pymc.logprob.abstract import get_measurable_outputs +from pymc.logprob.basic import logcdf, logp from pymc.model import Model from pymc.sampling import draw, sample from pymc.testing import assert_moment_is_expected diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index c0eff7679f..df41f8b071 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -50,7 +50,7 @@ from pymc.distributions.mixture import MixtureTransformWarning from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.distributions.transforms import _default_transform -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.logprob.transforms import IntervalTransform, LogTransform, SimplexTransform from pymc.math import expand_packed_triangular from pymc.model import Model diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index a1ee84f98b..671fe4fa79 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -37,7 +37,7 @@ quaddist_matrix, ) from pymc.distributions.shape_utils import change_dist_size, to_tuple -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.logprob.utils import ParameterValueError from pymc.math import kronecker from pymc.pytensorf import compile_pymc, floatX, intX diff --git a/tests/distributions/test_timeseries.py b/tests/distributions/test_timeseries.py index 3a04577800..4e19b0bc0c 100644 --- a/tests/distributions/test_timeseries.py +++ b/tests/distributions/test_timeseries.py @@ -39,7 +39,7 @@ MvStudentTRandomWalk, RandomWalk, ) -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.model import Model from pymc.pytensorf import floatX from pymc.sampling.forward import draw, sample_posterior_predictive diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 816f429ea0..a29ab16679 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -25,7 +25,7 @@ import pymc as pm import pymc.distributions.transforms as tr -from pymc.logprob.joint_logprob import joint_logp +from pymc.logprob.basic import joint_logp from pymc.pytensorf import floatX, jacobian from pymc.testing import ( Circ, diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index 4cab6d7d86..7502260dc8 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -26,7 +26,7 @@ from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated from pymc.exceptions import TruncationError from pymc.logprob.abstract import _icdf -from pymc.logprob.joint_logprob import logp +from pymc.logprob.basic import logp from pymc.logprob.transforms import IntervalTransform from pymc.logprob.utils import ParameterValueError from pymc.testing import assert_moment_is_expected diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 24a2fe0ef0..21ab0c82d8 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -51,10 +51,11 @@ MeasurableVariable, UnmeasurableVariable, _get_measurable_outputs, + _logcdf_helper, assign_custom_measurable_outputs, - logcdf, noop_measurable_outputs_fn, ) +from pymc.logprob.basic import logcdf def assert_equal_hash(classA, classB): @@ -158,10 +159,10 @@ def test_logcdf_helper(): value = pt.vector("value") x = pm.Normal.dist(0, 1) - x_logcdf = logcdf(x, value) + x_logcdf = _logcdf_helper(x, value) np.testing.assert_almost_equal(x_logcdf.eval({value: [0, 1]}), sp.norm(0, 1).logcdf([0, 1])) - x_logcdf = logcdf(x, [0, 1]) + x_logcdf = _logcdf_helper(x, [0, 1]) np.testing.assert_almost_equal(x_logcdf.eval(), sp.norm(0, 1).logcdf([0, 1])) diff --git a/tests/logprob/test_joint_logprob.py b/tests/logprob/test_basic.py similarity index 87% rename from tests/logprob/test_joint_logprob.py rename to tests/logprob/test_basic.py index f2d3276bd0..59abd5fecf 100644 --- a/tests/logprob/test_joint_logprob.py +++ b/tests/logprob/test_basic.py @@ -55,8 +55,7 @@ import pymc as pm -from pymc.logprob.abstract import logprob -from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logp +from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp from pymc.logprob.utils import rvs_to_value_vars, walk_model from pymc.testing import assert_no_rvs from tests.logprob.utils import joint_logprob @@ -69,7 +68,7 @@ def test_joint_logprob_basic(): a_value_var = a.clone() a_logp = joint_logprob({a: a_value_var}, sum=False) - a_logp_exp = logprob(a, a_value_var) + a_logp_exp = logp(a, a_value_var) assert equal_computations([a_logp], [a_logp_exp]) @@ -84,12 +83,12 @@ def test_joint_logprob_basic(): # We need to replace the reference to `sigma` in `Y` with its value # variable - ll_Y = logprob(Y, y_value_var) + ll_Y = logp(Y, y_value_var) (ll_Y,), _ = rvs_to_value_vars( [ll_Y], initial_replacements={sigma: sigma_value_var}, ) - total_ll_exp = logprob(sigma, sigma_value_var) + ll_Y + total_ll_exp = logp(sigma, sigma_value_var) + ll_Y assert equal_computations([total_ll], [total_ll_exp]) @@ -122,10 +121,10 @@ def test_joint_logprob_multi_obs(): a_val = a.clone() b_val = b.clone() - logp = joint_logprob({a: a_val, b: b_val}, sum=False) - logp_exp = logprob(a, a_val) + logprob(b, b_val) + logp_res = joint_logprob({a: a_val, b: b_val}, sum=False) + logp_exp = logp(a, a_val) + logp(b, b_val) - assert equal_computations([logp], [logp_exp]) + assert equal_computations([logp_res], [logp_exp]) x = pt.random.normal(0, 1) y = pt.random.normal(x, 1) @@ -133,10 +132,10 @@ def test_joint_logprob_multi_obs(): x_val = x.clone() y_val = y.clone() - logp = joint_logprob({x: x_val, y: y_val}) + logp_res = joint_logprob({x: x_val, y: y_val}) exp_logp = joint_logprob({x: x_val, y: y_val}) - assert equal_computations([logp], [exp_logp]) + assert equal_computations([logp_res], [exp_logp]) def test_joint_logprob_diff_dims(): @@ -357,32 +356,6 @@ def test_joint_logp_incsubtensor(indices, size): np.testing.assert_almost_equal(logp_vals, exp_obs_logps) -def test_logp_helper(): - value = pt.vector("value") - x = pm.Normal.dist(0, 1) - - x_logp = pm.logp(x, value) - np.testing.assert_almost_equal(x_logp.eval({value: [0, 1]}), sp.norm(0, 1).logpdf([0, 1])) - - x_logp = pm.logp(x, [0, 1]) - np.testing.assert_almost_equal(x_logp.eval(), sp.norm(0, 1).logpdf([0, 1])) - - -def test_logp_helper_derived_rv(): - assert np.isclose( - pm.logp(pt.exp(pm.Normal.dist()), 5).eval(), - pm.logp(pm.LogNormal.dist(), 5).eval(), - ) - - -def test_logp_helper_exceptions(): - with pytest.raises(TypeError, match="When RV is not a pure distribution"): - pm.logp(pt.exp(pm.Normal.dist()), [1, 2]) - - with pytest.raises(NotImplementedError, match="PyMC could not infer logp of input variable"): - pm.logp(pt.cos(pm.Normal.dist()), 1) - - def test_model_unchanged_logprob_access(): # Issue #5007 with pm.Model() as model: @@ -430,3 +403,57 @@ def test_hierarchical_obs_logp(): ops = {a.owner.op for a in logp_ancestors if a.owner} assert len(ops) > 0 assert not any(isinstance(o, RandomVariable) for o in ops) + + +@pytest.mark.parametrize( + "func, scipy_func", + [ + (logp, "logpdf"), + (logcdf, "logcdf"), + (icdf, "ppf"), + ], +) +def test_probability_direct_dispatch(func, scipy_func): + value = pt.vector("value") + x = pm.Normal.dist(0, 1) + + np.testing.assert_almost_equal( + func(x, value).eval({value: [0, 1]}), + getattr(sp.norm(0, 1), scipy_func)([0, 1]), + ) + + np.testing.assert_almost_equal( + func(x, [0, 1]).eval(), + getattr(sp.norm(0, 1), scipy_func)([0, 1]), + ) + + +@pytest.mark.parametrize( + "func, scipy_func, test_value", + [ + (logp, "logpdf", 5.0), + pytest.param(logcdf, "logcdf", 5.0, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(icdf, "ppf", 0.7, marks=pytest.mark.xfail(raises=NotImplementedError)), + ], +) +def test_probability_inference(func, scipy_func, test_value): + assert np.isclose( + func(pt.exp(pm.Normal.dist()), test_value).eval(), + getattr(sp.lognorm(s=1), scipy_func)(test_value), + ) + + +@pytest.mark.parametrize( + "func, func_name", + [ + (logp, "Logprob"), + (logcdf, "LogCDF"), + (icdf, "Inverse CDF"), + ], +) +def test_probability_inference_fails(func, func_name): + with pytest.raises( + NotImplementedError, + match=f"{func_name} method not implemented for Elemwise{{cos,no_inplace}}", + ): + func(pt.cos(pm.Normal.dist()), 1) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 4e587c2f70..6534b22d85 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -47,7 +47,7 @@ from pytensor.tensor.shape import shape_tuple from pytensor.tensor.subtensor import as_index_constant -from pymc.logprob.joint_logprob import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.mixture import MixtureRV, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.utils import dirac_delta diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 7259fa80c8..b617edb49a 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -44,8 +44,8 @@ from pytensor.scan.utils import ScanArgs from scipy import stats -from pymc.logprob.abstract import logprob -from pymc.logprob.joint_logprob import factorized_joint_logprob, logp +from pymc.logprob.abstract import _logprob_helper +from pymc.logprob.basic import factorized_joint_logprob, logp from pymc.logprob.scan import ( construct_scan, convert_outer_out_to_in, @@ -63,7 +63,7 @@ def create_inner_out_logp(value_map): """ res = [] for old_inner_out_var, new_inner_in_var in value_map.items(): - logp = logprob(old_inner_out_var, new_inner_in_var) + logp = _logprob_helper(old_inner_out_var, new_inner_in_var) if new_inner_in_var.name: logp.name = f"logp({new_inner_in_var.name})" res.append(logp) @@ -134,7 +134,7 @@ def output_step_fn(y_t, y_tm1, mu_tm1): y_tm1.name = "y_tm1" mu = mu_tm1 + y_tm1 + 1 mu.name = "mu_t" - logp = logprob(pt.random.normal(mu, 1.0), y_t) + logp = _logprob_helper(pt.random.normal(mu, 1.0), y_t) logp.name = "logp" return mu, logp @@ -233,7 +233,7 @@ def output_step_fn(y_t, y_tm1, y_tm2): y_t.name = "y_t" y_tm1.name = "y_tm1" y_tm2.name = "y_tm2" - logp = logprob(pt.random.normal(y_tm1 + y_tm2, 1.0), y_t) + logp = _logprob_helper(pt.random.normal(y_tm1 + y_tm2, 1.0), y_t) logp.name = "logp(y_t)" return logp @@ -359,7 +359,7 @@ def scan_fn(mus_t, sigma_t, Gamma_t): def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t): S_t = pt.random.categorical(Gamma_t[0], name="S_t") Y_t = pt.random.normal(mus_t[S_t_val], sigma_t, name="Y_t") - Y_t_logp, S_t_logp = logprob(Y_t, Y_t_val), logprob(S_t, S_t_val) + Y_t_logp, S_t_logp = _logprob_helper(Y_t, Y_t_val), _logprob_helper(S_t, S_t_val) Y_t_logp.name = "log(Y_t=y_t)" S_t_logp.name = "log(S_t=s_t)" return Y_t_logp, S_t_logp @@ -375,7 +375,7 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t): Y_rv_logp.name = "logp(Y=y)" S_rv_logp.name = "logp(S=s)" - Gamma_logp = logprob(Gamma_rv, Gamma_vv) + Gamma_logp = _logprob_helper(Gamma_rv, Gamma_vv) y_logp_ref = Y_rv_logp.sum() + S_rv_logp.sum() + Gamma_logp.sum() diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 29c8dc0ea7..54d59f39d2 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -49,7 +49,7 @@ from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob -from pymc.logprob.joint_logprob import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.transforms import ( ChainedTransform, ExpTransform, diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index ddee0c82cf..363e94c76e 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -47,8 +47,8 @@ import pymc as pm -from pymc.logprob.abstract import MeasurableVariable, get_measurable_outputs, logprob -from pymc.logprob.joint_logprob import joint_logp +from pymc.logprob.abstract import MeasurableVariable, get_measurable_outputs +from pymc.logprob.basic import joint_logp, logp from pymc.logprob.utils import ( ParameterValueError, dirac_delta, @@ -163,7 +163,7 @@ def test_CheckParameter(): sigma = pt.scalar("sigma") x_rv = pt.random.normal(mu, sigma, name="x") x_vv = pt.constant(0) - x_logp = logprob(x_rv, x_vv) + x_logp = logp(x_rv, x_vv) x_logp_fn = function([sigma], x_logp) with pytest.raises(ParameterValueError, match="sigma > 0"): diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 5a1c9c1656..2218d06044 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -42,8 +42,8 @@ from pytensor.tensor.var import TensorVariable from scipy import stats as stats -from pymc.logprob import factorized_joint_logprob -from pymc.logprob.abstract import get_measurable_outputs, icdf, logcdf, logprob +from pymc.logprob import factorized_joint_logprob, icdf, logcdf, logp +from pymc.logprob.abstract import get_measurable_outputs from pymc.logprob.utils import ignore_logprob @@ -162,7 +162,7 @@ def scipy_logprob_tester( test_fn = getattr(stats, name) if test == "logprob": - pytensor_res = logprob(rv_var, pt.as_tensor(obs)) + pytensor_res = logp(rv_var, pt.as_tensor(obs)) elif test == "logcdf": pytensor_res = logcdf(rv_var, pt.as_tensor(obs)) elif test == "icdf": diff --git a/tests/test_model.py b/tests/test_model.py index 52e69103b2..a8c7dc2c54 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -42,7 +42,7 @@ from pymc.distributions import Normal, transforms from pymc.distributions.transforms import log from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning -from pymc.logprob.joint_logprob import joint_logp +from pymc.logprob.basic import joint_logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Point, ValueGradFunction, modelcontext from pymc.testing import SeededTest