Skip to content

Commit 23cffdc

Browse files
committed
Fix exponential and gamma logp / random link
1 parent c9fa127 commit 23cffdc

File tree

2 files changed

+68
-51
lines changed

2 files changed

+68
-51
lines changed

pymc3/distributions/continuous.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,49 +1377,53 @@ class Exponential(PositiveContinuous):
13771377
@classmethod
13781378
def dist(cls, lam, *args, **kwargs):
13791379
lam = at.as_tensor_variable(floatX(lam))
1380-
# mean = 1.0 / lam
1381-
# median = mean * at.log(2)
1382-
# mode = at.zeros_like(lam)
1383-
1384-
# variance = lam ** -2
13851380

13861381
assert_negative_support(lam, "lam", "Exponential")
1387-
return super().dist([lam], **kwargs)
13881382

1389-
def logp(value, lam):
1383+
# Aesara exponential op is parametrized in terms of mu (1/lam)
1384+
return super().dist([at.inv(lam)], **kwargs)
1385+
1386+
def logp(value, mu):
13901387
"""
13911388
Calculate log-probability of Exponential distribution at specified value.
13921389
13931390
Parameters
13941391
----------
13951392
value: numeric
1396-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1397-
values are desired the values must be provided in a numpy array or aesara tensor
1393+
Value(s) for which log-probability is calculated. If the log
1394+
probabilities for multiple values are desired the values must be
1395+
provided in a numpy array or aesara tensor
13981396
13991397
Returns
14001398
-------
14011399
TensorVariable
14021400
"""
1403-
return bound(at.log(lam) - lam * value, value >= 0, lam > 0)
1401+
lam = at.inv(mu)
1402+
return bound(
1403+
at.log(lam) - lam * value,
1404+
value >= 0,
1405+
lam > 0,
1406+
)
14041407

1405-
def logcdf(value, lam):
1408+
def logcdf(value, mu):
14061409
r"""
14071410
Compute the log of cumulative distribution function for the Exponential distribution
14081411
at the specified value.
14091412
14101413
Parameters
14111414
----------
14121415
value: numeric or np.ndarray or aesara.tensor
1413-
Value(s) for which log CDF is calculated. If the log CDF for multiple
1414-
values are desired the values must be provided in a numpy array or aesara tensor.
1416+
Value(s) for which log CDF is calculated. If the log CDF for
1417+
multiple values are desired the values must be provided in a numpy
1418+
array or aesara tensor.
14151419
14161420
Returns
14171421
-------
14181422
TensorVariable
14191423
"""
1420-
a = lam * value
1424+
lam = at.inv(mu)
14211425
return bound(
1422-
log1mexp(a),
1426+
log1mexp(lam * value),
14231427
0 <= value,
14241428
0 <= lam,
14251429
)
@@ -2371,15 +2375,13 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, no_assert=Fal
23712375
alpha, beta = cls.get_alpha_beta(alpha, beta, mu, sigma)
23722376
alpha = at.as_tensor_variable(floatX(alpha))
23732377
beta = at.as_tensor_variable(floatX(beta))
2374-
# mean = alpha / beta
2375-
# mode = at.maximum((alpha - 1) / beta, 0)
2376-
# variance = alpha / beta ** 2
23772378

23782379
if not no_assert:
23792380
assert_negative_support(alpha, "alpha", "Gamma")
23802381
assert_negative_support(beta, "beta", "Gamma")
23812382

2382-
return super().dist([alpha, at.inv(beta)], **kwargs)
2383+
# The Aesara `GammaRV` `Op` will invert the `beta` parameter itself
2384+
return super().dist([alpha, beta], **kwargs)
23832385

23842386
@classmethod
23852387
def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
@@ -2397,45 +2399,47 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
23972399

23982400
return alpha, beta
23992401

2400-
def _distr_parameters_for_repr(self):
2401-
return ["alpha", "beta"]
2402-
2403-
def logp(value, alpha, beta):
2402+
def logp(value, alpha, inv_beta):
24042403
"""
24052404
Calculate log-probability of Gamma distribution at specified value.
24062405
24072406
Parameters
24082407
----------
24092408
value: numeric
2410-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
2411-
values are desired the values must be provided in a numpy array or `TensorVariable`.
2409+
Value(s) for which log-probability is calculated. If the log
2410+
probabilities for multiple values are desired the values must be
2411+
provided in a numpy array or `TensorVariable`.
24122412
24132413
Returns
24142414
-------
24152415
TensorVariable
24162416
"""
2417+
beta = at.inv(inv_beta)
24172418
return bound(
24182419
-gammaln(alpha) + logpow(beta, alpha) - beta * value + logpow(value, alpha - 1),
24192420
value >= 0,
24202421
alpha > 0,
24212422
beta > 0,
24222423
)
24232424

