Skip to content

Commit d55ca39

Browse files
committed
Replace ZeroInflated distributions with Mixtures
1 parent 67b3d37 commit d55ca39

File tree

1 file changed

+48
-284
lines changed

1 file changed

+48
-284
lines changed

pymc/distributions/discrete.py

Lines changed: 48 additions & 284 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
normal_lcdf,
4444
)
4545
from pymc.distributions.distribution import Discrete
46-
from pymc.distributions.logprob import logcdf, logp
46+
from pymc.distributions.logprob import logp
47+
from pymc.distributions.mixture import Mixture
4748
from pymc.distributions.shape_utils import rv_size_is_none
4849
from pymc.math import sigmoid
4950
from pymc.vartypes import continuous_types
@@ -1386,22 +1387,24 @@ def logcdf(value, c):
13861387
)
13871388

13881389

1389-
class ZeroInflatedPoissonRV(RandomVariable):
1390-
name = "zero_inflated_poisson"
1391-
ndim_supp = 0
1392-
ndims_params = [0, 0]
1393-
dtype = "int64"
1394-
_print_name = ("ZeroInflatedPois", "\\operatorname{ZeroInflatedPois}")
1395-
1396-
@classmethod
1397-
def rng_fn(cls, rng, psi, lam, size):
1398-
return rng.poisson(lam, size=size) * (rng.random(size=size) < psi)
1399-
1390+
def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
1391+
"""Helper function to create a zero-inflated mixture
14001392
1401-
zero_inflated_poisson = ZeroInflatedPoissonRV()
1402-
1403-
1404-
class ZeroInflatedPoisson(Discrete):
1393+
If name is `None`, this function returns an unregistered variable
1394+
"""
1395+
nonzero_p = at.as_tensor_variable(floatX(nonzero_p))
1396+
weights = at.stack([1 - nonzero_p, nonzero_p], axis=-1)
1397+
comp_dists = [
1398+
Constant.dist(0),
1399+
nonzero_dist,
1400+
]
1401+
if name is not None:
1402+
return Mixture(name, weights, comp_dists, **kwargs)
1403+
else:
1404+
return Mixture.dist(weights, comp_dists, **kwargs)
1405+
1406+
1407+
class ZeroInflatedPoisson:
14051408
R"""
14061409
Zero-inflated Poisson log-likelihood.
14071410
@@ -1452,97 +1455,19 @@ class ZeroInflatedPoisson(Discrete):
14521455
(theta >= 0).
14531456
"""
14541457

1455-
rv_op = zero_inflated_poisson
1456-
1457-
@classmethod
1458-
def dist(cls, psi, theta, *args, **kwargs):
1459-
psi = at.as_tensor_variable(floatX(psi))
1460-
theta = at.as_tensor_variable(floatX(theta))
1461-
return super().dist([psi, theta], *args, **kwargs)
1462-
1463-
def get_moment(rv, size, psi, theta):
1464-
mean = at.floor(psi * theta)
1465-
if not rv_size_is_none(size):
1466-
mean = at.full(size, mean)
1467-
return mean
1468-
1469-
def logp(value, psi, theta):
1470-
r"""
1471-
Calculate log-probability of ZeroInflatedPoisson distribution at specified value.
1472-
1473-
Parameters
1474-
----------
1475-
value: numeric
1476-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1477-
values are desired the values must be provided in a numpy array or Aesara tensor
1478-
1479-
Returns
1480-
-------
1481-
TensorVariable
1482-
"""
1483-
1484-
res = at.switch(
1485-
at.gt(value, 0),
1486-
at.log(psi) + logp(Poisson.dist(mu=theta), value),
1487-
at.logaddexp(at.log1p(-psi), at.log(psi) - theta),
1488-
)
1489-
1490-
res = at.switch(at.lt(value, 0), -np.inf, res)
1491-
1492-
return check_parameters(
1493-
res,
1494-
0 <= psi,
1495-
psi <= 1,
1496-
0 <= theta,
1497-
msg="0 <= psi <= 1, theta >= 0",
1498-
)
1499-
1500-
def logcdf(value, psi, theta):
1501-
"""
1502-
Compute the log of the cumulative distribution function for ZeroInflatedPoisson distribution
1503-
at the specified value.
1504-
1505-
Parameters
1506-
----------
1507-
value: numeric or np.ndarray or aesara.tensor
1508-
Value(s) for which log CDF is calculated. If the log CDF for multiple
1509-
values are desired the values must be provided in a numpy array or Aesara tensor.
1510-
1511-
Returns
1512-
-------
1513-
TensorVariable
1514-
"""
1515-
1516-
res = at.switch(
1517-
at.lt(value, 0),
1518-
-np.inf,
1519-
at.logaddexp(
1520-
at.log1p(-psi),
1521-
at.log(psi) + logcdf(Poisson.dist(mu=theta), value),
1522-
),
1523-
)
1524-
1525-
return check_parameters(
1526-
res, 0 <= psi, psi <= 1, 0 <= theta, msg="0 <= psi <= 1, theta >= 0"
1458+
def __new__(cls, name, psi, theta, **kwargs):
1459+
return _zero_inflated_mixture(
1460+
name=name, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=theta), **kwargs
15271461
)
15281462

