Skip to content

Commit 48e0d98

Browse files
ricardoV94brandonwillard
authored andcommitted
Refactor ZeroInflatedNegativeBinomial
1 parent 7e8d112 commit 48e0d98

File tree

3 files changed

+107
-63
lines changed

3 files changed

+107
-63
lines changed

pymc3/distributions/discrete.py

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def NegBinom(a, m, x):
668668

669669
@classmethod
670670
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
671-
n, p = cls.get_n_p(mu, alpha, p, n)
671+
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
672672
n = at.as_tensor_variable(floatX(n))
673673
p = at.as_tensor_variable(floatX(p))
674674
return super().dist([n, p], *args, **kwargs)
@@ -1482,6 +1482,21 @@ def logcdf(value, psi, n, p):
14821482
)
14831483

14841484

1485+
class ZeroInflatedNegBinomialRV(RandomVariable):
1486+
name = "zero_inflated_neg_binomial"
1487+
ndim_supp = 0
1488+
ndims_params = [0, 0, 0]
1489+
dtype = "int64"
1490+
_print_name = ("ZeroInflatedNegBinom", "\\operatorname{ZeroInflatedNegBinom}")
1491+
1492+
@classmethod
1493+
def rng_fn(cls, rng, psi, n, p, size):
1494+
return rng.negative_binomial(n=n, p=p, size=size) * (rng.random(size=size) < psi)
1495+
1496+
1497+
zero_inflated_neg_binomial = ZeroInflatedNegBinomialRV()
1498+
1499+
14851500
class ZeroInflatedNegativeBinomial(Discrete):
14861501
R"""
14871502
Zero-Inflated Negative binomial log-likelihood.
@@ -1551,50 +1566,17 @@ def ZeroInfNegBinom(a, m, psi, x):
15511566
15521567
"""
15531568

1554-
def __init__(self, psi, mu, alpha, *args, **kwargs):
1555-
super().__init__(*args, **kwargs)
1556-
self.mu = mu = at.as_tensor_variable(floatX(mu))
1557-
self.alpha = alpha = at.as_tensor_variable(floatX(alpha))
1558-
self.psi = psi = at.as_tensor_variable(floatX(psi))
1559-
self.nb = NegativeBinomial.dist(mu, alpha)
1560-
self.mode = self.nb.mode
1569+
rv_op = zero_inflated_neg_binomial
15611570

1562-
def random(self, point=None, size=None):
1563-
r"""
1564-
Draw random values from ZeroInflatedNegativeBinomial distribution.
1565-
1566-
Parameters
1567-
----------
1568-
point: dict, optional
1569-
Dict of variable values on which random values are to be
1570-
conditioned (uses default point if not specified).
1571-
size: int, optional
1572-
Desired size of random sample (returns one sample if not
1573-
specified).
1574-
1575-
Returns
1576-
-------
1577-
array
1578-
"""
1579-
# mu, alpha, psi = draw_values([self.mu, self.alpha, self.psi], point=point, size=size)
1580-
# g = generate_samples(self._random, mu=mu, alpha=alpha, dist_shape=self.shape, size=size)
1581-
# g[g == 0] = np.finfo(float).eps # Just in case
1582-
# g, psi = broadcast_distribution_samples([g, psi], size=size)
1583-
# return stats.poisson.rvs(g) * (np.random.random(g.shape) < psi)
1584-
1585-
def _random(self, mu, alpha, size):
1586-
r"""Wrapper around stats.gamma.rvs that converts NegativeBinomial's
1587-
parametrization to scipy.gamma. All parameter arrays should have
1588-
been broadcasted properly by generate_samples at this point and size is
1589-
the scipy.rvs representation.
1590-
"""
1591-
return stats.gamma.rvs(
1592-
a=alpha,
1593-
scale=mu / alpha,
1594-
size=size,
1595-
)
1571+
@classmethod
1572+
def dist(cls, psi, mu, alpha, *args, **kwargs):
1573+
psi = at.as_tensor_variable(floatX(psi))
1574+
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
1575+
n = at.as_tensor_variable(floatX(n))
1576+
p = at.as_tensor_variable(floatX(p))
1577+
return super().dist([psi, n, p], *args, **kwargs)
15961578

1597-
def logp(self, value):
1579+
def logp(value, psi, n, p):
15981580
r"""
15991581
Calculate log-probability of ZeroInflatedNegativeBinomial distribution at specified value.
16001582
@@ -1608,20 +1590,22 @@ def logp(self, value):
16081590
-------
16091591
TensorVariable
16101592
"""
1611-
alpha = self.alpha
1612-
mu = self.mu
1613-
psi = self.psi
16141593

