Skip to content

Commit 88b4abc

Browse files
Add support for discrete rvs
1 parent 322ea87 commit 88b4abc

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

pymc/logprob/order.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from pytensor.tensor.random.op import RandomVariable
5050
from pytensor.tensor.variable import TensorVariable
5151

52+
import pymc as pm
53+
5254
from pymc.logprob.abstract import (
5355
MeasurableVariable,
5456
_logcdf_helper,
@@ -67,7 +69,7 @@ class MeasurableMax(Max):
6769

6870

6971
class MeasurableMaxDiscrete(Max):
70-
"""A placeholder used to specify a log-likelihood for a cmax sub-graph."""
72+
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""
7173

7274

7375
MeasurableVariable.register(MeasurableMaxDiscrete)
@@ -105,14 +107,14 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
105107
if axis != base_var_dims:
106108
return None
107109

108-
# logprob for discrete distribution
109-
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
110-
measurable_max = MeasurableMaxDiscrete(list(axis))
111-
max_rv_node = measurable_max.make_node(base_var)
112-
max_rv = max_rv_node.outputs
110+
# distinguish measurable discrete and continuous (because logprob is different)
111+
if base_var.owner.op.dtype.startswith("int"):
112+
if isinstance(base_var.owner.op, RandomVariable):
113+
measurable_max = MeasurableMaxDiscrete(list(axis))
114+
max_rv_node = measurable_max.make_node(base_var)
115+
max_rv = max_rv_node.outputs
113116

114-
return max_rv
115-
# logprob for continuous distribution
117+
return max_rv
116118
else:
117119
measurable_max = MeasurableMax(list(axis))
118120
max_rv_node = measurable_max.make_node(base_var)
@@ -148,17 +150,17 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
148150
r"""Compute the log-likelihood graph for the `Max` operation.
149151
150152
The formula that we use here is :
151-
\ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
152-
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively.
153+
.. math::
154+
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
155+
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
153156
"""
154157
(value,) = values
155-
logprob = _logprob_helper(base_rv, value)
156158
logcdf = _logcdf_helper(base_rv, value)
157159
logcdf_prev = _logcdf_helper(base_rv, value - 1)
158160

159-
n = base_rv.size
161+
[n] = constant_fold([base_rv.size])
160162

161-
logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n)
163+
logprob = pm.math.logdiffexp(n * logcdf, n * logcdf_prev)
162164

163165
return logprob
164166

tests/logprob/test_order.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import numpy as np
4040
import pytensor.tensor as pt
4141
import pytest
42+
import scipy.stats as sp
4243

4344
import pymc as pm
4445

@@ -232,23 +233,25 @@ def test_min_non_mul_elemwise_fails():
232233
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
233234
x_min_logprob = logp(x_min, x_min_value)
234235

235-
def test_max_discrete():
236-
x = pm.DiscreteUniform.dist(0, 1, size=(3,))
237-
x.name = "x"
238-
x_max = pt.max(x, axis=-1)
236+
237+
@pytest.mark.parametrize(
238+
"mu, size, value, axis",
239+
[(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)],
240+
)
241+
def test_max_discrete(mu, size, value, axis):
242+
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
243+
x_max = pt.max(x, axis=axis)
239244
x_max_value = pt.scalar("x_max_value")
240245
x_max_logprob = logp(x_max, x_max_value)
241246

242-
discrete_logprob = _logprob_helper(x, x_max_value)
243-
discrete_logcdf = _logcdf_helper(x, x_max_value)
244-
discrete_logcdf_prev = _logcdf_helper(x, x_max_value - 1)
245-
n = x.size
246-
discrete_logprob = pt.log((pt.exp(discrete_logcdf)) ** n - (pt.exp(discrete_logcdf_prev)) ** n)
247+
test_value = value
247248

248-
test_value = 0.85
249+
n = size
250+
exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n
251+
exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n
249252

250253
np.testing.assert_allclose(
251-
discrete_logprob.eval({x_max_value: test_value}),
254+
np.log(exp_rv - exp_rv_prev),
252255
(x_max_logprob.eval({x_max_value: test_value})),
253256
rtol=1e-06,
254257
)

0 commit comments

Comments
 (0)