Skip to content

Commit e8396bf

Browse files
zoj613ricardoV94
authored andcommitted
Add Polya-Gamma distribution
1 parent 8623cd4 commit e8396bf

File tree

7 files changed

+295
-0
lines changed

7 files changed

+295
-0
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ jobs:
134134
run: |
135135
conda activate pymc3-dev-py37
136136
pip install -e .
137+
pip install --pre -U polyagamma
137138
python --version
138139
- name: Run tests
139140
run: |
@@ -211,6 +212,7 @@ jobs:
211212
run: |
212213
conda activate pymc3-dev-py38
213214
pip install -e .
215+
pip install --pre -U polyagamma
214216
python --version
215217
- name: Run tests
216218
# This job uses a cmd shell, therefore the environment variable syntax is different!

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)).
2222
- The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical
2323
observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)).
24+
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc3/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
2425
- ...
2526

2627
### Maintenance

docs/source/api/distributions/continuous.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Continuous
3636
Logistic
3737
LogitNormal
3838
Interpolated
39+
PolyaGamma
3940

4041
.. automodule:: pymc3.distributions.continuous
4142
:members:

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
Moyal,
4949
Normal,
5050
Pareto,
51+
PolyaGamma,
5152
Rice,
5253
SkewNormal,
5354
StudentT,
@@ -189,6 +190,7 @@
189190
"Simulator",
190191
"BART",
191192
"CAR",
193+
"PolyaGamma",
192194
"logpt",
193195
"logp",
194196
"_logp",

pymc3/distributions/continuous.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@
2020

2121
from typing import List, Optional, Tuple, Union
2222

23+
import aesara
2324
import aesara.tensor as at
2425
import numpy as np
2526

2627
from aesara.assert_op import Assert
28+
from aesara.graph.basic import Apply
29+
from aesara.graph.op import Op
2730
from aesara.tensor import gammaln
31+
from aesara.tensor.extra_ops import broadcast_shape
2832
from aesara.tensor.random.basic import (
2933
BetaRV,
3034
WeibullRV,
@@ -47,6 +51,21 @@
4751
)
4852
from aesara.tensor.random.op import RandomVariable
4953
from aesara.tensor.var import TensorConstant, TensorVariable
54+
55+
try:
56+
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
57+
except ImportError: # pragma: no cover
58+
59+
def random_polyagamma(*args, **kwargs):
60+
raise RuntimeError("polyagamma package is not installed!")
61+
62+
def polyagamma_pdf(*args, **kwargs):
63+
raise RuntimeError("polyagamma package is not installed!")
64+
65+
def polyagamma_cdf(*args, **kwargs):
66+
raise RuntimeError("polyagamma package is not installed!")
67+
68+
5069
from scipy import stats
5170
from scipy.interpolate import InterpolatedUnivariateSpline
5271
from scipy.special import expit
@@ -103,6 +122,7 @@
103122
"Rice",
104123
"Moyal",
105124
"AsymmetricLaplace",
125+
"PolyaGamma",
106126
]
107127

108128