2424-
def logcdf(value, alpha, beta):
2425+
def logcdf(value, alpha, inv_beta):
24252426
"""
24262427
Compute the log of the cumulative distribution function for Gamma distribution
24272428
at the specified value.
24282429
24292430
Parameters
24302431
----------
24312432
value: numeric or np.ndarray or `TensorVariable`
2432-
Value(s) for which log CDF is calculated. If the log CDF for multiple
2433-
values are desired the values must be provided in a numpy array or `TensorVariable`.
2433+
Value(s) for which log CDF is calculated. If the log CDF for
2434+
multiple values are desired the values must be provided in a numpy
2435+
array or `TensorVariable`.
24342436
24352437
Returns
24362438
-------
24372439
TensorVariable
24382440
"""
2441+
beta = at.inv(inv_beta)
2442+
24392443
# Avoid C-assertion when the gammainc function is called with invalid values (#4340)
24402444
safe_alpha = at.switch(at.lt(alpha, 0), 0, alpha)
24412445
safe_beta = at.switch(at.lt(beta, 0), 0, beta)

pymc3/tests/test_distributions_random.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import aesara
2121
import numpy as np
2222
import numpy.random as nr
23+
import numpy.testing as npt
2324
import pytest
2425
import scipy.stats as st
2526

@@ -32,7 +33,7 @@
3233
from pymc3.distributions.dist_math import clipped_beta_rvs
3334
from pymc3.distributions.shape_utils import to_tuple
3435
from pymc3.exceptions import ShapeError
35-
from pymc3.tests.helpers import SeededTest
36+
from pymc3.tests.helpers import SeededTest, select_by_precision
3637
from pymc3.tests.test_distributions import (
3738
Domain,
3839
I,
@@ -626,13 +627,6 @@ def ref_rand(size, alpha, beta):
626627

627628
pymc3_random(pm.Beta, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
628629

629-
@pytest.mark.skip(reason="This test is covered by Aesara")
630-
def test_exponential(self):
631-
def ref_rand(size, lam):
632-
return nr.exponential(scale=1.0 / lam, size=size)
633-
634-
pymc3_random(pm.Exponential, {"lam": Rplus}, ref_rand=ref_rand)
635-
636630
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
637631
def test_laplace(self):
638632
def ref_rand(size, mu, b):
@@ -680,20 +674,6 @@ def ref_rand(size, beta):
680674

681675
pymc3_random(pm.HalfCauchy, {"beta": Rplusbig}, ref_rand=ref_rand)
682676

683-
@pytest.mark.skip(reason="This test is covered by Aesara")
684-
def test_gamma_alpha_beta(self):
685-
def ref_rand(size, alpha, beta):
686-
return st.gamma.rvs(alpha, scale=1.0 / beta, size=size)
687-
688-
pymc3_random(pm.Gamma, {"alpha": Rplusbig, "beta": Rplusbig}, ref_rand=ref_rand)
689-
690-
@pytest.mark.skip(reason="This test is covered by Aesara")
691-
def test_gamma_mu_sigma(self):
692-
def ref_rand(size, mu, sigma):
693-
return st.gamma.rvs(mu ** 2 / sigma ** 2, scale=sigma ** 2 / mu, size=size)
694-
695-
pymc3_random(pm.Gamma, {"mu": Rplusbig, "sigma": Rplusbig}, ref_rand=ref_rand)
696-
697677
@pytest.mark.skip(reason="This test is covered by Aesara")
698678
def test_inverse_gamma(self):
699679
def ref_rand(size, alpha, beta):
@@ -1787,7 +1767,7 @@ def test_issue_3758(self):
17871767

17881768
for var in "bcd":
17891769
std = np.std(samples[var] - samples["a"])
1790-
np.testing.assert_allclose(std, 1, rtol=1e-2)
1770+
npt.assert_allclose(std, 1, rtol=1e-2)
17911771

17921772
def test_issue_3829(self):
17931773
with pm.Model() as model:
@@ -1884,3 +1864,36 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
18841864
prior = pm.sample_prior_predictive(samples=sample_shape)
18851865

18861866
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1867+
1868+
1869+
def test_exponential_parameterization():
1870+
test_lambda = floatX(10.0)
1871+
1872+
exp_pymc = pm.Exponential.dist(lam=test_lambda)
1873+
(rv_scale,) = exp_pymc.owner.inputs[3:]
1874+
1875+
npt.assert_almost_equal(rv_scale.eval(), 1 / test_lambda)
1876+
1877+
1878+
def test_gamma_parameterization():
1879+
1880+
test_alpha = floatX(10.0)
1881+
test_beta = floatX(100.0)
1882+
1883+
gamma_pymc = pm.Gamma.dist(alpha=test_alpha, beta=test_beta)
1884+
rv_alpha, rv_inv_beta = gamma_pymc.owner.inputs[3:]
1885+
1886+
assert np.array_equal(rv_alpha.eval(), test_alpha)
1887+
1888+
decimal = select_by_precision(float64=6, float32=3)
1889+
1890+
npt.assert_almost_equal(rv_inv_beta.eval(), 1.0 / test_beta, decimal)
1891+
1892+
test_mu = test_alpha / test_beta
1893+
test_sigma = np.sqrt(test_mu / test_beta)
1894+
1895+
gamma_pymc = pm.Gamma.dist(mu=test_mu, sigma=test_sigma)
1896+
rv_alpha, rv_inv_beta = gamma_pymc.owner.inputs[3:]
1897+
1898+
npt.assert_almost_equal(rv_alpha.eval(), test_alpha, decimal)
1899+
npt.assert_almost_equal(rv_inv_beta.eval(), 1.0 / test_beta, decimal)

0 commit comments

Comments
 (0)