Skip to content

Commit d52ae50

Browse files
committed
Refactor BetaBinomial
1 parent faa0600 commit d52ae50

File tree

3 files changed

+27
-84
lines changed

3 files changed

+27
-84
lines changed

pymc3/distributions/discrete.py

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from aesara.tensor.random.basic import (
2020
RandomVariable,
2121
bernoulli,
22+
betabinom,
2223
binomial,
2324
categorical,
2425
geometric,
@@ -41,7 +42,7 @@
4142
normal_lcdf,
4243
)
4344
from pymc3.distributions.distribution import Discrete
44-
from pymc3.math import log1mexp, logaddexp, logsumexp, sigmoid, tround
45+
from pymc3.math import log1mexp, logaddexp, logsumexp, sigmoid
4546

4647
__all__ = [
4748
"Binomial",
@@ -227,58 +228,16 @@ def BetaBinom(a, b, n, x):
227228
beta > 0.
228229
"""
229230

230-
def __init__(self, alpha, beta, n, *args, **kwargs):
231-
super().__init__(*args, **kwargs)
232-
self.alpha = alpha = at.as_tensor_variable(floatX(alpha))
233-
self.beta = beta = at.as_tensor_variable(floatX(beta))
234-
self.n = n = at.as_tensor_variable(intX(n))
235-
self.mode = at.cast(tround(alpha / (alpha + beta)), "int8")
236-
237-
def _random(self, alpha, beta, n, size=None):
238-
size = size or ()
239-
p = stats.beta.rvs(a=alpha, b=beta, size=size).flatten()
240-
# Sometimes scipy.beta returns nan. Ugh.
241-
while np.any(np.isnan(p)):
242-
i = np.isnan(p)
243-
p[i] = stats.beta.rvs(a=alpha, b=beta, size=np.sum(i))
244-
# Sigh...
245-
_n, _p, _size = np.atleast_1d(n).flatten(), p.flatten(), p.shape[0]
246-
247-
quotient, remainder = divmod(_p.shape[0], _n.shape[0])
248-
if remainder != 0:
249-
raise TypeError(
250-
"n has a bad size! Was cast to {}, must evenly divide {}".format(
251-
_n.shape[0], _p.shape[0]
252-
)
253-
)
254-
if quotient != 1:
255-
_n = np.tile(_n, quotient)
256-
samples = np.reshape(stats.binom.rvs(n=_n, p=_p, size=_size), size)
257-
return samples
258-
259-
def random(self, point=None, size=None):
260-
r"""
261-
Draw random values from BetaBinomial distribution.
231+
rv_op = betabinom
262232

263-
Parameters
264-
----------
265-
point: dict, optional
266-
Dict of variable values on which random values are to be
267-
conditioned (uses default point if not specified).
268-
size: int, optional
269-
Desired size of random sample (returns one sample if not
270-
specified).
271-
272-
Returns
273-
-------
274-
array
275-
"""
276-
# alpha, beta, n = draw_values([self.alpha, self.beta, self.n], point=point, size=size)
277-
# return generate_samples(
278-
# self._random, alpha=alpha, beta=beta, n=n, dist_shape=self.shape, size=size
279-
# )
233+
@classmethod
234+
def dist(cls, alpha, beta, n, *args, **kwargs):
235+
alpha = at.as_tensor_variable(floatX(alpha))
236+
beta = at.as_tensor_variable(floatX(beta))
237+
n = at.as_tensor_variable(intX(n))
238+
return super().dist([n, alpha, beta], **kwargs)
280239

281-
def logp(self, value):
240+
def logp(value, n, alpha, beta):
282241
r"""
283242
Calculate log-probability of BetaBinomial distribution at specified value.
284243
@@ -292,9 +251,6 @@ def logp(self, value):
292251
-------
293252
TensorVariable
294253
"""
295-
alpha = self.alpha
296-
beta = self.beta
297-
n = self.n
298254
return bound(
299255
binomln(n, value) + betaln(value + alpha, n - value + beta) - betaln(alpha, beta),
300256
value >= 0,
@@ -303,7 +259,7 @@ def logp(self, value):
303259
beta > 0,
304260
)
305261

306-
def logcdf(self, value):
262+
def logcdf(value, n, alpha, beta):
307263
"""
308264
Compute the log of the cumulative distribution function for BetaBinomial distribution
309265
at the specified value.
@@ -323,15 +279,15 @@ def logcdf(self, value):
323279
f"BetaBinomial.logcdf expects a scalar value but received a {np.ndim(value)}-dimensional object."
324280
)
325281

326-
alpha = self.alpha
327-
beta = self.beta
328-
n = self.n
329282
safe_lower = at.switch(at.lt(value, 0), value, 0)
330283