@@ -4007,3 +4027,201 @@ def logcdf(value, mu, sigma):
40074027
at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))),
40084028
0 < sigma,
40094029
)
4030+
4031+
4032+
class PolyaGammaRV(RandomVariable):
4033+
"""Polya-Gamma random variable."""
4034+
4035+
name = "polyagamma"
4036+
ndim_supp = 0
4037+
ndims_params = [0, 0]
4038+
dtype = "floatX"
4039+
_print_name = ("PG", "\\operatorname{PG}")
4040+
4041+
def __call__(self, h=1.0, z=0.0, size=None, **kwargs):
4042+
return super().__call__(h, z, size=size, **kwargs)
4043+
4044+
@classmethod
4045+
def rng_fn(cls, rng, h, z, size=None):
4046+
"""
4047+
Generate a random sample from the distribution with the given parameters
4048+
4049+
Parameters
4050+
----------
4051+
rng : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator}
4052+
A seed to initialize the random number generator. If None, then fresh,
4053+
unpredictable entropy will be pulled from the OS. If an ``int`` or
4054+
``array_like[ints]`` is passed, then it will be passed to
4055+
`SeedSequence` to derive the initial `BitGenerator` state. One may also
4056+
pass in a `SeedSequence` instance.
4057+
Additionally, when passed a `BitGenerator`, it will be wrapped by
4058+
`Generator`. If passed a `Generator`, it will be returned unaltered.
4059+
h : scalar or sequence
4060+
The shape parameter of the distribution.
4061+
z : scalar or sequence
4062+
The exponential tilting parameter.
4063+
size : int or tuple of ints, optional
4064+
The number of elements to draw from the distribution. If size is
4065+
``None`` (default) then a single value is returned. If a tuple of
4066+
integers is passed, the returned array will have the same shape.
4067+
If the element(s) of size is not an integer type, it will be truncated
4068+
to the largest integer smaller than its value (e.g (2.1, 1) -> (2, 1)).
4069+
This parameter only applies if `h` and `z` are scalars.
4070+
"""
4071+
# handle the kind of rng passed to the sampler
4072+
bg = rng._bit_generator if isinstance(rng, np.random.RandomState) else rng
4073+
return random_polyagamma(h, z, size=size, random_state=bg).astype(aesara.config.floatX)
4074+
4075+
4076+
polyagamma = PolyaGammaRV()
4077+
4078+
4079+
class _PolyaGammaLogDistFunc(Op):
4080+
__props__ = ("get_pdf",)
4081+
4082+
def __init__(self, get_pdf=False):
4083+
self.get_pdf = get_pdf
4084+
4085+
def make_node(self, x, h, z):
4086+
x = at.as_tensor_variable(floatX(x))
4087+
h = at.as_tensor_variable(floatX(h))
4088+
z = at.as_tensor_variable(floatX(z))
4089+
shape = broadcast_shape(x, h, z)
4090+
broadcastable = [] if not shape else [False] * len(shape)
4091+
return Apply(self, [x, h, z], [at.TensorType(aesara.config.floatX, broadcastable)()])
4092+
4093+
def perform(self, node, ins, outs):
4094+
x, h, z = ins[0], ins[1], ins[2]
4095+
outs[0][0] = (
4096+
polyagamma_pdf(x, h, z, return_log=True)
4097+
if self.get_pdf
4098+
else polyagamma_cdf(x, h, z, return_log=True)
4099+
).astype(aesara.config.floatX)
4100+
4101+
4102+
class PolyaGamma(PositiveContinuous):
4103+
r"""
4104+
The Polya-Gamma distribution.
4105+
4106+
The distribution is parametrized by ``h`` (shape parameter) and ``z``
4107+
(exponential tilting parameter). The pdf of this distribution is
4108+
4109+
.. math::
4110+
4111+
f(x \mid h, z) = cosh^h(\frac{z}{2})e^{-\frac{1}{2}xz^2}f(x \mid h, 0),
4112+
where :math:`f(x \mid h, 0)` is the pdf of a :math:`PG(h, 0)` variable.
4113+
Notice that the pdf of this distribution is expressed as an alternating-sign
4114+
sum of inverse-Gaussian densities.
4115+
4116+
.. math::
4117+
4118+
X = \Sigma_{k=1}^{\infty}\frac{Ga(h, 1)}{d_k},
4119+
4120+
where :math:`d_k = 2(k - 0.5)^2\pi^2 + z^2/2`, :math:`Ga(h, 1)` is a gamma
4121+
random variable with shape parameter ``h`` and scale parameter ``1``.
4122+
4123+
.. plot::
4124+
4125+
import matplotlib.pyplot as plt
4126+
import numpy as np
4127+
from polyagamma import polyagamma_pdf
4128+
plt.style.use('seaborn-darkgrid')
4129+
x = np.linspace(0.01, 5, 500);x.sort()
4130+
hs = [1., 5., 10., 15.]
4131+
zs = [0.] * 4
4132+
for h, z in zip(hs, zs):
4133+
pdf = polyagamma_pdf(x, h=h, z=z)
4134+
plt.plot(x, pdf, label=r'$h$ = {}, $z$ = {}'.format(h, z))
4135+
plt.xlabel('x', fontsize=12)
4136+
plt.ylabel('f(x)', fontsize=12)
4137+
plt.legend(loc=1)
4138+
plt.show()
4139+
4140+
======== =============================
4141+
Support :math:`x \in (0, \infty)`
4142+
Mean :math:`dfrac{h}{4} if :math:`z=0`, :math:`\dfrac{tanh(z/2)h}{2z}` otherwise.
4143+
Variance :math:`0.041666688h` if :math:`z=0`, :math:`\dfrac{h(sinh(z) - z)(1 - tanh^2(z/2))}{4z^3}` otherwise.
4144+
======== =============================
4145+
4146+
Parameters
4147+
----------
4148+
h: float, optional
4149+
The shape parameter of the distribution (h > 0).
4150+
z: float, optional
4151+
The exponential tilting parameter of the distribution.
4152+
4153+
Examples
4154+
--------
4155+
.. code-block:: python
4156+
4157+
rng = np.random.default_rng()
4158+
with pm.Model():
4159+
x = pm.PolyaGamma('x', h=1, z=5.5)
4160+
with pm.Model():
4161+
x = pm.PolyaGamma('x', h=25, z=-2.3, rng=rng, size=(100, 5))
4162+
4163+
References
4164+
----------
4165+
.. [1] Polson, Nicholas G., James G. Scott, and Jesse Windle.
4166+
"Bayesian inference for logistic models using Pólya–Gamma latent
4167+
variables." Journal of the American statistical Association
4168+
108.504 (2013): 1339-1349.
4169+
.. [2] Windle, Jesse, Nicholas G. Polson, and James G. Scott.
4170+
"Sampling Polya-Gamma random variates: alternate and approximate
4171+
techniques." arXiv preprint arXiv:1405.0506 (2014)
4172+
.. [3] Luc Devroye. "On exact simulation algorithms for some distributions
4173+
related to Jacobi theta functions." Statistics & Probability Letters,
4174+
Volume 79, Issue 21, (2009): 2251-2259.
4175+
.. [4] Windle, J. (2013). Forecasting high-dimensional, time-varying
4176+
variance-covariance matrices with high-frequency data and sampling
4177+
Pólya-Gamma random variates for posterior distributions derived
4178+
from logistic likelihoods.(PhD thesis). Retrieved from
4179+
http://hdl.handle.net/2152/21842
4180+
"""
4181+
rv_op = polyagamma
4182+
4183+
@classmethod
4184+
def dist(cls, h=1.0, z=0.0, **kwargs):
4185+
h = at.as_tensor_variable(floatX(h))
4186+
z = at.as_tensor_variable(floatX(z))
4187+
4188+
msg = f"The variable {h} specified for PolyaGamma has non-positive "
4189+
msg += "values, making it unsuitable for this parameter."
4190+
Assert(msg)(h, at.all(at.gt(h, 0.0)))
4191+
4192+
return super().dist([h, z], **kwargs)
4193+
4194+
def logp(value, h, z):
4195+
"""
4196+
Calculate log-probability of Polya-Gamma distribution at specified value.
4197+
4198+
Parameters
4199+
----------
4200+
value: numeric
4201+
Value(s) for which log-probability is calculated. If the log
4202+
probabilities for multiple values are desired the values must be
4203+
provided in a numpy array.
4204+
4205+
Returns
4206+
-------
4207+
TensorVariable
4208+
"""
4209+
4210+
return bound(_PolyaGammaLogDistFunc(True)(value, h, z), h > 0, value > 0)
4211+
4212+
def logcdf(value, h, z):
4213+
"""
4214+
Compute the log of the cumulative distribution function for the
4215+
Polya-Gamma distribution at the specified value.
4216+
4217+
Parameters
4218+
----------
4219+
value: numeric or np.ndarray or `TensorVariable`
4220+
Value(s) for which log CDF is calculated. If the log CDF for multiple
4221+
values are desired the values must be provided in a numpy array.
4222+
4223+
Returns
4224+
-------
4225+
TensorVariable
4226+
"""
4227+
return bound(_PolyaGammaLogDistFunc(False)(value, h, z), h > 0, value > 0)

