Skip to content

Commit 7821ebe

Browse files
committed
Separate logp-related dispatch helpers from user facing functions
1 parent 08a627b commit 7821ebe

20 files changed

+191
-131
lines changed

pymc/distributions/continuous.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from pytensor.tensor.random.op import RandomVariable
5757
from pytensor.tensor.var import TensorConstant
5858

59-
from pymc.logprob.abstract import _logprob, logcdf, logprob
59+
from pymc.logprob.abstract import _logcdf_helper, _logprob_helper
6060

6161
try:
6262
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
@@ -722,7 +722,7 @@ def logp(value, mu, sigma, lower, upper):
722722
else:
723723
norm = 0.0
724724

725-
logp = _logprob(normal, (value,), None, None, None, mu, sigma) - norm
725+
logp = _logprob_helper(Normal.dist(mu, sigma), value) - norm
726726

727727
if is_lower_bounded:
728728
logp = pt.switch(value < lower, -np.inf, logp)
@@ -2033,7 +2033,7 @@ def moment(rv, size, loc, beta):
20332033
return beta
20342034

20352035
def logp(value, loc, beta):
2036-
res = pt.log(2) + logprob(Cauchy.dist(loc, beta), value)
2036+
res = pt.log(2) + _logprob_helper(Cauchy.dist(loc, beta), value)
20372037
res = pt.switch(pt.ge(value, loc), res, -np.inf)
20382038
return check_parameters(
20392039
res,
@@ -2342,10 +2342,10 @@ def moment(rv, size, nu):
23422342
return moment
23432343

23442344
def logp(value, nu):
2345-
return logprob(Gamma.dist(alpha=nu / 2, beta=0.5), value)
2345+
return _logprob_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)
23462346

23472347
def logcdf(value, nu):
2348-
return logcdf(Gamma.dist(alpha=nu / 2, beta=0.5), value)
2348+
return _logcdf_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)
23492349

23502350

23512351
# TODO: Remove this once logp for multiplication is working!

pymc/distributions/mixture.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
)
3333
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
3434
from pymc.distributions.transforms import _default_transform
35-
from pymc.logprob.abstract import _logcdf, _logprob, logcdf
36-
from pymc.logprob.basic import logp
35+
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob, _logprob_helper
3736
from pymc.logprob.transforms import IntervalTransform
3837
from pymc.logprob.utils import ignore_logprob
3938
from pymc.util import check_dist_not_registered
@@ -337,10 +336,10 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
337336
if len(components) == 1:
338337
# Need to broadcast value across mixture axis
339338
mix_axis = -components[0].owner.op.ndim_supp - 1
340-
components_logp = logp(components[0], pt.expand_dims(value, mix_axis))
339+
components_logp = _logprob_helper(components[0], pt.expand_dims(value, mix_axis))
341340
else:
342341
components_logp = pt.stack(
343-
[logp(component, value) for component in components],
342+
[_logprob_helper(component, value) for component in components],
344343
axis=-1,
345344
)
346345

@@ -363,10 +362,10 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):
363362
if len(components) == 1:
364363
# Need to broadcast value across mixture axis
365364
mix_axis = -components[0].owner.op.ndim_supp - 1
366-
components_logcdf = logcdf(components[0], pt.expand_dims(value, mix_axis))
365+
components_logcdf = _logcdf_helper(components[0], pt.expand_dims(value, mix_axis))
367366
else:
368367
components_logcdf = pt.stack(
369-
[logcdf(component, value) for component in components],
368+
[_logcdf_helper(component, value) for component in components],
370369
axis=-1,
371370
)
372371

pymc/distributions/truncated.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple
3838
from pymc.distributions.transforms import _default_transform
3939
from pymc.exceptions import TruncationError
40-
from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob, icdf, logcdf
40+
from pymc.logprob.abstract import MeasurableVariable, _logcdf, _logprob
41+
from pymc.logprob.basic import icdf, logcdf
4142
from pymc.math import logdiffexp
4243
from pymc.util import check_dist_not_registered
4344

pymc/logprob/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
3636

37-
from pymc.logprob.abstract import logprob, logcdf # isort: split
38-
39-
from pymc.logprob.basic import factorized_joint_logprob, joint_logp, logp
37+
from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp
4038

4139
# isort: off
4240
# Add rewrites to the DBs
@@ -49,4 +47,8 @@
4947

5048
# isort: on
5149

52-
__all__ = ("logp", "logcdf")
50+
__all__ = (
51+
"logp",
52+
"logcdf",
53+
"icdf",
54+
)

pymc/logprob/abstract.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,6 @@
4848
from pytensor.tensor.random.op import RandomVariable
4949

5050

