diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 1ea75ad4ec..86e7f1e9b9 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -926,6 +926,13 @@ def dist(cls, N, k, n, *args, **kwargs): n = at.as_tensor_variable(intX(n)) return super().dist([good, bad, n], *args, **kwargs) + def get_moment(rv, size, good, bad, n): + N, k = good + bad, good + mode = at.floor((n + 1) * (k + 1) / (N + 2)) + if not rv_size_is_none(size): + mode = at.full(size, mode) + return mode + def logp(value, good, bad, n): r""" Calculate log-probability of HyperGeometric distribution at specified value. @@ -1060,6 +1067,12 @@ def dist(cls, lower, upper, *args, **kwargs): upper = intX(at.floor(upper)) return super().dist([lower, upper], **kwargs) + def get_moment(rv, size, lower, upper): + mode = at.maximum(at.floor((upper + lower) / 2.0), lower) + if not rv_size_is_none(size): + mode = at.full(size, mode) + return mode + def logp(value, lower, upper): r""" Calculate log-probability of DiscreteUniform distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index f091e12c7f..dd9884eea5 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -10,6 +10,7 @@ Cauchy, ChiSquared, Constant, + DiscreteUniform, Exponential, Flat, Gamma, @@ -18,6 +19,7 @@ HalfFlat, HalfNormal, HalfStudentT, + HyperGeometric, Kumaraswamy, Laplace, Logistic, @@ -417,7 +419,12 @@ def test_poisson_moment(mu, size, expected): (10, 0.7, None, 4), (10, 0.7, 5, np.full(5, 4)), (np.full(3, 10), np.arange(1, 4) / 10, None, np.array([90, 40, 23])), - (10, np.arange(1, 4) / 10, (2, 3), np.full((2, 3), np.array([90, 40, 23]))), + ( + 10, + np.arange(1, 4) / 10, + (2, 3), + np.full((2, 3), np.array([90, 40, 23])), + ), ], ) def test_negative_binomial_moment(n, p, size, expected): @@ -461,7 +468,13 @@ def test_zero_inflated_poisson_moment(psi, theta, size, expected): (0.2, 7, 0.7, None, 4), (0.2, 7, 0.3, 5, np.full(5, 2)), (0.6, 25, np.arange(1, 6) / 10, None, np.arange(1, 6)), - (0.6, 25, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))), + ( + 0.6, + 25, + np.arange(1, 6) / 10, + (2, 5), + np.full((2, 5), np.arange(1, 6)), + ), ], ) def test_zero_inflated_binomial_moment(psi, n, p, size, expected): @@ -503,3 +516,44 @@ def test_geometric_moment(p, size, expected): with Model() as model: Geometric("x", p=p, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "N, k, n, size, expected", + [ + (50, 10, 20, None, 4), + (50, 10, 23, 5, np.full(5, 5)), + (50, 10, np.arange(23, 28), None, np.full(5, 5)), + ( + 50, + 10, + np.arange(18, 23), + (2, 5), + np.full((2, 5), 4), + ), + ], +) +def test_hyper_geometric_moment(N, k, n, size, expected): + with Model() as model: + HyperGeometric("x", N=N, k=k, n=n, size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "lower, upper, size, expected", + [ + (1, 5, None, 3), + (1, 5, 5, np.full(5, 3)), + (1, np.arange(5, 22, 4), None, np.arange(3, 13, 2)), + ( + 1, + np.arange(5, 22, 4), + (2, 5), + np.full((2, 5), np.arange(3, 13, 2)), + ), + ], +) +def test_discrete_uniform_moment(lower, upper, size, expected): + with Model() as model: + DiscreteUniform("x", lower=lower, upper=upper, size=size) + assert_moment_is_expected(model, expected)