@@ -469,6 +469,15 @@ def logcdf_fn(value, psi, n, p):
469
469
{"n" : NatSmall , "p" : Unit , "psi" : Unit },
470
470
)
471
471
472
+ @pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
473
+ def test_categorical (self , n ):
474
+ check_logp (
475
+ pm .Categorical ,
476
+ Domain (range (n ), dtype = "int64" , edges = (0 , n )),
477
+ {"p" : Simplex (n )},
478
+ lambda value , p : categorical_logpdf (value , p ),
479
+ )
480
+
472
481
@aesara .config .change_flags (compute_test_value = "raise" )
473
482
def test_categorical_bounds (self ):
474
483
with pm .Model ():
@@ -495,6 +504,14 @@ def test_categorical_negative_p(self, p):
495
504
with pm .Model ():
496
505
x = pm .Categorical ("x" , p = p )
497
506
507
+ def test_categorical_p_not_normalized (self ):
508
+ # test UserWarning is raised for p vals that sum to more than 1
509
+ # and normaliation is triggered
510
+ with pytest .warns (UserWarning , match = "[5]" ):
511
+ with pm .Model () as m :
512
+ x = pm .Categorical ("x" , p = [1 , 1 , 1 , 1 , 1 ])
513
+ assert np .isclose (m .x .owner .inputs [3 ].sum ().eval (), 1.0 )
514
+
498
515
def test_categorical_negative_p_symbolic (self ):
499
516
with pytest .raises (ParameterValueError ):
500
517
value = np .array ([[1 , 1 , 1 ]])
@@ -507,23 +524,6 @@ def test_categorical_p_not_normalized_symbolic(self):
507
524
invalid_dist = pm .Categorical .dist (p = at .as_tensor_variable ([2 , 2 , 2 ]))
508
525
pm .logp (invalid_dist , value ).eval ()
509
526
510
- @pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
511
- def test_categorical (self , n ):
512
- check_logp (
513
- pm .Categorical ,
514
- Domain (range (n ), dtype = "int64" , edges = (0 , n )),
515
- {"p" : Simplex (n )},
516
- lambda value , p : categorical_logpdf (value , p ),
517
- )
518
-
519
- def test_categorical_p_not_normalized (self ):
520
- # test UserWarning is raised for p vals that sum to more than 1
521
- # and normaliation is triggered
522
- with pytest .warns (UserWarning , match = "[5]" ):
523
- with pm .Model () as m :
524
- x = pm .Categorical ("x" , p = [1 , 1 , 1 , 1 , 1 ])
525
- assert np .isclose (m .x .owner .inputs [3 ].sum ().eval (), 1.0 )
526
-
527
527
@pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
528
528
def test_orderedlogistic (self , n ):
529
529
with warnings .catch_warnings ():
0 commit comments