Skip to content

Commit 7c86ed7

Browse files
Make discrete moments robust against int-overflows
1 parent d1e868f commit 7c86ed7

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

pymc/distributions/discrete.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,11 @@ def dist(cls, N, k, n, *args, **kwargs):
878878
return super().dist([good, bad, n], *args, **kwargs)
879879

880880
def moment(rv, size, good, bad, n):
881-
N, k = good + bad, good
881+
# Cast to float because the intX can be int8
882+
# which could trigger an integer overflow below.
883+
n = floatX(n)
884+
k = floatX(good)
885+
N = k + floatX(bad)
882886
mode = at.floor((n + 1) * (k + 1) / (N + 2))
883887
if not rv_size_is_none(size):
884888
mode = at.full(size, mode)
@@ -1014,6 +1018,8 @@ def dist(cls, lower, upper, *args, **kwargs):
10141018
return super().dist([lower, upper], **kwargs)
10151019

10161020
def moment(rv, size, lower, upper):
1021+
upper = floatX(upper)
1022+
lower = floatX(lower)
10171023
mode = at.maximum(at.floor((upper + lower) / 2.0), lower)
10181024
if not rv_size_is_none(size):
10191025
mode = at.full(size, mode)

pymc/tests/distributions/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
579579

580580
assert moment.shape == expected.shape
581581
assert expected.shape == random_draw.shape
582-
assert np.allclose(moment, expected)
582+
np.testing.assert_allclose(moment, expected, atol=1e-10)
583583

584584
if check_finite_logp:
585585
logp_moment = (

0 commit comments

Comments
 (0)