Skip to content

Commit fde52a4

Browse files
aloctavodiaJunpeng Lao
authored and
Junpeng Lao
committed
Add LogitNormal distribution and tests (#2877)
* Add LogitNormal disribution and tests Co-authored-by: denadai2 <[email protected]> Co-authored-by: aloctavodia <[email protected]> * reduce precision for float32
1 parent dcdd543 commit fde52a4

File tree

6 files changed

+117
-3
lines changed

6 files changed

+117
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- Add new 'pairplot' function, for plotting scatter or hexbin matrices of sampled parameters.
1515
Optionally it can plot divergences.
1616
- Plots of discrete distributions in the docstrings
17+
- Add logitnormal distribution
1718

1819
### Fixes
1920

docs/source/api/distributions/continuous.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Continuous
3030
Triangular
3131
Gumbel
3232
Logistic
33+
LogitNormal
3334
Interpolated
3435

3536
.. automodule:: pymc3.distributions.continuous

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .continuous import Triangular
2727
from .continuous import Gumbel
2828
from .continuous import Logistic
29+
from .continuous import LogitNormal
2930
from .continuous import Interpolated
3031

3132
from .discrete import Binomial
@@ -145,6 +146,7 @@
145146
'DiscreteWeibull',
146147
'Gumbel',
147148
'Logistic',
149+
'LogitNormal',
148150
'Interpolated',
149151
'Bound',
150152
]

pymc3/distributions/continuous.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,23 @@
1010
import numpy as np
1111
import theano.tensor as tt
1212
from scipy import stats
13+
from scipy.special import expit
1314
from scipy.interpolate import InterpolatedUnivariateSpline
1415
import warnings
1516

1617
from pymc3.theanof import floatX
1718
from . import transforms
1819
from pymc3.util import get_variable_name
1920
from .special import log_i0
21+
from ..math import invlogit, logit
2022
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise, SplineWrapper
2123
from .distribution import Continuous, draw_values, generate_samples
2224

2325
__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'Beta', 'Exponential',
2426
'Laplace', 'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull',
2527
'HalfStudentT', 'Lognormal', 'ChiSquared', 'HalfNormal', 'Wald',
2628
'Pareto', 'InverseGamma', 'ExGaussian', 'VonMises', 'SkewNormal',
27-
'Triangular', 'Gumbel', 'Logistic', 'Interpolated']
29+
'Triangular', 'Gumbel', 'Logistic', 'LogitNormal', 'Interpolated']
2830

2931

3032
class PositiveContinuous(Continuous):
@@ -2262,6 +2264,86 @@ def _repr_latex_(self, name=None, dist=None):
22622264
get_variable_name(s))
22632265

22642266