1529-
1530-
class ZeroInflatedBinomialRV(RandomVariable):
1531-
name = "zero_inflated_binomial"
1532-
ndim_supp = 0
1533-
ndims_params = [0, 0, 0]
1534-
dtype = "int64"
1535-
_print_name = ("ZeroInflatedBinom", "\\operatorname{ZeroInflatedBinom}")
1536-
15371463
@classmethod
1538-
def rng_fn(cls, rng, psi, n, p, size):
1539-
return rng.binomial(n=n, p=p, size=size) * (rng.random(size=size) < psi)
1540-
1541-
1542-
zero_inflated_binomial = ZeroInflatedBinomialRV()
1464+
def dist(cls, psi, theta, **kwargs):
1465+
return _zero_inflated_mixture(
1466+
name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=theta), **kwargs
1467+
)
15431468

15441469

1545-
class ZeroInflatedBinomial(Discrete):
1470+
class ZeroInflatedBinomial:
15461471
R"""
15471472
Zero-inflated Binomial log-likelihood.
15481473
@@ -1594,110 +1519,19 @@ class ZeroInflatedBinomial(Discrete):
15941519
15951520
"""
15961521

1597-
rv_op = zero_inflated_binomial
1598-
1599-
@classmethod
1600-
def dist(cls, psi, n, p, *args, **kwargs):
1601-
psi = at.as_tensor_variable(floatX(psi))
1602-
n = at.as_tensor_variable(intX(n))
1603-
p = at.as_tensor_variable(floatX(p))
1604-
return super().dist([psi, n, p], *args, **kwargs)
1605-
1606-
def get_moment(rv, size, psi, n, p):
1607-
mean = at.round(psi * n * p)
1608-
if not rv_size_is_none(size):
1609-
mean = at.full(size, mean)
1610-
return mean
1611-
1612-
def logp(value, psi, n, p):
1613-
r"""
1614-
Calculate log-probability of ZeroInflatedBinomial distribution at specified value.
1615-
1616-
Parameters
1617-
----------
1618-
value: numeric
1619-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1620-
values are desired the values must be provided in a numpy array or Aesara tensor
1621-
1622-
Returns
1623-
-------
1624-
TensorVariable
1625-
"""
1626-
1627-
res = at.switch(
1628-
at.gt(value, 0),
1629-
at.log(psi) + logp(Binomial.dist(n=n, p=p), value),
1630-
at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)),
1631-
)
1632-
1633-
res = at.switch(
1634-
at.lt(value, 0),
1635-
-np.inf,
1636-
res,
1637-
)
1638-
1639-
return check_parameters(
1640-
res,
1641-
0 <= psi,
1642-
psi <= 1,
1643-
0 <= p,
1644-
p <= 1,
1645-
msg="0 <= psi <= 1, 0 <= p <= 1",
1646-
)
1647-
1648-
def logcdf(value, psi, n, p):
1649-
"""
1650-
Compute the log of the cumulative distribution function for ZeroInflatedBinomial distribution
1651-
at the specified value.
1652-
1653-
Parameters
1654-
----------
1655-
value: numeric or np.ndarray or aesara.tensor
1656-
Value(s) for which log CDF is calculated. If the log CDF for multiple
1657-
values are desired the values must be provided in a numpy array or Aesara tensor.
1658-
1659-
Returns
1660-
-------
1661-
TensorVariable
1662-
"""
1663-
res = at.switch(
1664-
at.or_(at.lt(value, 0), at.gt(value, n)),
1665-
-np.inf,
1666-
at.logaddexp(
1667-
at.log1p(-psi),
1668-
at.log(psi) + logcdf(Binomial.dist(n=n, p=p), value),
1669-
),
1670-
)
1671-
1672-
return check_parameters(
1673-
res,
1674-
0 <= psi,
1675-
psi <= 1,
1676-
0 <= p,
1677-
p <= 1,
1678-
msg="0 <= psi <= 1, 0 <= p <= 1",
1522+
def __new__(cls, name, psi, n, p, **kwargs):
1523+
return _zero_inflated_mixture(
1524+
name=name, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
16791525
)
16801526

1681-
1682-
class ZeroInflatedNegBinomialRV(RandomVariable):
1683-
name = "zero_inflated_neg_binomial"
1684-
ndim_supp = 0
1685-
ndims_params = [0, 0, 0]
1686-
dtype = "int64"
1687-
_print_name = (
1688-
"ZeroInflatedNegBinom",
1689-
"\\operatorname{ZeroInflatedNegBinom}",
1690-
)
1691-
16921527
@classmethod
1693-
def rng_fn(cls, rng, psi, n, p, size):
1694-
return rng.negative_binomial(n=n, p=p, size=size) * (rng.random(size=size) < psi)
1695-
1696-
1697-
zero_inflated_neg_binomial = ZeroInflatedNegBinomialRV()
1528+
def dist(cls, psi, n, p, **kwargs):
1529+
return _zero_inflated_mixture(
1530+
name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
1531+
)
16981532

16991533

1700-
class ZeroInflatedNegativeBinomial(Discrete):
1534+
class ZeroInflatedNegativeBinomial:
17011535
R"""
17021536
Zero-Inflated Negative binomial log-likelihood.
17031537
@@ -1778,91 +1612,21 @@ def ZeroInfNegBinom(a, m, psi, x):
17781612
Alternative number of target success trials (n > 0)
17791613
"""
17801614

1781-
rv_op = zero_inflated_neg_binomial
1782-
1783-
@classmethod
1784-
def dist(cls, psi, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
1785-
psi = at.as_tensor_variable(floatX(psi))
1786-
n, p = NegativeBinomial.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
1787-
n = at.as_tensor_variable(floatX(n))
1788-
p = at.as_tensor_variable(floatX(p))
1789-
return super().dist([psi, n, p], *args, **kwargs)
1790-
1791-
def get_moment(rv, size, psi, n, p):
1792-
mean = at.floor(psi * n * (1 - p) / p)
1793-
if not rv_size_is_none(size):
1794-
mean = at.full(size, mean)
1795-
return mean
1796-
1797-
def logp(value, psi, n, p):
1798-
r"""
1799-
Calculate log-probability of ZeroInflatedNegativeBinomial distribution at specified value.
1800-
1801-
Parameters
1802-
----------
1803-
value: numeric
1804-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1805-
values are desired the values must be provided in a numpy array or Aesara tensor
1806-
1807-
Returns
1808-
-------
1809-
TensorVariable
1810-
"""
1811-
1812-
res = at.switch(
1813-
at.gt(value, 0),
1814-
at.log(psi) + logp(NegativeBinomial.dist(n=n, p=p), value),
1815-
at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
1816-
)
1817-
1818-
res = at.switch(
1819-
at.lt(value, 0),
1820-
-np.inf,
1821-
res,
1822-
)
1823-
1824-
return check_parameters(
1825-
res,
1826-
0 <= psi,
1827-
psi <= 1,
1828-
0 < n,
1829-
0 <= p,
1830-
p <= 1,
1831-
msg="0 <= psi <= 1, n > 0, 0 <= p <= 1",
1832-
)
1833-
1834-
def logcdf(value, psi, n, p):
1835-
"""
1836-
Compute the log of the cumulative distribution function for ZeroInflatedNegativeBinomial distribution
1837-
at the specified value.
1838-
1839-
Parameters
1840-
----------
1841-
value: numeric or np.ndarray or aesara.tensor
1842-
Value(s) for which log CDF is calculated. If the log CDF for multiple
1843-
values are desired the values must be provided in a numpy array or Aesara tensor.
1844-
1845-
Returns
1846-
-------
1847-
TensorVariable
1848-
"""
1849-
res = at.switch(
1850-
at.lt(value, 0),
1851-
-np.inf,
1852-
at.logaddexp(
1853-
at.log1p(-psi),
1854-
at.log(psi) + logcdf(NegativeBinomial.dist(n=n, p=p), value),
1855-
),
1615+
def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
1616+
return _zero_inflated_mixture(
1617+
name=name,
1618+
nonzero_p=psi,
1619+
nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n),
1620+
**kwargs,
18561621
)
18571622

1858-
return check_parameters(
1859-
res,
1860-
0 <= psi,
1861-
psi <= 1,
1862-
0 < n,
1863-
0 < p,
1864-
p <= 1,
1865-
msg="0 <= psi <= 1, n > 0, 0 < p <= 1",
1623+
@classmethod
1624+
def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
1625+
return _zero_inflated_mixture(
1626+
name=None,
1627+
nonzero_p=psi,
1628+
nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n),
1629+
**kwargs,
18661630
)
18671631

18681632

0 commit comments

Comments
 (0)