pymc3/tests/test_distributions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
import aesara.tensor as at
2020
import numpy as np
2121
import numpy.random as nr
22+
23+
try:
24+
from polyagamma import polyagamma_cdf, polyagamma_pdf
25+
26+
_polyagamma_not_installed = False
27+
except ImportError: # pragma: no cover
28+
29+
_polyagamma_not_installed = True
30+
31+
def polyagamma_pdf(*args, **kwargs):
32+
raise RuntimeError("polyagamma package is not installed!")
33+
34+
def polyagamma_cdf(*args, **kwargs):
35+
raise RuntimeError("polyagamma package is not installed!")
36+
37+
2238
import pytest
2339
import scipy.stats
2440
import scipy.stats.distributions as sp
@@ -954,6 +970,26 @@ def test_bound_normal(self):
954970
x = PositiveNormal("x", mu=0, sigma=1, transform=None)
955971
assert np.isinf(logp(x, -1).eval())
956972

973+
@pytest.mark.skipif(
974+
condition=_polyagamma_not_installed,
975+
reason="`polyagamma package is not available/installed.",
976+
)
977+
def test_polyagamma(self):
978+
self.check_logp(
979+
pm.PolyaGamma,
980+
Rplus,
981+
{"h": Rplus, "z": R},
982+
lambda value, h, z: polyagamma_pdf(value, h, z, return_log=True),
983+
decimal=select_by_precision(float64=6, float32=-1),
984+
)
985+
self.check_logcdf(
986+
pm.PolyaGamma,
987+
Rplus,
988+
{"h": Rplus, "z": R},
989+
lambda value, h, z: polyagamma_cdf(value, h, z, return_log=True),
990+
decimal=select_by_precision(float64=6, float32=-1),
991+
)
992+
957993
def test_discrete_unif(self):
958994
self.check_logp(
959995
DiscreteUniform,

pymc3/tests/test_distributions_random.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@
2424
import scipy.stats as st
2525

2626
from numpy.testing import assert_almost_equal, assert_array_almost_equal
27+
28+
try:
29+
from polyagamma import random_polyagamma
30+
31+
_polyagamma_not_installed = False
32+
except ImportError: # pragma: no cover
33+
34+
_polyagamma_not_installed = True
35+
36+
def random_polyagamma(*args, **kwargs):
37+
raise RuntimeError("polyagamma package is not installed!")
38+
39+
2740
from scipy.special import expit
2841

2942
import pymc3 as pm
@@ -1326,6 +1339,28 @@ class TestBetaBinomial(BaseTestDistribution):
13261339
]
13271340

