Skip to content

Reintroduce Bernoulli logitp parametrization #4620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
import aesara.tensor as at
import numpy as np

from aesara.tensor.random.basic import bernoulli, binomial, categorical, nbinom, poisson
from aesara.tensor.random.basic import (
BernoulliRV,
binomial,
categorical,
nbinom,
poisson,
)
from scipy import stats
from scipy.special import expit

from pymc3.aesaraf import floatX, intX, take_along_axis
from pymc3.distributions.dist_math import (
Expand All @@ -32,7 +39,7 @@
normal_lcdf,
)
from pymc3.distributions.distribution import Discrete
from pymc3.math import log1mexp, logaddexp, logsumexp, sigmoid, tround
from pymc3.math import log1mexp, log1pexp, logaddexp, logit, logsumexp, sigmoid, tround

__all__ = [
"Binomial",
Expand Down Expand Up @@ -332,6 +339,19 @@ def logcdf(self, value):
)


class BernoulliLogitRV(BernoulliRV):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it inherit from BernoulliRV or RandomVariable directly?

Is the name and _print_name change just confusing and not necessary?

name = "bernoulli_logit"
_print_name = ("BernLogit", "\\operatorname{BernLogit}")

@classmethod
def rng_fn(cls, rng, logitp, size=None):
p = expit(logitp)
return stats.bernoulli.rvs(p, size=size, random_state=rng)


bernoulli_logit = BernoulliLogitRV()


class Bernoulli(Discrete):
R"""Bernoulli log-likelihood

Expand Down Expand Up @@ -368,16 +388,29 @@ class Bernoulli(Discrete):
----------
p: float
Probability of success (0 < p < 1).
logit_p: float
Alternative logit of sucess probability.
"""
rv_op = bernoulli
rv_op = bernoulli_logit

@classmethod
def dist(cls, p=None, logit_p=None, *args, **kwargs):
p = at.as_tensor_variable(floatX(p))
# mode = at.cast(tround(p), "int8")
return super().dist([p], **kwargs)
logit_p = cls.get_logitp(p=p, logit_p=logit_p)
logit_p = at.as_tensor_variable(floatX(logit_p))
return super().dist([logit_p], **kwargs)

def logp(value, p):
@classmethod
def get_logitp(cls, p=None, logit_p=None):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")

if logit_p is None:
logit_p = logit(p)
return logit_p

def logp(value, logit_p):
r"""
Calculate log-probability of Bernoulli distribution at specified value.

Expand All @@ -391,19 +424,15 @@ def logp(value, p):
-------
TensorVariable
"""
# if self._is_logit:
# lp = at.switch(value, self._logit_p, -self._logit_p)
# return -log1pexp(-lp)
# else:
lp = at.switch(value, -logit_p, logit_p)
return bound(
at.switch(value, at.log(p), at.log(1 - p)),
value >= 0,
-log1pexp(lp),
0 <= value,
value <= 1,
p >= 0,
p <= 1,
~at.isnan(logit_p),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid p gets converted to nan so this is how we can identify it in bound

)

def logcdf(value, p):
def logcdf(value, logit_p):
"""
Compute the log of the cumulative distribution function for Bernoulli distribution
at the specified value.
Expand All @@ -422,12 +451,11 @@ def logcdf(value, p):
return bound(
at.switch(
at.lt(value, 1),
at.log1p(-p),
-log1pexp(logit_p),
0,
),
0 <= value,
0 <= p,
p <= 1,
~at.isnan(logit_p),
)

def _distr_parameters_for_repr(self):
Expand Down
42 changes: 26 additions & 16 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,6 @@ def scipy_mu_alpha_logcdf(value, mu, alpha):
n_samples=10,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.parametrize(
"mu, p, alpha, n, expected",
[
Expand Down Expand Up @@ -1522,21 +1521,6 @@ def test_beta_binomial_selfconsistency(self):
{"alpha": Rplus, "beta": Rplus, "n": NatSmall},
)

@pytest.mark.xfail(reason="Bernoulli logit_p not refactored yet")
def test_bernoulli_logit_p(self):
self.check_logp(
Bernoulli,
Bool,
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
)
self.check_logcdf(
Bernoulli,
Bool,
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)),
)

def test_bernoulli(self):
self.check_logp(
Bernoulli,
Expand All @@ -1556,6 +1540,32 @@ def test_bernoulli(self):
{"p": Unit},
)

def test_bernoulli_logitp(self):
self.check_logp(
Bernoulli,
Bool,
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
)
self.check_logcdf(
Bernoulli,
Bool,
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)),
)

@pytest.mark.parametrize(
"p, logit_p, expected",
[
(None, None, "Must specify either p or logit_p."),
(0.5, 0.5, "Can't specify both p and logit_p."),
],
)
def test_bernoulli_init_fail(self, p, logit_p, expected):
with Model():
with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"):
Bernoulli("x", p=p, logit_p=logit_p)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_discrete_weibull(self):
self.check_logp(
Expand Down
9 changes: 7 additions & 2 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,15 @@ def test_beta_binomial(self):
def _beta_bin(self, n, alpha, beta, size=None):
return st.binom.rvs(n, st.beta.rvs(a=alpha, b=beta, size=size))

@pytest.mark.skip(reason="This test is covered by Aesara")
def test_bernoulli(self):
pymc3_random_discrete(
pm.Bernoulli, {"p": Unit}, ref_rand=lambda size, p=None: st.bernoulli.rvs(p, size=size)
pm.Bernoulli, {"p": Unit}, ref_rand=lambda size, p: st.bernoulli.rvs(p, size=size)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be replaced following #4608

)

pymc3_random_discrete(
pm.Bernoulli,
{"logit_p": R},
ref_rand=lambda size, logit_p: st.bernoulli.rvs(expit(logit_p), size=size),
)

@pytest.mark.skip(reason="This test is covered by Aesara")
Expand Down
3 changes: 1 addition & 2 deletions pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def get_city_data():
return data.merge(unique, "inner", on="fips")


@pytest.mark.xfail(reason="Bernoulli distribution not refactored")
class TestARM5_4(SeededTest):
def build_model(self):
data = pd.read_csv(
Expand All @@ -68,7 +67,7 @@ def build_model(self):
P["1"] = 1

with pm.Model() as model:
effects = pm.Normal("effects", mu=0, sigma=100, shape=len(P.columns))
effects = pm.Normal("effects", mu=0, sigma=100, size=len(P.columns))
logit_p = at.dot(floatX(np.array(P)), effects)
pm.Bernoulli("s", logit_p=logit_p, observed=floatX(data.switch.values))
return model
Expand Down