Skip to content

Commit 7afae2a

Browse files
larryshamalamamichaelosthege
authored andcommitted
Fix pm.Interpolated moment
1 parent 60e2648 commit 7afae2a

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

pymc/distributions/continuous.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3734,8 +3734,11 @@ def dist(cls, x_points, pdf_points, *args, **kwargs):
37343734
return super().dist([x_points, pdf_points, cdf_points], **kwargs)
37353735

37363736
def moment(rv, size, x_points, pdf_points, cdf_points):
3737-
# cdf_points argument is unused
3738-
moment = at.sum(at.mul(x_points, pdf_points))
3737+
"""
3738+
Estimates the expectation integral using the trapezoid rule; cdf_points are not used.
3739+
"""
3740+
x_fx = at.mul(x_points, pdf_points) # x_i * f(x_i) for all xi's in x_points
3741+
moment = at.sum(at.mul(at.diff(x_points), x_fx[1:] + x_fx[:-1])) / 2
37393742

37403743
if not rv_size_is_none(size):
37413744
moment = at.full(size, moment)

pymc/tests/test_distributions_moments.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -892,19 +892,32 @@ def test_categorical_moment(p, size, expected):
892892
np.array([-4, -1, 3, 9, 19]),
893893
np.array([0.1, 0.15, 0.2, 0.25, 0.3]),
894894
None,
895-
1.5458937198067635,
895+
9.34782609,
896896
),
897897
(
898898
np.array([-22, -4, 0, 8, 13]),
899899
np.tile(1 / 5, 5),
900900
(5, 3),
901-
np.full((5, 3), -0.14285714285714296),
901+
np.full((5, 3), -4.5),
902902
),
903903
(
904904
np.arange(-100, 10),
905905
np.arange(1, 111) / 6105,
906906
(2, 5, 3),
907-
np.full((2, 5, 3), -27.584097859327223),
907+
np.full((2, 5, 3), -27.65765766),
908+
),
909+
(
910+
# from https://github.com/pymc-devs/pymc/issues/5959
911+
np.linspace(0, 10, 10),
912+
st.norm.pdf(np.linspace(0, 10, 10), loc=2.5, scale=1),
913+
None,
914+
2.5270134,
915+
),
916+
(
917+
np.linspace(0, 10, 100),
918+
st.norm.pdf(np.linspace(0, 10, 100), loc=2.5, scale=1),
919+
None,
920+
2.51771721,
908921
),
909922
],
910923
)

0 commit comments

Comments
 (0)