Skip to content

Commit 14aa3d0

Browse files
committed
Reorder categorical tests more logically
1 parent 67da585 commit 14aa3d0

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

pymc/tests/distributions/test_discrete.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,15 @@ def logcdf_fn(value, psi, n, p):
469469
{"n": NatSmall, "p": Unit, "psi": Unit},
470470
)
471471

472+
@pytest.mark.parametrize("n", [2, 3, 4])
473+
def test_categorical(self, n):
474+
check_logp(
475+
pm.Categorical,
476+
Domain(range(n), dtype="int64", edges=(0, n)),
477+
{"p": Simplex(n)},
478+
lambda value, p: categorical_logpdf(value, p),
479+
)
480+
472481
@aesara.config.change_flags(compute_test_value="raise")
473482
def test_categorical_bounds(self):
474483
with pm.Model():
@@ -495,6 +504,14 @@ def test_categorical_negative_p(self, p):
495504
with pm.Model():
496505
x = pm.Categorical("x", p=p)
497506

507+
def test_categorical_p_not_normalized(self):
508+
# test UserWarning is raised for p vals that sum to more than 1
509+
# and normaliation is triggered
510+
with pytest.warns(UserWarning, match="[5]"):
511+
with pm.Model() as m:
512+
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
513+
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)
514+
498515
def test_categorical_negative_p_symbolic(self):
499516
with pytest.raises(ParameterValueError):
500517
value = np.array([[1, 1, 1]])
@@ -507,23 +524,6 @@ def test_categorical_p_not_normalized_symbolic(self):
507524
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
508525
pm.logp(invalid_dist, value).eval()
509526

510-
@pytest.mark.parametrize("n", [2, 3, 4])
511-
def test_categorical(self, n):
512-
check_logp(
513-
pm.Categorical,
514-
Domain(range(n), dtype="int64", edges=(0, n)),
515-
{"p": Simplex(n)},
516-
lambda value, p: categorical_logpdf(value, p),
517-
)
518-
519-
def test_categorical_p_not_normalized(self):
520-
# test UserWarning is raised for p vals that sum to more than 1
521-
# and normaliation is triggered
522-
with pytest.warns(UserWarning, match="[5]"):
523-
with pm.Model() as m:
524-
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
525-
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)
526-
527527
@pytest.mark.parametrize("n", [2, 3, 4])
528528
def test_orderedlogistic(self, n):
529529
with warnings.catch_warnings():

0 commit comments

Comments
 (0)