Skip to content

Commit aae9dfb

Browse files
juan.lopez.arriazaricardoV94
juan.lopez.arriaza
authored andcommitted
Adding moments for AsymmetricLaplace and SkewNormal and corresponding tests
1 parent b9b9efc commit aae9dfb

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,13 @@ def dist(cls, b, kappa, mu=0, *args, **kwargs):
15991599

16001600
return super().dist([b, kappa, mu], *args, **kwargs)
16011601

1602+
def get_moment(rv, size, b, kappa, mu):
1603+
mean = mu - (kappa - 1 / kappa) / b
1604+
1605+
if not rv_size_is_none(size):
1606+
mean = at.full(size, mean)
1607+
return mean
1608+
16021609
def logp(value, b, kappa, mu):
16031610
"""
16041611
Calculate log-probability of Asymmetric-Laplace distribution at specified value.
@@ -3012,6 +3019,12 @@ def dist(cls, alpha=1, mu=0.0, sigma=None, tau=None, sd=None, *args, **kwargs):
30123019

30133020
return super().dist([mu, sigma, alpha], *args, **kwargs)
30143021

3022+
def get_moment(rv, size, mu, sigma, alpha):
3023+
mean = mu + sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha ** 2) ** 0.5
3024+
if not rv_size_is_none(size):
3025+
mean = at.full(size, mean)
3026+
return mean
3027+
30153028
def logp(value, mu, sigma, alpha):
30163029
"""
30173030
Calculate log-probability of SkewNormal distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy import special
55

66
from pymc.distributions import (
7+
AsymmetricLaplace,
78
Bernoulli,
89
Beta,
910
BetaBinomial,
@@ -35,6 +36,7 @@
3536
Normal,
3637
Pareto,
3738
Poisson,
39+
SkewNormal,
3840
StudentT,
3941
Triangular,
4042
TruncatedNormal,
@@ -764,3 +766,59 @@ def test_moyal_moment(mu, sigma, size, expected):
764766
with Model() as model:
765767
Moyal("x", mu=mu, sigma=sigma, size=size)
766768
assert_moment_is_expected(model, expected)
769+
770+
771+
@pytest.mark.parametrize(
772+
"alpha, mu, sigma, size, expected",
773+
[
774+
(1.0, 1.0, 1.0, None, 1.56418958),
775+
(1, np.ones(5), 1, None, np.full(5, 1.56418958)),
776+
(np.ones(5), 1, np.ones(5), None, np.full(5, 1.56418958)),
777+
(
778+
np.arange(5),
779+
np.arange(1, 6),
780+
np.arange(1, 6),
781+
None,
782+
(1.0, 3.12837917, 5.14094894, 7.02775903, 8.87030861),
783+
),
784+
(
785+
np.arange(5),
786+
np.arange(1, 6),
787+
np.arange(1, 6),
788+
(2, 5),
789+
np.full((2, 5), (1.0, 3.12837917, 5.14094894, 7.02775903, 8.87030861)),
790+
),
791+
],
792+
)
793+
def test_skewnormal_moment(alpha, mu, sigma, size, expected):
794+
with Model() as model:
795+
SkewNormal("x", alpha=alpha, mu=mu, sigma=sigma, size=size)
796+
assert_moment_is_expected(model, expected)
797+
798+
799+
@pytest.mark.parametrize(
800+
"b, kappa, mu, size, expected",
801+
[
802+
(1.0, 1.0, 1.0, None, 1.0),
803+
(1.0, np.ones(5), 1.0, None, np.full(5, 1.0)),
804+
(np.arange(1, 6), 1.0, np.ones(5), None, np.full(5, 1.0)),
805+
(
806+
np.arange(1, 6),
807+
np.arange(1, 6),
808+
np.arange(1, 6),
809+
None,
810+
(1.0, 1.25, 2.111111111111111, 3.0625, 4.04),
811+
),
812+
(
813+
np.arange(1, 6),
814+
np.arange(1, 6),
815+
np.arange(1, 6),
816+
(2, 5),
817+
np.full((2, 5), (1.0, 1.25, 2.111111111111111, 3.0625, 4.04)),
818+
),
819+
],
820+
)
821+
def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
822+
with Model() as model:
823+
AsymmetricLaplace("x", b=b, kappa=kappa, mu=mu, size=size)
824+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)