Skip to content

Commit 4dd0538

Browse files
committed
Reintroduce logit_p argument in Bernoulli
1 parent 69a4e60 commit 4dd0538

File tree

4 files changed

+33
-21
lines changed

4 files changed

+33
-21
lines changed

pymc3/distributions/discrete.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,15 @@ class Bernoulli(Discrete):
333333

334334
@classmethod
335335
def dist(cls, p=None, logit_p=None, *args, **kwargs):
336+
if p is not None and logit_p is not None:
337+
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
338+
elif p is None and logit_p is None:
339+
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")
340+
341+
if logit_p is not None:
342+
p = at.sigmoid(logit_p)
343+
336344
p = at.as_tensor_variable(floatX(p))
337-
# mode = at.cast(tround(p), "int8")
338345
return super().dist([p], **kwargs)
339346

340347
def logp(value, p):
@@ -351,12 +358,9 @@ def logp(value, p):
351358
-------
352359
TensorVariable
353360
"""
354-
# if self._is_logit:
355-
# lp = at.switch(value, self._logit_p, -self._logit_p)
356-
# return -log1pexp(-lp)
357-
# else:
361+
358362
return bound(
359-
at.switch(value, at.log(p), at.log(1 - p)),
363+
at.switch(value, at.log(p), at.log1p(-p)),
360364
value >= 0,
361365
value <= 1,
362366
p >= 0,

pymc3/tests/test_distributions.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,40 +1608,50 @@ def test_beta_binomial(self):
16081608
{"alpha": Rplus, "beta": Rplus, "n": NatSmall},
16091609
)
16101610

1611-
@pytest.mark.xfail(reason="Bernoulli logit_p not refactored yet")
1612-
def test_bernoulli_logit_p(self):
1611+
def test_bernoulli(self):
16131612
self.check_logp(
16141613
Bernoulli,
16151614
Bool,
1616-
{"logit_p": R},
1617-
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
1615+
{"p": Unit},
1616+
lambda value, p: sp.bernoulli.logpmf(value, p),
16181617
)
1619-
self.check_logcdf(
1618+
self.check_logp(
16201619
Bernoulli,
16211620
Bool,
16221621
{"logit_p": R},
1623-
lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)),
1622+
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
16241623
)
1625-
1626-
def test_bernoulli(self):
1627-
self.check_logp(
1624+
self.check_logcdf(
16281625
Bernoulli,
16291626
Bool,
16301627
{"p": Unit},
1631-
lambda value, p: sp.bernoulli.logpmf(value, p),
1628+
lambda value, p: sp.bernoulli.logcdf(value, p),
16321629
)
16331630
self.check_logcdf(
16341631
Bernoulli,
16351632
Bool,
1636-
{"p": Unit},
1637-
lambda value, p: sp.bernoulli.logcdf(value, p),
1633+
{"logit_p": R},
1634+
lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)),
16381635
)
16391636
self.check_selfconsistency_discrete_logcdf(
16401637
Bernoulli,
16411638
Bool,
16421639
{"p": Unit},
16431640
)
16441641

1642+
def test_bernoulli_wrong_arguments(self):
1643+
m = pm.Model()
1644+
1645+
msg = "Incompatible parametrization. Can't specify both p and logit_p"
1646+
with m:
1647+
with pytest.raises(ValueError, match=msg):
1648+
Bernoulli("x", p=0.5, logit_p=0)
1649+
1650+
msg = "Incompatible parametrization. Must specify either p or logit_p"
1651+
with m:
1652+
with pytest.raises(ValueError, match=msg):
1653+
Bernoulli("x")
1654+
16451655
def test_discrete_weibull(self):
16461656
self.check_logp(
16471657
DiscreteWeibull,

pymc3/tests/test_distributions_random.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,11 +1025,10 @@ class TestBernoulli(BaseTestDistribution):
10251025
]
10261026

10271027

1028-
@pytest.mark.skip("Still not implemented")
10291028
class TestBernoulliLogitP(BaseTestDistribution):
10301029
pymc_dist = pm.Bernoulli
10311030
pymc_dist_params = {"logit_p": 1.0}
1032-
expected_rv_op_params = {"mean": 0, "sigma": 10.0}
1031+
expected_rv_op_params = {"p": expit(1.0)}
10331032
tests_to_run = ["check_pymc_params_match_rv_op"]
10341033

10351034

pymc3/tests/test_examples.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def get_city_data():
5151
return data.merge(unique, "inner", on="fips")
5252

5353

54-
@pytest.mark.xfail(reason="Bernoulli logitp distribution not refactored")
5554
class TestARM5_4(SeededTest):
5655
def build_model(self):
5756
data = pd.read_csv(

0 commit comments

Comments
 (0)