13281341

1342+
@pytest.mark.skipif(
1343+
condition=_polyagamma_not_installed,
1344+
reason="`polyagamma package is not available/installed.",
1345+
)
1346+
class TestPolyaGamma(BaseTestDistribution):
1347+
def polyagamma_rng_fn(self, size, h, z, rng):
1348+
return random_polyagamma(h, z, size=size, random_state=rng._bit_generator)
1349+
1350+
pymc_dist = pm.PolyaGamma
1351+
pymc_dist_params = {"h": 1.0, "z": 0.0}
1352+
expected_rv_op_params = {"h": 1.0, "z": 0.0}
1353+
reference_dist_params = {"h": 1.0, "z": 0.0}
1354+
reference_dist = lambda self: functools.partial(
1355+
self.polyagamma_rng_fn, rng=self.get_random_state()
1356+
)
1357+
tests_to_run = [
1358+
"check_pymc_params_match_rv_op",
1359+
"check_pymc_draws_match_reference",
1360+
"check_rv_size",
1361+
]
1362+
1363+
13291364
class TestDiscreteUniform(BaseTestDistribution):
13301365
def discrete_uniform_rng_fn(self, size, lower, upper, rng):
13311366
return st.randint.rvs(lower, upper + 1, size=size, random_state=rng)

0 commit comments

Comments
 (0)