Skip to content

Commit f41450d

Browse files
committed
Add test that DM is normalized for n=1 case.
1 parent d16677c commit f41450d

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pymc3/tests/test_distributions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,22 @@ def test_dirichlet_multinomial_mode(self, alpha, n):
10651065
shape=alpha.shape)
10661066
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)
10671067

1068+
@pytest.mark.parametrize('alpha,n,enum', [
1069+
[[[.25, .25, .25, .25]], [1], [[1, 0, 0, 0],
1070+
[0, 1, 0, 0],
1071+
[0, 0, 1, 0],
1072+
[0, 0, 0, 1]]]
1073+
])
1074+
def test_dirichlet_multinomial_pmf(self, alpha, n, enum):
1075+
alpha = np.array(alpha)
1076+
n = np.array(n)
1077+
with Model() as model:
1078+
m = DirichletMultinomial('m', n=n, alpha=alpha,
1079+
shape=alpha.shape)
1080+
logp = lambda x: m.distribution.logp(np.array([x])).eval()
1081+
p_all_poss = [np.exp(logp(x)) for x in enum]
1082+
assert_almost_equal(np.sum(p_all_poss), 1)
1083+
10681084
@pytest.mark.parametrize('alpha,n', [
10691085
[[[.25, .25, .25, .25]], [1]],
10701086
[[[.3, .6, .05, .05]], [2]],

0 commit comments

Comments
 (0)