2267+
class LogitNormal(UnitContinuous):
2268+
R"""
2269+
Logit-Normal log-likelihood.
2270+
2271+
The pdf of this distribution is
2272+
2273+
.. math::
2274+
f(x \mid \mu, \tau) =
2275+
\frac{1}{x(1-x)} \sqrt{\frac{\tau}{2\pi}}
2276+
\exp\left\{ -\frac{\tau}{2} (logit(x)-\mu)^2 \right\}
2277+
2278+
2279+
.. plot::
2280+
2281+
import matplotlib.pyplot as plt
2282+
import numpy as np
2283+
import scipy.stats as st
2284+
from scipy.special import logit
2285+
plt.style.use('seaborn-darkgrid')
2286+
x = np.linspace(0.0001, 0.9999, 500)
2287+
mus = [0., 0., 0., 1.]
2288+
sds = [0.3, 1., 2., 1.]
2289+
for mu, sd in zip(mus, sds):
2290+
pdf = st.norm.pdf(logit(x), loc=mu, scale=sd) * 1/(x * (1-x))
2291+
plt.plot(x, pdf, label=r'$\mu$ = {}, $\sigma$ = {}'.format(mu, sd))
2292+
plt.legend(loc=1)
2293+
plt.show()
2294+
2295+
======== ==========================================
2296+
Support :math:`x \in (0, 1)`
2297+
Mean no analytical solution
2298+
Variance no analytical solution
2299+
======== ==========================================
2300+
2301+
Parameters
2302+
----------
2303+
mu : float
2304+
Location parameter.
2305+
sd : float
2306+
Scale parameter (sd > 0).
2307+
tau : float
2308+
Scale parameter (tau > 0).
2309+
"""
2310+
2311+
def __init__(self, mu=0, sd=None, tau=None, **kwargs):
2312+
self.mu = mu = tt.as_tensor_variable(mu)
2313+
tau, sd = get_tau_sd(tau=tau, sd=sd)
2314+
self.sd = tt.as_tensor_variable(sd)
2315+
self.tau = tau = tt.as_tensor_variable(tau)
2316+
2317+
self.median = invlogit(mu)
2318+
assert_negative_support(sd, 'sd', 'LogitNormal')
2319+
assert_negative_support(tau, 'tau', 'LogitNormal')
2320+
2321+
super(LogitNormal, self).__init__(**kwargs)
2322+
2323+
def random(self, point=None, size=None, repeat=None):
2324+
mu, _, sd = draw_values([self.mu, self.tau, self.sd], point=point)
2325+
return expit(generate_samples(stats.norm.rvs, loc=mu, scale=sd, dist_shape=self.shape,
2326+
size=size))
2327+
2328+
def logp(self, value):
2329+
sd = self.sd
2330+
mu = self.mu
2331+
tau = self.tau
2332+
return bound(-0.5 * tau * (logit(value) - mu) ** 2
2333+
+ 0.5 * tt.log(tau / (2. * np.pi))
2334+
- tt.log(value * (1 - value)), value > 0, value < 1, tau > 0)
2335+
2336+
def _repr_latex_(self, name=None, dist=None):
2337+
if dist is None:
2338+
dist = self
2339+
sd = dist.sd
2340+
mu = dist.mu
2341+
name = r'\text{%s}' % name
2342+
return r'${} \sim \text{{LogitNormal}}(\mathit{{mu}}={},~\mathit{{sd}}={})$'.format(name,
2343+
get_variable_name(mu),
2344+
get_variable_name(sd))
2345+
2346+
22652347
class Interpolated(Continuous):
22662348
R"""
22672349
Univariate probability distribution defined as a linear interpolation

pymc3/tests/test_distributions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
1515
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
1616
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
17-
Gumbel, Logistic, Interpolated, ZeroInflatedBinomial, HalfFlat, AR1,
18-
KroneckerNormal)
17+
Gumbel, Logistic, LogitNormal, Interpolated, ZeroInflatedBinomial,
18+
HalfFlat, AR1, KroneckerNormal)
1919
from ..distributions import continuous
2020
from pymc3.theanof import floatX
2121
from numpy import array, inf, log, exp
@@ -27,6 +27,7 @@
2727
from scipy import integrate
2828
import scipy.stats.distributions as sp
2929
import scipy.stats
30+
from scipy.special import logit
3031
import theano
3132
import theano.tensor as tt
3233
from ..math import kronecker
@@ -1035,6 +1036,12 @@ def test_logistic(self):
10351036
lambda value, mu, s: sp.logistic.logpdf(value, mu, s),
10361037
decimal=select_by_precision(float64=6, float32=1))
10371038

1039+
def test_logitnormal(self):
1040+
self.pymc3_matches_scipy(LogitNormal, Unit, {'mu': R, 'sd': Rplus},
1041+
lambda value, mu, sd: (sp.norm.logpdf(logit(value), mu, sd)
1042+
- (np.log(value) + np.log1p(-value))),
1043+
decimal=select_by_precision(float64=6, float32=1))
1044+
10381045
def test_multidimensional_beta_construction(self):
10391046
with Model():
10401047
Beta('beta', alpha=1., beta=1., shape=(10, 20))

pymc3/tests/test_distributions_random.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import numpy.testing as npt
66
import scipy.stats as st
7+
from scipy.special import expit
78
from scipy import linalg
89
import numpy.random as nr
910
import theano
@@ -299,6 +300,16 @@ class TestGumbel(BaseTestCases.BaseTestCase):
299300
params = {'mu': 0., 'beta': 1.}
300301

301302

303+
class TestLogistic(BaseTestCases.BaseTestCase):
304+
distribution = pm.Logistic
305+
params = {'mu': 0., 's': 1.}
306+
307+
308+
class TestLogitNormal(BaseTestCases.BaseTestCase):
309+
distribution = pm.LogitNormal
310+
params = {'mu': 0., 'sd': 1.}
311+
312+
302313
class TestBinomial(BaseTestCases.BaseTestCase):
303314
distribution = pm.Binomial
304315
params = {'n': 5, 'p': 0.5}
@@ -668,6 +679,16 @@ def ref_rand(size, mu, beta):
668679
return st.gumbel_r.rvs(loc=mu, scale=beta, size=size)
669680
pymc3_random(pm.Gumbel, {'mu': R, 'beta': Rplus}, ref_rand=ref_rand)
670681

682+
def test_logistic(self):
683+
def ref_rand(size, mu, s):
684+
return st.logistic.rvs(loc=mu, scale=s, size=size)
685+
pymc3_random(pm.Logistic, {'mu': R, 's': Rplus}, ref_rand=ref_rand)
686+
687+
def test_logitnormal(self):
688+
def ref_rand(size, mu, sd):
689+
return expit(st.norm.rvs(loc=mu, scale=sd, size=size))
690+
pymc3_random(pm.LogitNormal, {'mu': R, 'sd': Rplus}, ref_rand=ref_rand)
691+
671692
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
672693
def test_interpolated(self):
673694
for mu in R.vals:

0 commit comments

Comments
 (0)