51-
def logprob(rv_var, *rv_values, **kwargs):
52-
"""Create a graph for the log-probability of a ``RandomVariable``."""
53-
logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs)
54-
55-
for rv_var in rv_values:
56-
if rv_var.name:
57-
logprob.name = f"{rv_var.name}_logprob"
58-
59-
return logprob
60-
61-
62-
def logcdf(rv_var, rv_value, **kwargs):
63-
"""Create a graph for the logcdf of a ``RandomVariable``."""
64-
logcdf = _logcdf(rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs)
65-
66-
if rv_var.name:
67-
logcdf.name = f"{rv_var.name}_logcdf"
68-
69-
return logcdf
70-
71-
72-
def icdf(rv, value, **kwargs):
73-
"""Create a graph for the inverse CDF of a `RandomVariable`."""
74-
rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs)
75-
if rv.name:
76-
rv_icdf.name = f"{rv.name}_icdf"
77-
return rv_icdf
78-
79-
8051
@singledispatch
8152
def _logprob(
8253
op: Op,
@@ -94,6 +65,18 @@ def _logprob(
9465
raise NotImplementedError(f"Logprob method not implemented for {op}")
9566

9667

68+
def _logprob_helper(rv, *values, **kwargs):
69+
"""Helper that calls `_logprob` dispatcher."""
70+
logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
71+
72+
for rv in values:
73+
if rv.name:
74+
logprob.name = f"{rv.name}_logprob"
75+
break
76+
77+
return logprob
78+
79+
9780
@singledispatch
9881
def _logcdf(
9982
op: Op,
@@ -107,7 +90,17 @@ def _logcdf(
10790
of ``RandomVariable``. If you want to implement new logcdf graphs
10891
for a ``RandomVariable``, register a new function on this dispatcher.
10992
"""
110-
raise NotImplementedError(f"Logcdf method not implemented for {op}")
93+
raise NotImplementedError(f"LogCDF method not implemented for {op}")
94+
95+
96+
def _logcdf_helper(rv, value, **kwargs):
97+
"""Helper that calls `_logcdf` dispatcher."""
98+
logcdf = _logcdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)
99+
100+
if rv.name:
101+
logcdf.name = f"{rv.name}_logcdf"
102+
103+
return logcdf
111104

112105

113106
@singledispatch
@@ -122,7 +115,17 @@ def _icdf(
122115
This function dispatches on the type of `op`, which should be a subclass
123116
of `RandomVariable`.
124117
"""
125-
raise NotImplementedError(f"icdf not implemented for {op}")
118+
raise NotImplementedError(f"Inverse CDF method not implemented for {op}")
119+
120+
121+
def _icdf_helper(rv, value, **kwargs):
122+
"""Helper that calls `_icdf` dispatcher."""
123+
rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs)
124+
125+
if rv.name:
126+
rv_icdf.name = f"{rv.name}_icdf"
127+
128+
return rv_icdf
126129

127130

128131
class MeasurableVariable(abc.ABC):

pymc/logprob/basic.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,54 @@
3939
from collections import deque
4040
from typing import Dict, List, Optional, Sequence, Union
4141

42+
import numpy as np
4243
import pytensor
4344
import pytensor.tensor as pt
4445

4546
from pytensor import config
46-
from pytensor.graph.basic import graph_inputs, io_toposort
47+
from pytensor.graph.basic import Variable, graph_inputs, io_toposort
4748
from pytensor.graph.op import compute_test_value
4849
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
4950
from pytensor.tensor.random.op import RandomVariable
5051
from pytensor.tensor.var import TensorVariable
51-
52-
from pymc.logprob.abstract import _logprob, get_measurable_outputs
53-
from pymc.logprob.abstract import logprob as logp_logprob
52+
from typing_extensions import TypeAlias
53+
54+
from pymc.logprob.abstract import (
55+
_icdf_helper,
56+
_logcdf_helper,
57+
_logprob,
58+
_logprob_helper,
59+
get_measurable_outputs,
60+
)
5461
from pymc.logprob.rewriting import construct_ir_fgraph
5562
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
5663
from pymc.logprob.utils import rvs_to_value_vars
5764

65+
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]
66+
5867

59-
def logp(rv: TensorVariable, value: TensorVariable, **kwargs) -> TensorVariable:
68+
def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
6069
"""Return the log-probability graph of a Random Variable"""
6170

6271
value = pt.as_tensor_variable(value, dtype=rv.dtype)
6372
try:
64-
return logp_logprob(rv, value, **kwargs)
73+
return _logprob_helper(rv, value, **kwargs)
6574
except NotImplementedError:
6675
fgraph, _, _ = construct_ir_fgraph({rv: value})
6776
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
68-
return logp_logprob(ir_rv, ir_value, **kwargs)
77+
return _logprob_helper(ir_rv, ir_value, **kwargs)
78+
79+
80+
def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
81+
"""Create a graph for the log-CDF of a Random Variable."""
82+
value = pt.as_tensor_variable(value, dtype=rv.dtype)
83+
return _logcdf_helper(rv, value, **kwargs)
84+
85+
86+
def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
87+
"""Create a graph for the inverse CDF of a Random Variable."""
88+
value = pt.as_tensor_variable(value)
89+
return _icdf_helper(rv, value, **kwargs)
6990

7091

7192
def factorized_joint_logprob(

pymc/logprob/cumsum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from pytensor.graph.rewriting.basic import node_rewriter
4242
from pytensor.tensor.extra_ops import CumOp
4343

44-
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
44+
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
4545
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
4646
from pymc.logprob.utils import ignore_logprob
4747

@@ -72,7 +72,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs):
7272
axis=op.axis,
7373
)
7474

75-
cumsum_logp = logprob(base_rv, value_diff)
75+
cumsum_logp = _logprob_helper(base_rv, value_diff)
7676

7777
return cumsum_logp
7878

pymc/logprob/mixture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
7070
from pytensor.tensor.var import TensorVariable
7171

72-
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
72+
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
7373
from pymc.logprob.rewriting import (
7474
local_lift_DiracDelta,
7575
logprob_rewrites_db,
@@ -445,7 +445,7 @@ def logprob_MixtureRV(
445445
# this intentional one-off?
446446
rv_m = rv_pull_down(rv[m_indices] if m_indices else rv)
447447
val_m = value[idx_m_on_axis]
448-
logp_m = logprob(rv_m, val_m)
448+
logp_m = _logprob_helper(rv_m, val_m)
449449
logp_val = pt.set_subtensor(logp_val[idx_m_on_axis], logp_m)
450450

451451
else:
@@ -463,7 +463,7 @@ def logprob_MixtureRV(
463463

464464
logp_val = 0.0
465465
for i, comp_rv in enumerate(comp_rvs):
466-
comp_logp = logprob(comp_rv, value)
466+
comp_logp = _logprob_helper(comp_rv, value)
467467
if join_axis_val is not None:
468468
comp_logp = pt.squeeze(comp_logp, axis=join_axis_val)
469469
logp_val += ifelse(
@@ -540,10 +540,10 @@ def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
540540
rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)}
541541

542542
logps_then = [
543-
logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
543+
_logprob_helper(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
544544
]
545545
logps_else = [
546-
logprob(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items()
546+
_logprob_helper(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items()
547547
]
548548

549549
# If the multiple variables depend on each other, we have to replace them

pymc/logprob/tensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
local_rv_size_lift,
5151
)
5252

53-
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
53+
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
5454
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
5555
from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars
5656

@@ -137,7 +137,7 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs):
137137
base_rv.name = f"base_rv[{i}]"
138138
value.name = f"value[{i}]"
139139

140-
logps = [logprob(base_rv, value) for base_rv, value in base_rvs_to_values.items()]
140+
logps = [_logprob_helper(base_rv, value) for base_rv, value in base_rvs_to_values.items()]
141141

142142
# If the stacked variables depend on each other, we have to replace them by the respective values
143143
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_values)
@@ -174,7 +174,8 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs):
174174

175175
base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)}
176176
logps = [
177-
logprob(base_var, split_value) for base_var, split_value in base_rvs_to_split_values.items()
177+
_logprob_helper(base_var, split_value)
178+
for base_var, split_value in base_rvs_to_split_values.items()
178179
]
179180

