Skip to content

Commit e39bb0d

Browse files
juan.lopez.arriazamorganstrom
authored andcommitted
Adding moments for AsymmetricLaplace and SkewNormal and corresponding tests
1 parent 6811e0e commit e39bb0d

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,13 @@ def dist(cls, b, kappa, mu=0, *args, **kwargs):
15991599

16001600
return super().dist([b, kappa, mu], *args, **kwargs)
16011601

1602+
def get_moment(rv, size, b, kappa, mu):
1603+
mean = mu - (kappa - 1 / kappa) / b
1604+
1605+
if not rv_size_is_none(size):
1606+
mean = at.full(size, mean)
1607+
return mean
1608+
16021609
def logp(value, b, kappa, mu):
16031610
"""
16041611
Calculate log-probability of Asymmetric-Laplace distribution at specified value.
@@ -3008,6 +3015,12 @@ def dist(cls, alpha=1, mu=0.0, sigma=None, tau=None, sd=None, *args, **kwargs):
30083015

30093016
return super().dist([mu, sigma, alpha], *args, **kwargs)
30103017

3018+
def get_moment(rv, size, mu, sigma, alpha):
3019+
mean = mu + sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha ** 2) ** 0.5
3020+
if not rv_size_is_none(size):
3021+
mean = at.full(size, mean)
3022+
return mean
3023+
30113024
def logp(value, mu, sigma, alpha):
30123025
"""
30133026
Calculate log-probability of SkewNormal distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy import special
55

66
from pymc.distributions import (
7+
AsymmetricLaplace,
78
Bernoulli,
89
Beta,
910
BetaBinomial,
@@ -36,6 +37,7 @@
3637
Normal,
3738
Pareto,
3839
Poisson,
40+
SkewNormal,
3941
StudentT,
4042
Triangular,
4143
TruncatedNormal,
@@ -785,3 +787,59 @@ def test_moyal_moment(mu, sigma, size, expected):
785787
with Model() as model:
786788
Moyal("x", mu=mu, sigma=sigma, size=size)
787789
assert_moment_is_expected(model, expected)
790+
791+
792+
@pytest.mark.parametrize(
793+
"alpha, mu, sigma, size, expected",
794+
[
795+
(1.0, 1.0, 1.0, None, 1.56418958),
796+
(1, np.ones(5), 1, None, np.full(5, 1.56418958)),
797+
(np.ones(5), 1, np.ones(5), None, np.full(5, 1.56418958)),
798+
(
799+
np.arange(5),
800+
np.arange(1, 6),
801+
np.arange(1, 6),
802+
None,
803+
(1.0, 3.12837917, 5.14094894, 7.02775903, 8.87030861),
804+
),
805+
(
806+
np.arange(5),
807+
np.arange(1, 6),
808+
np.arange(1, 6),
809+
(2, 5),
810+
np.full((2, 5), (1.0, 3.12837917, 5.14094894, 7.02775903, 8.87030861)),
811+
),
812+
],
813+
)
814+
def test_skewnormal_moment(alpha, mu, sigma, size, expected):
815+
with Model() as model:
816+
SkewNormal("x", alpha=alpha, mu=mu, sigma=sigma, size=size)
817+
assert_moment_is_expected(model, expected)
818+
819+
820+
@pytest.mark.parametrize(
821+
"b, kappa, mu, size, expected",
822+
[
823+
(1.0, 1.0, 1.0, None, 1.0),
824+
(1.0, np.ones(5), 1.0, None, np.full(5, 1.0)),
825+
(np.arange(1, 6), 1.0, np.ones(5), None, np.full(5, 1.0)),
826+
(
827+
np.arange(1, 6),
828+
np.arange(1, 6),
829+
np.arange(1, 6),
830+
None,
831+
(1.0, 1.25, 2.111111111111111, 3.0625, 4.04),
832+
),
833+
(
834+
np.arange(1, 6),
835+
np.arange(1, 6),
836+
np.arange(1, 6),
837+
(2, 5),
838+
np.full((2, 5), (1.0, 1.25, 2.111111111111111, 3.0625, 4.04)),
839+
),
840+
],
841+
)
842+
def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
843+
with Model() as model:
844+
AsymmetricLaplace("x", b=b, kappa=kappa, mu=mu, size=size)
845+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)