@@ -843,7 +843,7 @@ class TestPoisson(BaseTestDistribution):
843
843
tests_to_run = ["check_pymc_params_match_rv_op" ]
844
844
845
845
846
- class TestMvNormal (BaseTestDistribution ):
846
+ class TestMvNormalCov (BaseTestDistribution ):
847
847
pymc_dist = pm .MvNormal
848
848
pymc_dist_params = {
849
849
"mu" : np .array ([1.0 , 2.0 ]),
@@ -893,6 +893,63 @@ class TestMvNormalTau(BaseTestDistribution):
893
893
tests_to_run = ["check_pymc_params_match_rv_op" ]
894
894
895
895
896
+ class TestMvStudentTCov (BaseTestDistribution ):
897
+ pymc_dist = pm .MvStudentT
898
+ pymc_dist_params = {
899
+ "nu" : 5 ,
900
+ "mu" : np .array ([1.0 , 2.0 ]),
901
+ "cov" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
902
+ }
903
+ expected_rv_op_params = {
904
+ "nu" : 5 ,
905
+ "mu" : np .array ([1.0 , 2.0 ]),
906
+ "cov" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
907
+ }
908
+ sizes_to_check = [None , (1 ), (2 , 3 )]
909
+ sizes_expected = [(2 ,), (1 , 2 ), (2 , 3 , 2 )]
910
+ reference_dist_params = {
911
+ "df" : 5 ,
912
+ "loc" : np .array ([1.0 , 2.0 ]),
913
+ "shape" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
914
+ }
915
+ reference_dist = seeded_scipy_distribution_builder ("multivariate_t" )
916
+ tests_to_run = [
917
+ "check_pymc_params_match_rv_op" ,
918
+ "check_pymc_draws_match_reference" ,
919
+ "check_rv_size" ,
920
+ ]
921
+
922
+
923
+ class TestMvStudentTChol (BaseTestDistribution ):
924
+ pymc_dist = pm .MvStudentT
925
+ pymc_dist_params = {
926
+ "nu" : 5 ,
927
+ "mu" : np .array ([1.0 , 2.0 ]),
928
+ "chol" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
929
+ }
930
+ expected_rv_op_params = {
931
+ "nu" : 5 ,
932
+ "mu" : np .array ([1.0 , 2.0 ]),
933
+ "cov" : quaddist_matrix (chol = pymc_dist_params ["chol" ]).eval (),
934
+ }
935
+ tests_to_run = ["check_pymc_params_match_rv_op" ]
936
+
937
+
938
+ class TestMvStudentTTau (BaseTestDistribution ):
939
+ pymc_dist = pm .MvStudentT
940
+ pymc_dist_params = {
941
+ "nu" : 5 ,
942
+ "mu" : np .array ([1.0 , 2.0 ]),
943
+ "tau" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
944
+ }
945
+ expected_rv_op_params = {
946
+ "nu" : 5 ,
947
+ "mu" : np .array ([1.0 , 2.0 ]),
948
+ "cov" : quaddist_matrix (tau = pymc_dist_params ["tau" ]).eval (),
949
+ }
950
+ tests_to_run = ["check_pymc_params_match_rv_op" ]
951
+
952
+
896
953
class TestDirichlet (BaseTestDistribution ):
897
954
pymc_dist = pm .Dirichlet
898
955
pymc_dist_params = {"a" : np .array ([1.0 , 2.0 ])}
@@ -1402,21 +1459,6 @@ def ref_rand_evd(size, mu, evds, sigma):
1402
1459
model_args = evd_args ,
1403
1460
)
1404
1461
1405
- def test_mv_t (self ):
1406
- def ref_rand (size , nu , Sigma , mu ):
1407
- normal = st .multivariate_normal .rvs (cov = Sigma , size = size )
1408
- chi2 = st .chi2 .rvs (df = nu , size = size )[..., None ]
1409
- return mu + (normal / np .sqrt (chi2 / nu ))
1410
-
1411
- for n in [2 , 3 ]:
1412
- pymc3_random (
1413
- pm .MvStudentT ,
1414
- {"nu" : Domain ([5 , 10 , 25 , 50 ]), "Sigma" : PdMatrix (n ), "mu" : Vector (R , n )},
1415
- size = 100 ,
1416
- valuedomain = Vector (R , n ),
1417
- ref_rand = ref_rand ,
1418
- )
1419
-
1420
1462
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
1421
1463
def test_dirichlet_multinomial (self ):
1422
1464
def ref_rand (size , a , n ):
0 commit comments