331284
return bound(
332285
at.switch(
333286
at.lt(value, n),
334-
logsumexp(self.logp(at.arange(safe_lower, value + 1)), keepdims=False),
287+
logsumexp(
288+
BetaBinomial.logp(at.arange(safe_lower, value + 1), n, alpha, beta),
289+
keepdims=False,
290+
),
335291
0,
336292
),
337293
0 <= value,

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,8 +1496,7 @@ def test_binomial(self):
14961496
n_samples=10,
14971497
)
14981498

1499-
# Too lazy to propagate decimal parameter through the whole chain of deps
1500-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1499+
@pytest.mark.xfail(reason="checkd tests has not been refactored")
15011500
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
15021501
def test_beta_binomial_distribution(self):
15031502
self.checkd(
@@ -1506,7 +1505,6 @@ def test_beta_binomial_distribution(self):
15061505
{"alpha": Rplus, "beta": Rplus, "n": NatSmall},
15071506
)
15081507

1509-
@pytest.mark.xfail(reason="Distribution not refactored yet")
15101508
@pytest.mark.skipif(
15111509
condition=(SCIPY_VERSION < parse("1.4.0")), reason="betabinom is new in Scipy 1.4.0"
15121510
)
@@ -1518,7 +1516,6 @@ def test_beta_binomial_logp(self):
15181516
lambda value, alpha, beta, n: sp.betabinom.logpmf(value, a=alpha, b=beta, n=n),
15191517
)
15201518

1521-
@pytest.mark.xfail(reason="Distribution not refactored yet")
15221519
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
15231520
@pytest.mark.skipif(
15241521
condition=(SCIPY_VERSION < parse("1.4.0")), reason="betabinom is new in Scipy 1.4.0"
@@ -1531,7 +1528,6 @@ def test_beta_binomial_logcdf(self):
15311528
lambda value, alpha, beta, n: sp.betabinom.logcdf(value, a=alpha, b=beta, n=n),
15321529
)
15331530

1534-
@pytest.mark.xfail(reason="Distribution not refactored yet")
15351531
def test_beta_binomial_selfconsistency(self):
15361532
self.check_selfconsistency_discrete_logcdf(
15371533
BetaBinomial,

pymc3/tests/test_distributions_random.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import functools
1515
import itertools
16-
import sys
1716

1817
from contextlib import ExitStack as does_not_raise
1918
from typing import Callable, List, Optional
@@ -312,12 +311,6 @@ class TestLogitNormal(BaseTestCases.BaseTestCase):
312311
params = {"mu": 0.0, "sigma": 1.0}
313312

314313

315-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
316-
class TestBetaBinomial(BaseTestCases.BaseTestCase):
317-
distribution = pm.BetaBinomial
318-
params = {"n": 5, "alpha": 1.0, "beta": 1.0}
319-
320-
321314
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
322315
class TestConstant(BaseTestCases.BaseTestCase):
323316
distribution = pm.Constant
@@ -893,6 +886,17 @@ def seeded_weibul_rng_fn(self):
893886
]
894887

895888

889+
class TestBetaBinomial(BaseTestDistribution):
890+
pymc_dist = pm.BetaBinomial
891+
pymc_dist_params = {"alpha": 2.0, "beta": 1.0, "n": 5}
892+
expected_rv_op_params = {"n": 5, "alpha": 2.0, "beta": 1.0}
893+
reference_dist_params = {"n": 5, "a": 2.0, "b": 1.0}
894+
tests_to_run = [
895+
"check_pymc_params_match_rv_op",
896+
"check_rv_size",
897+
]
898+
899+
896900
class TestScalarParameterSamples(SeededTest):
897901
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
898902
def test_bounded(self):
@@ -1002,19 +1006,6 @@ def test_half_flat(self):
10021006
with pytest.raises(ValueError):
10031007
f.random(1)
10041008

1005-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1006-
@pytest.mark.xfail(
1007-
sys.platform.startswith("win"),
1008-
reason="Known issue: https://github.com/pymc-devs/pymc3/pull/4269",
1009-
)
1010-
def test_beta_binomial(self):
1011-
pymc3_random_discrete(
1012-
pm.BetaBinomial, {"n": Nat, "alpha": Rplus, "beta": Rplus}, ref_rand=self._beta_bin
1013-
)
1014-
1015-
def _beta_bin(self, n, alpha, beta, size=None):
1016-
return st.binom.rvs(n, st.beta.rvs(a=alpha, b=beta, size=size))
1017-
10181009
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
10191010
def test_discrete_uniform(self):
10201011
def ref_rand(size, lower, upper):

0 commit comments

Comments
 (0)