Skip to content

Commit 140dab0

Browse files
add categorical moment (#5176)
Co-authored-by: Farhan Reynaldo <[email protected]>
1 parent 7485ccc commit 140dab0

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,13 +1167,14 @@ class Categorical(Discrete):
11671167
def dist(cls, p, **kwargs):
11681168

11691169
p = at.as_tensor_variable(floatX(p))
1170-
1171-
# mode = at.argmax(p, axis=-1)
1172-
# if mode.ndim == 1:
1173-
# mode = at.squeeze(mode)
1174-
11751170
return super().dist([p], **kwargs)
11761171

1172+
def get_moment(rv, size, p):
1173+
mode = at.argmax(p, axis=-1)
1174+
if not rv_size_is_none(size):
1175+
mode = at.full(size, mode)
1176+
return mode
1177+
11771178
def logp(value, p):
11781179
r"""
11791180
Calculate log-probability of Categorical distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Beta,
99
BetaBinomial,
1010
Binomial,
11+
Categorical,
1112
Cauchy,
1213
ChiSquared,
1314
Constant,
@@ -728,3 +729,22 @@ def test_logitnormal_moment(mu, sigma, size, expected):
728729
with Model() as model:
729730
LogitNormal("x", mu=mu, sigma=sigma, size=size)
730731
assert_moment_is_expected(model, expected)
732+
733+
734+
@pytest.mark.parametrize(
735+
"p, size, expected",
736+
[
737+
(np.array([0.1, 0.3, 0.6]), None, 2),
738+
(np.array([0.6, 0.1, 0.3]), 5, np.full(5, 0)),
739+
(np.full((2, 3), np.array([0.6, 0.1, 0.3])), None, [0, 0]),
740+
(
741+
np.full((2, 3), np.array([0.1, 0.3, 0.6])),
742+
(3, 2),
743+
np.full((3, 2), [2, 2]),
744+
),
745+
],
746+
)
747+
def test_categorical_moment(p, size, expected):
748+
with Model() as model:
749+
Categorical("x", p=p, size=size)
750+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)