Skip to content

Commit a1d9b7a

Browse files
ricardoV94twiecki
authored andcommitted
Fix exponential and gamma logp / random link (#4576)
1 parent 8873806 commit a1d9b7a

File tree

2 files changed

+68
-51
lines changed

2 files changed

+68
-51
lines changed

pymc3/distributions/continuous.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,49 +1382,53 @@ class Exponential(PositiveContinuous):
13821382
@classmethod
13831383
def dist(cls, lam, *args, **kwargs):
13841384
lam = at.as_tensor_variable(floatX(lam))
1385-
# mean = 1.0 / lam
1386-
# median = mean * at.log(2)
1387-
# mode = at.zeros_like(lam)
1388-
1389-
# variance = lam ** -2
13901385

13911386
assert_negative_support(lam, "lam", "Exponential")
1392-
return super().dist([lam], **kwargs)
13931387

1394-
def logp(value, lam):
1388+
# Aesara exponential op is parametrized in terms of mu (1/lam)
1389+
return super().dist([at.inv(lam)], **kwargs)
1390+
1391+
def logp(value, mu):
13951392
"""
13961393
Calculate log-probability of Exponential distribution at specified value.
13971394
13981395
Parameters
13991396
----------
14001397
value: numeric
1401-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1402-
values are desired the values must be provided in a numpy array or aesara tensor
1398+
Value(s) for which log-probability is calculated. If the log
1399+
probabilities for multiple values are desired the values must be
1400+
provided in a numpy array or aesara tensor
14031401
14041402
Returns
14051403
-------
14061404
TensorVariable
14071405
"""
1408-
return bound(at.log(lam) - lam * value, value >= 0, lam > 0)
1406+
lam = at.inv(mu)
1407+
return bound(
1408+
at.log(lam) - lam * value,
1409+
value >= 0,
1410+
lam > 0,
1411+
)
14091412

1410-
def logcdf(value, lam):
1413+
def logcdf(value, mu):
14111414
r"""
14121415
Compute the log of cumulative distribution function for the Exponential distribution
14131416
at the specified value.
14141417
14151418
Parameters
14161419
----------
14171420
value: numeric or np.ndarray or aesara.tensor
1418-
Value(s) for which log CDF is calculated. If the log CDF for multiple
1419-
values are desired the values must be provided in a numpy array or aesara tensor.
1421+
Value(s) for which log CDF is calculated. If the log CDF for
1422+
multiple values are desired the values must be provided in a numpy
1423+
array or aesara tensor.
14201424
14211425
Returns
14221426
-------
14231427
TensorVariable
14241428
"""
1425-
a = lam * value
1429+
lam = at.inv(mu)
14261430
return bound(
1427-
log1mexp(a),
1431+
log1mexp(lam * value),
14281432
0 <= value,
14291433
0 <= lam,
14301434
)
@@ -2376,15 +2380,13 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, no_assert=Fal
23762380
alpha, beta = cls.get_alpha_beta(alpha, beta, mu, sigma)
23772381
alpha = at.as_tensor_variable(floatX(alpha))
23782382
beta = at.as_tensor_variable(floatX(beta))
2379-
# mean = alpha / beta
2380-
# mode = at.maximum((alpha - 1) / beta, 0)
2381-
# variance = alpha / beta ** 2
23822383

23832384
if not no_assert:
23842385
assert_negative_support(alpha, "alpha", "Gamma")
23852386
assert_negative_support(beta, "beta", "Gamma")
23862387

2387-
return super().dist([alpha, at.inv(beta)], **kwargs)
2388+
# The Aesara `GammaRV` `Op` will invert the `beta` parameter itself
2389+
return super().dist([alpha, beta], **kwargs)
23882390

23892391
@classmethod
23902392
def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
@@ -2402,45 +2404,47 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
24022404

24032405
return alpha, beta
24042406

2405-
def _distr_parameters_for_repr(self):
2406-
return ["alpha", "beta"]
2407-
2408-
def logp(value, alpha, beta):
2407+
def logp(value, alpha, inv_beta):
24092408
"""
24102409
Calculate log-probability of Gamma distribution at specified value.
24112410
24122411
Parameters
24132412
----------
24142413
value: numeric
2415-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
2416-
values are desired the values must be provided in a numpy array or `TensorVariable`.
2414+
Value(s) for which log-probability is calculated. If the log
2415+
probabilities for multiple values are desired the values must be
2416+
provided in a numpy array or `TensorVariable`.
24172417
24182418
Returns
24192419
-------
24202420
TensorVariable
24212421
"""
2422+
beta = at.inv(inv_beta)
24222423
return bound(
24232424
-gammaln(alpha) + logpow(beta, alpha) - beta * value + logpow(value, alpha - 1),
24242425
value >= 0,
24252426
alpha > 0,
24262427
beta > 0,
24272428
)
24282429

2429-
def logcdf(value, alpha, beta):
2430+
def logcdf(value, alpha, inv_beta):
24302431
"""
24312432
Compute the log of the cumulative distribution function for Gamma distribution
24322433
at the specified value.
24332434
24342435
Parameters
24352436
----------
24362437
value: numeric or np.ndarray or `TensorVariable`
2437-
Value(s) for which log CDF is calculated. If the log CDF for multiple
2438-
values are desired the values must be provided in a numpy array or `TensorVariable`.
2438+
Value(s) for which log CDF is calculated. If the log CDF for
2439+
multiple values are desired the values must be provided in a numpy
2440+
array or `TensorVariable`.
24392441
24402442
Returns
24412443
-------
24422444
TensorVariable
24432445
"""
2446+
beta = at.inv(inv_beta)
2447+
24442448
# Avoid C-assertion when the gammainc function is called with invalid values (#4340)
24452449
safe_alpha = at.switch(at.lt(alpha, 0), 0, alpha)
24462450
safe_beta = at.switch(at.lt(beta, 0), 0, beta)

pymc3/tests/test_distributions_random.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import aesara
2121
import numpy as np
2222
import numpy.random as nr
23+
import numpy.testing as npt
2324
import pytest
2425
import scipy.stats as st
2526

@@ -32,7 +33,7 @@
3233
from pymc3.distributions.dist_math import clipped_beta_rvs
3334
from pymc3.distributions.shape_utils import to_tuple
3435
from pymc3.exceptions import ShapeError
35-
from pymc3.tests.helpers import SeededTest
36+
from pymc3.tests.helpers import SeededTest, select_by_precision
3637
from pymc3.tests.test_distributions import (
3738
Domain,
3839
I,
@@ -626,13 +627,6 @@ def ref_rand(size, alpha, beta):
626627

627628
pymc3_random(pm.Beta, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
628629

629-
@pytest.mark.skip(reason="This test is covered by Aesara")
630-
def test_exponential(self):
631-
def ref_rand(size, lam):
632-
return nr.exponential(scale=1.0 / lam, size=size)
633-
634-
pymc3_random(pm.Exponential, {"lam": Rplus}, ref_rand=ref_rand)
635-
636630
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
637631
def test_laplace(self):
638632
def ref_rand(size, mu, b):
@@ -680,20 +674,6 @@ def ref_rand(size, beta):
680674

681675
pymc3_random(pm.HalfCauchy, {"beta": Rplusbig}, ref_rand=ref_rand)
682676

683-
@pytest.mark.skip(reason="This test is covered by Aesara")
684-
def test_gamma_alpha_beta(self):
685-
def ref_rand(size, alpha, beta):
686-
return st.gamma.rvs(alpha, scale=1.0 / beta, size=size)
687-
688-
pymc3_random(pm.Gamma, {"alpha": Rplusbig, "beta": Rplusbig}, ref_rand=ref_rand)
689-
690-
@pytest.mark.skip(reason="This test is covered by Aesara")
691-
def test_gamma_mu_sigma(self):
692-
def ref_rand(size, mu, sigma):
693-
return st.gamma.rvs(mu ** 2 / sigma ** 2, scale=sigma ** 2 / mu, size=size)
694-
695-
pymc3_random(pm.Gamma, {"mu": Rplusbig, "sigma": Rplusbig}, ref_rand=ref_rand)
696-
697677
@pytest.mark.skip(reason="This test is covered by Aesara")
698678
def test_inverse_gamma(self):
699679
def ref_rand(size, alpha, beta):
@@ -1787,7 +1767,7 @@ def test_issue_3758(self):
17871767

17881768
for var in "bcd":
17891769
std = np.std(samples[var] - samples["a"])
1790-
np.testing.assert_allclose(std, 1, rtol=1e-2)
1770+
npt.assert_allclose(std, 1, rtol=1e-2)
17911771

17921772
def test_issue_3829(self):
17931773
with pm.Model() as model:
@@ -1884,3 +1864,36 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
18841864
prior = pm.sample_prior_predictive(samples=sample_shape)
18851865

18861866
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1867+
1868+
1869+
def test_exponential_parameterization():
1870+
test_lambda = floatX(10.0)
1871+
1872+
exp_pymc = pm.Exponential.dist(lam=test_lambda)
1873+
(rv_scale,) = exp_pymc.owner.inputs[3:]
1874+
1875+
npt.assert_almost_equal(rv_scale.eval(), 1 / test_lambda)
1876+
1877+
1878+
def test_gamma_parameterization():
1879+
1880+
test_alpha = floatX(10.0)
1881+
test_beta = floatX(100.0)
1882+
1883+
gamma_pymc = pm.Gamma.dist(alpha=test_alpha, beta=test_beta)
1884+
rv_alpha, rv_inv_beta = gamma_pymc.owner.inputs[3:]
1885+
1886+
assert np.array_equal(rv_alpha.eval(), test_alpha)
1887+
1888+
decimal = select_by_precision(float64=6, float32=3)
1889+
1890+
npt.assert_almost_equal(rv_inv_beta.eval(), 1.0 / test_beta, decimal)
1891+
1892+
test_mu = test_alpha / test_beta
1893+
test_sigma = np.sqrt(test_mu / test_beta)
1894+
1895+
gamma_pymc = pm.Gamma.dist(mu=test_mu, sigma=test_sigma)
1896+
rv_alpha, rv_inv_beta = gamma_pymc.owner.inputs[3:]
1897+
1898+
npt.assert_almost_equal(rv_alpha.eval(), test_alpha, decimal)
1899+
npt.assert_almost_equal(rv_inv_beta.eval(), 1.0 / test_beta, decimal)

0 commit comments

Comments
 (0)