180181
if len({logp.ndim for logp in logps}) != 1:
@@ -271,7 +272,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs):
271272
undo_ds = [original_shuffle.index(i) for i in range(len(original_shuffle))]
272273
value = value.dimshuffle(undo_ds)
273274

274-
raw_logp = logprob(base_var, value)
275+
raw_logp = _logprob_helper(base_var, value)
275276

276277
# Re-apply original dimshuffle, ignoring any support dimensions consumed by
277278
# the logprob function. This assumes that support dimensions are always in

pymc/logprob/transforms.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
MeasurableVariable,
7979
_get_measurable_outputs,
8080
_logprob,
81-
logprob,
81+
_logprob_helper,
8282
)
8383
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
8484
from pymc.logprob.utils import ignore_logprob, walk_model
@@ -369,10 +369,13 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
369369
# Some transformations, like squaring may produce multiple backward values
370370
if isinstance(backward_value, tuple):
371371
input_logprob = pt.logaddexp(
372-
*(logprob(measurable_input, backward_val, **kwargs) for backward_val in backward_value)
372+
*(
373+
_logprob_helper(measurable_input, backward_val, **kwargs)
374+
for backward_val in backward_value
375+
)
373376
)
374377
else:
375-
input_logprob = logprob(measurable_input, backward_value)
378+
input_logprob = _logprob_helper(measurable_input, backward_value)
376379

377380
if input_logprob.ndim < value.ndim:
378381
# Do we just need to sum the jacobian terms across the support dims?

pymc/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434

3535
import pymc as pm
3636

37-
from pymc import Distribution, logcdf, logp
37+
from pymc.distributions.distribution import Distribution
3838
from pymc.distributions.shape_utils import change_dist_size
3939
from pymc.initial_point import make_initial_point_fn
40-
from pymc.logprob import joint_logp
41-
from pymc.logprob.abstract import MeasurableVariable, icdf
40+
from pymc.logprob.abstract import MeasurableVariable
41+
from pymc.logprob.basic import icdf, joint_logp, logcdf, logp
4242
from pymc.logprob.utils import ParameterValueError
4343
from pymc.pytensorf import (
4444
compile_pymc,

0 commit comments

Comments
 (0)