Skip to content

Commit 75c2e1e

Browse files
Move content of distributions.special into distributions.dist_math (#4760)
* Disable student-t test temporarily Co-authored-by: Ricardo <[email protected]>
1 parent cb6f5b2 commit 75c2e1e

File tree

9 files changed

+87
-125
lines changed

9 files changed

+87
-125
lines changed

.github/workflows/pytest.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ jobs:
4747
--ignore=pymc3/tests/test_minibatches.py
4848
--ignore=pymc3/tests/test_pickling.py
4949
--ignore=pymc3/tests/test_plots.py
50-
--ignore=pymc3/tests/test_special_functions.py
5150
--ignore=pymc3/tests/test_updates.py
5251
--ignore=pymc3/tests/test_examples.py
5352
--ignore=pymc3/tests/test_gp.py
@@ -67,7 +66,6 @@ jobs:
6766
pymc3/tests/test_minibatches.py
6867
pymc3/tests/test_pickling.py
6968
pymc3/tests/test_plots.py
70-
pymc3/tests/test_special_functions.py
7169
pymc3/tests/test_updates.py
7270
7371
- |

pymc3/distributions/continuous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
from aesara.assert_op import Assert
27+
from aesara.tensor import gammaln
2728
from aesara.tensor.random.basic import (
2829
BetaRV,
2930
WeibullRV,
@@ -57,17 +58,16 @@
5758
betaln,
5859
bound,
5960
clipped_beta_rvs,
60-
gammaln,
6161
i0e,
6262
incomplete_beta,
63+
log_i0,
6364
log_normal,
6465
logpow,
6566
normal_lccdf,
6667
normal_lcdf,
6768
zvalue,
6869
)
6970
from pymc3.distributions.distribution import Continuous
70-
from pymc3.distributions.special import log_i0
7171
from pymc3.math import log1mexp, log1pexp, logdiffexp, logit
7272
from pymc3.util import UNSET
7373

pymc3/distributions/dist_math.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
from aesara.graph.op import Op
3030
from aesara.scalar import UnaryScalarOp, upgrade_to_float_no_complex
3131
from aesara.scan import until
32+
from aesara.tensor import gammaln
3233
from aesara.tensor.elemwise import Elemwise
3334
from aesara.tensor.slinalg import Cholesky, Solve
3435

3536
from pymc3.aesaraf import floatX
3637
from pymc3.distributions.shape_utils import to_tuple
37-
from pymc3.distributions.special import gammaln
3838

3939
f = floatX
4040
c = -0.5 * np.log(2.0 * np.pi)
@@ -634,3 +634,41 @@ def clipped_beta_rvs(a, b, size=None, random_state=None, dtype="float64"):
634634
out = scipy.stats.beta.rvs(a, b, size=size, random_state=random_state).astype(dtype)
635635
lower, upper = _beta_clip_values[dtype]
636636
return np.maximum(np.minimum(out, upper), lower)
637+
638+
639+
def multigammaln(a, p):
640+
"""Multivariate Log Gamma
641+
642+
Parameters
643+
----------
644+
a: tensor like
645+
p: int
646+
degrees of freedom. p > 0
647+
"""
648+
i = at.arange(1, p + 1)
649+
return p * (p - 1) * at.log(np.pi) / 4.0 + at.sum(gammaln(a + (1.0 - i) / 2.0), axis=0)
650+
651+
652+
def log_i0(x):
653+
"""
654+
Calculates the logarithm of the 0 order modified Bessel function of the first kind""
655+
"""
656+
return at.switch(
657+
at.lt(x, 5),
658+
at.log1p(
659+
x ** 2.0 / 4.0
660+
+ x ** 4.0 / 64.0
661+
+ x ** 6.0 / 2304.0
662+
+ x ** 8.0 / 147456.0
663+
+ x ** 10.0 / 14745600.0
664+
+ x ** 12.0 / 2123366400.0
665+
),
666+
x
667+
- 0.5 * at.log(2.0 * np.pi * x)
668+
+ at.log1p(
669+
1.0 / (8.0 * x)
670+
+ 9.0 / (128.0 * x ** 2.0)
671+
+ 225.0 / (3072.0 * x ** 3.0)
672+
+ 11025.0 / (98304.0 * x ** 4.0)
673+
),
674+
)

pymc3/distributions/multivariate.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from aesara.graph.basic import Apply
2626
from aesara.graph.op import Op
27+
from aesara.tensor import gammaln
2728
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
2829
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
2930
from aesara.tensor.random.utils import broadcast_params
@@ -41,9 +42,8 @@
4142
from pymc3.aesaraf import floatX, intX
4243
from pymc3.distributions import transforms
4344
from pymc3.distributions.continuous import ChiSquared, Normal
44-
from pymc3.distributions.dist_math import bound, factln, logpow
45+
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
4546
from pymc3.distributions.distribution import Continuous, Discrete
46-
from pymc3.distributions.special import gammaln, multigammaln
4747
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
4848

4949
__all__ = [
@@ -154,7 +154,7 @@ def quaddist_tau(delta, chol_mat):
154154

155155

156156
class MvNormal(Continuous):
157-
R"""
157+
r"""
158158
Multivariate normal log-likelihood.
159159
160160
.. math::
@@ -249,7 +249,7 @@ def _distr_parameters_for_repr(self):
249249

250250

251251
class MvStudentT(Continuous):
252-
R"""
252+
r"""
253253
Multivariate Student-T log-likelihood.
254254
255255
.. math::
@@ -362,7 +362,7 @@ def _distr_parameters_for_repr(self):
362362

363363

364364
class Dirichlet(Continuous):
365-
R"""
365+
r"""
366366
Dirichlet log-likelihood.
367367
368368
.. math::
@@ -452,7 +452,7 @@ def rng_fn(cls, rng, n, p, size):
452452

453453

454454
class Multinomial(Discrete):
455-
R"""
455+
r"""
456456
Multinomial log-likelihood.
457457
458458
Generalizes binomial distribution, but instead of each trial resulting
@@ -525,7 +525,7 @@ def logp(value, n, p):
525525

526526

527527
class DirichletMultinomial(Discrete):
528-
R"""Dirichlet Multinomial log-likelihood.
528+
r"""Dirichlet Multinomial log-likelihood.
529529
530530
Dirichlet mixture of Multinomials distribution, with a marginalized PMF.
531531
@@ -729,7 +729,7 @@ def __str__(self):
729729

730730

731731
class Wishart(Continuous):
732-
R"""
732+
r"""
733733
Wishart log-likelihood.
734734
735735
The Wishart distribution is the probability distribution of the
@@ -946,7 +946,7 @@ def _lkj_normalizing_constant(eta, n):
946946

947947

948948
class _LKJCholeskyCov(Continuous):
949-
R"""Underlying class for covariance matrix with LKJ distributed correlations.
949+
r"""Underlying class for covariance matrix with LKJ distributed correlations.
950950
See docs for LKJCholeskyCov function for more details on how to use it in models.
951951
"""
952952

@@ -1126,7 +1126,7 @@ def _distr_parameters_for_repr(self):
11261126

11271127

11281128
def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=True, *args, **kwargs):
1129-
R"""Wrapper function for covariance matrix with LKJ distributed correlations.
1129+
r"""Wrapper function for covariance matrix with LKJ distributed correlations.
11301130
11311131
This defines a distribution over Cholesky decomposed covariance
11321132
matrices, such that the underlying correlation matrices follow an
@@ -1279,7 +1279,7 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
12791279

12801280

12811281
class LKJCorr(Continuous):
1282-
R"""
1282+
r"""
12831283
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
12841284
12851285
The LKJ distribution is a prior distribution for correlation matrices.
@@ -1435,7 +1435,7 @@ def _distr_parameters_for_repr(self):
14351435

14361436

14371437
class MatrixNormal(Continuous):
1438-
R"""
1438+
r"""
14391439
Matrix-valued normal log-likelihood.
14401440
14411441
.. math::
@@ -1694,7 +1694,7 @@ def _distr_parameters_for_repr(self):
16941694

16951695

16961696
class KroneckerNormal(Continuous):
1697-
R"""
1697+
r"""
16981698
Multivariate normal log-likelihood with Kronecker-structured covariance.
16991699
17001700
.. math::
@@ -1941,7 +1941,7 @@ def _distr_parameters_for_repr(self):
19411941

19421942

19431943
class CAR(Continuous):
1944-
R"""
1944+
r"""
19451945
Likelihood for a conditional autoregression. This is a special case of the
19461946
multivariate normal with an adjacency-structured covariance matrix.
19471947

pymc3/distributions/special.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

pymc3/tests/test_dist_math.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import numpy as np
1717
import numpy.testing as npt
1818
import pytest
19+
import scipy.special
1920

21+
from aesara import config, function
2022
from aesara.tensor.random.basic import multinomial
2123
from scipy import interpolate, stats
2224

@@ -32,7 +34,9 @@
3234
clipped_beta_rvs,
3335
factln,
3436
i0e,
37+
multigammaln,
3538
)
39+
from pymc3.tests.checks import close_to
3640
from pymc3.tests.helpers import verify_grad
3741

3842

@@ -236,3 +240,25 @@ def test_clipped_beta_rvs(dtype):
236240
# equal to zero or one (issue #3898)
237241
values = clipped_beta_rvs(0.01, 0.01, size=1000000, dtype=dtype)
238242
assert not (np.any(values == 0) or np.any(values == 1))
243+
244+
245+
def check_vals(fn1, fn2, *args):
246+
v = fn1(*args)
247+
close_to(v, fn2(*args), 1e-6 if v.dtype == np.float64 else 1e-4)
248+
249+
250+
def test_multigamma():
251+
x = at.vector("x")
252+
p = at.scalar("p")
253+
254+
xvals = [np.array([v], dtype=config.floatX) for v in [0.1, 2, 5, 10, 50, 100]]
255+
256+
multigammaln_ = function([x, p], multigammaln(x, p), mode="FAST_COMPILE")
257+
258+
def ref_multigammaln(a, b):
259+
return np.array(scipy.special.multigammaln(a[0], b), config.floatX)
260+
261+
for p in [0, 1, 2, 3, 4, 100]:
262+
for x in xvals:
263+
if np.all(x > 0.5 * (p - 1)):
264+
check_vals(multigammaln_, ref_multigammaln, x, p)

pymc3/tests/test_distributions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,13 +1367,16 @@ def test_t(self):
13671367
lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam ** -0.5),
13681368
n_samples=10, # relies on slow incomplete beta
13691369
)
1370+
# TODO: reenable when PR #4736 is merged
1371+
"""
13701372
self.check_logcdf(
13711373
StudentT,
13721374
R,
13731375
{"nu": Rplus, "mu": R, "sigma": Rplus},
13741376
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
13751377
n_samples=5, # Just testing alternative parametrization
13761378
)
1379+
"""
13771380

13781381
def test_cauchy(self):
13791382
self.check_logp(

pymc3/tests/test_ode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,9 @@ def system(y, t, p):
395395
ode_model = DifferentialEquation(func=system, t0=0, times=times, n_states=2, n_theta=2)
396396

397397
with pm.Model() as model:
398-
beta = pm.HalfCauchy("beta", 1)
399-
gamma = pm.HalfCauchy("gamma", 1)
400-
sigma = pm.HalfCauchy("sigma", 1, shape=2)
398+
beta = pm.HalfCauchy("beta", 1, initval=1)
399+
gamma = pm.HalfCauchy("gamma", 1, initval=1)
400+
sigma = pm.HalfCauchy("sigma", 1, shape=2, initval=[1, 1])
401401
forward = ode_model(theta=[beta, gamma], y0=[0.99, 0.01])
402402
y = pm.Lognormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)
403403

0 commit comments

Comments
 (0)