@@ -759,6 +759,9 @@ def test_categorical_moment(p, size, expected):
759
759
(np .ones (10 ), np .identity (10 ), None , np .ones (10 )),
760
760
(np .ones (2 ), np .identity (2 ), 4 , np .ones ((4 , 2 ))),
761
761
(np .ones (2 ), np .identity (2 ), (4 , 3 ), np .ones ((4 , 3 , 2 ))),
762
+ (np .ones ((2 , 2 )), np .identity (2 ), None , np .ones ((2 , 2 ))),
763
+ (np .ones ((2 , 2 )), np .identity (2 ), 4 , np .ones ((4 , 2 , 2 ))),
764
+ (np .ones ((2 , 2 )), np .identity (2 ), (4 , 2 ), np .ones ((4 , 2 , 2 , 2 ))),
762
765
(np .array ([1 , 0 , 3.0 ]), np .identity (3 ), None , np .array ([1 , 0 , 3.0 ])),
763
766
(np .array ([1 , 0 , 3.0 ]), np .identity (3 ), 4 , np .full ((4 , 3 ), [1 , 0 , 3.0 ])),
764
767
(np .array ([1 , 0 , 3.0 ]), np .identity (3 ), (4 , 2 ), np .full ((4 , 2 , 3 ), [1 , 0 , 3.0 ])),
@@ -780,6 +783,12 @@ def test_categorical_moment(p, size, expected):
780
783
(4 , 5 ),
781
784
np .full ((4 , 5 , 3 ), [1 , 3 , 0.0 ]),
782
785
),
786
+ (
787
+ np .array ([[3. , 5 ], [1 , 4 ]]),
788
+ np .identity (2 ),
789
+ (4 , 5 ),
790
+ np .full ((4 , 5 , 2 , 2 ), [[3. , 5 ], [1 , 4 ]])
791
+ ),
783
792
],
784
793
)
785
794
def test_mv_normal_moment (mu , cov , size , expected ):
0 commit comments