Skip to content

Commit 2096416

Browse files
sagartomarricardoV94
authored andcommitted
Added moment for geometric distribution along with test
1 parent 35f9966 commit 2096416

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,12 @@ def dist(cls, p, *args, **kwargs):
819819
p = at.as_tensor_variable(floatX(p))
820820
return super().dist([p], *args, **kwargs)
821821

822+
def get_moment(rv, size, p):
823+
mean = at.round(1.0 / p)
824+
if not rv_size_is_none(size):
825+
mean = at.full(size, mean)
826+
return mean
827+
822828
def logp(value, p):
823829
r"""
824830
Calculate log-probability of Geometric distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Constant,
1313
Exponential,
1414
Gamma,
15+
Geometric,
1516
HalfCauchy,
1617
HalfNormal,
1718
HalfStudentT,
@@ -482,3 +483,18 @@ def test_logistic_moment(mu, s, size, expected):
482483
with Model() as model:
483484
Logistic("x", mu=mu, s=s, size=size)
484485
assert_moment_is_expected(model, expected)
486+
487+
488+
@pytest.mark.parametrize(
489+
"p, size, expected",
490+
[
491+
(0.5, None, 2),
492+
(0.2, 5, 5 * np.ones(5)),
493+
(np.linspace(0.25, 1, 4), None, [4, 2, 1, 1]),
494+
(np.linspace(0.25, 1, 4), (2, 4), np.full((2, 4), [4, 2, 1, 1])),
495+
],
496+
)
497+
def test_geometric_moment(p, size, expected):
498+
with Model() as model:
499+
Geometric("x", p=p, size=size)
500+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)