1615-
logp_other = at.log(psi) + self.nb.logp(value)
1616-
logp_0 = logaddexp(
1617-
at.log1p(-psi), at.log(psi) + alpha * (at.log(alpha) - at.log(alpha + mu))
1594+
return bound(
1595+
at.switch(
1596+
at.gt(value, 0),
1597+
at.log(psi) + NegativeBinomial.logp(value, n, p),
1598+
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
1599+
),
1600+
0 <= value,
1601+
0 <= psi,
1602+
psi <= 1,
1603+
0 < n,
1604+
0 <= p,
1605+
p <= 1,
16181606
)
16191607

1620-
logp_val = at.switch(at.gt(value, 0), logp_other, logp_0)
1621-
1622-
return bound(logp_val, 0 <= value, 0 <= psi, psi <= 1, mu > 0, alpha > 0)
1623-
1624-
def logcdf(self, value):
1608+
def logcdf(value, psi, n, p):
16251609
"""
16261610
Compute the log of the cumulative distribution function for ZeroInflatedNegativeBinomial distribution
16271611
at the specified value.
@@ -1640,13 +1624,14 @@ def logcdf(self, value):
16401624
raise TypeError(
16411625
f"ZeroInflatedNegativeBinomial.logcdf expects a scalar value but received a {np.ndim(value)}-dimensional object."
16421626
)
1643-
psi = self.psi
16441627

16451628
return bound(
1646-
logaddexp(at.log1p(-psi), at.log(psi) + self.nb.logcdf(value)),
1629+
logaddexp(at.log1p(-psi), at.log(psi) + NegativeBinomial.logcdf(value, n, p)),
16471630
0 <= value,
16481631
0 <= psi,
16491632
psi <= 1,
1633+
0 < p,
1634+
p <= 1,
16501635
)
16511636

16521637

pymc3/tests/test_distributions.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,8 +1660,7 @@ def logcdf_fn(value, psi, theta):
16601660
{"theta": Rplus, "psi": Unit},
16611661
)
16621662

1663-
# Too lazy to propagate decimal parameter through the whole chain of deps
1664-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1663+
@pytest.mark.xfail(reason="Test not refactored yet")
16651664
@pytest.mark.xfail(
16661665
condition=(aesara.config.floatX == "float32"),
16671666
reason="Fails on float32 due to inf issues",
@@ -1673,12 +1672,37 @@ def test_zeroinflatednegativebinomial_distribution(self):
16731672
{"mu": Rplusbig, "alpha": Rplusbig, "psi": Unit},
16741673
)
16751674

1676-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1677-
def test_zeroinflatednegativebinomial_logcdf(self):
1675+
def test_zeroinflatednegativebinomial(self):
1676+
def logp_fn(value, psi, mu, alpha):
1677+
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
1678+
if value == 0:
1679+
return np.log((1 - psi) * sp.nbinom.pmf(0, n, p))
1680+
else:
1681+
return np.log(psi * sp.nbinom.pmf(value, n, p))
1682+
1683+
def logcdf_fn(value, psi, mu, alpha):
1684+
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha)
1685+
return np.log((1 - psi) + psi * sp.nbinom.cdf(value, n, p))
1686+
1687+
self.check_logp(
1688+
ZeroInflatedNegativeBinomial,
1689+
Nat,
1690+
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
1691+
logp_fn,
1692+
)
1693+
1694+
self.check_logcdf(
1695+
ZeroInflatedNegativeBinomial,
1696+
Nat,
1697+
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
1698+
logcdf_fn,
1699+
n_samples=10,
1700+
)
1701+
16781702
self.check_selfconsistency_discrete_logcdf(
16791703
ZeroInflatedNegativeBinomial,
16801704
Nat,
1681-
{"mu": Rplusbig, "alpha": Rplusbig, "psi": Unit},
1705+
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
16821706
n_samples=10,
16831707
)
16841708

pymc3/tests/test_distributions_random.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,7 @@ def seeded_zero_inflated_poisson_rng_fn(self):
955955

956956

957957
class TestZeroInflatedBinomial(BaseTestDistribution):
958-
def zero_inflated_poisson_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
958+
def zero_inflated_binomial_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
959959
return binomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)
960960

961961
def seeded_zero_inflated_binomial_rng_fn(self):
@@ -968,7 +968,7 @@ def seeded_zero_inflated_binomial_rng_fn(self):
968968
)
969969

970970
return functools.partial(
971-
self.zero_inflated_poisson_rng_fn,
971+
self.zero_inflated_binomial_rng_fn,
972972
binomial_rng_fct=binomial_rng_fct,
973973
random_rng_fct=random_rng_fct,
974974
)
@@ -985,6 +985,41 @@ def seeded_zero_inflated_binomial_rng_fn(self):
985985
]
986986

987987

988+
class TestZeroInflatedNegativeBinomial(BaseTestDistribution):
989+
def zero_inflated_negbinomial_rng_fn(
990+
self, size, psi, n, p, negbinomial_rng_fct, random_rng_fct
991+
):
992+
return negbinomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)
993+
994+
def seeded_zero_inflated_negbinomial_rng_fn(self):
995+
negbinomial_rng_fct = functools.partial(
996+
getattr(np.random.RandomState, "negative_binomial"), self.get_random_state()
997+
)
998+
999+
random_rng_fct = functools.partial(
1000+
getattr(np.random.RandomState, "random"), self.get_random_state()
1001+
)
1002+
1003+
return functools.partial(
1004+
self.zero_inflated_negbinomial_rng_fn,
1005+
negbinomial_rng_fct=negbinomial_rng_fct,
1006+
random_rng_fct=random_rng_fct,
1007+
)
1008+
1009+
n, p = pm.NegativeBinomial.get_n_p(mu=3, alpha=5)
1010+
1011+
pymc_dist = pm.ZeroInflatedNegativeBinomial
1012+
pymc_dist_params = {"psi": 0.9, "mu": 3, "alpha": 5}
1013+
expected_rv_op_params = {"psi": 0.9, "n": n, "p": p}
1014+
reference_dist_params = {"psi": 0.9, "n": n, "p": p}
1015+
reference_dist = seeded_zero_inflated_negbinomial_rng_fn
1016+
tests_to_run = [
1017+
"check_pymc_params_match_rv_op",
1018+
"check_pymc_draws_match_reference",
1019+
"check_rv_size",
1020+
]
1021+
1022+
9881023
class TestOrderedLogistic(BaseTestDistribution):
9891024
pymc_dist = pm.OrderedLogistic
9901025
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}

0 commit comments

Comments
 (0)