@@ -758,15 +758,34 @@ def test_categorical_moment(p, size, expected):
758
758
(np .ones (1 ), np .identity (1 ), None , np .ones (1 )),
759
759
(np .ones (10 ), np .identity (10 ), None , np .ones (10 )),
760
760
(np .ones (2 ), np .identity (2 ), 4 , np .ones ((4 , 2 ))),
761
- (np .ones (2 ), np .identity (2 ), (4 , 2 ), np .ones ((4 , 2 , 2 ))),
762
- (np .ones ((2 , 2 )), np .identity (2 ), None , np .ones ((2 , 2 ))),
763
- (np .ones ((2 , 2 )), np .identity (2 ), 4 , np .ones ((4 , 2 , 2 ))),
764
- (np .ones ((2 , 2 )), np .identity (2 ), (4 , 2 ), np .ones ((4 , 2 , 2 , 2 ))),
761
+ (np .ones (2 ), np .identity (2 ), (4 , 3 ), np .ones ((4 , 3 , 2 ))),
762
+ (np .array ([1 , 0 , 3.0 ]), np .identity (3 ), None , np .array ([1 , 0 , 3.0 ])),
763
+ (np .array ([1 , 0 , 3.0 ]), np .identity (3 ), 4 , np .full ((4 , 3 ), [1 , 0 , 3.0 ])),
764
+ (np .array ([1 , 0 , 3.0 ]), np .identity (3 ), (4 , 2 ), np .full ((4 , 2 , 3 ), [1 , 0 , 3.0 ])),
765
+ (
766
+ np .array ([1 , 3.0 ]),
767
+ np .identity (2 ),
768
+ (4 , 5 ),
769
+ np .full ((4 , 5 , 2 ), [1 , 3.0 ]),
770
+ ),
771
+ (
772
+ np .array ([1 , 3.0 ]),
773
+ np .array ([[1.0 , 0.5 ], [0.5 , 2 ]]),
774
+ (4 , 5 ),
775
+ np .full ((4 , 5 , 2 ), [1 , 3.0 ]),
776
+ ),
777
+ (
778
+ np .array ([1 , 3 , 0.0 ]),
779
+ np .array ([[1.0 , 0.5 , 0.1 ], [0.5 , 2 , 0.5 ], [0.1 , 0.5 , 5 ]]),
780
+ (4 , 5 ),
781
+ np .full ((4 , 5 , 3 ), [1 , 3 , 0.0 ]),
782
+ ),
765
783
],
766
784
)
767
785
def test_mv_normal_moment (mu , cov , size , expected ):
768
786
with Model () as model :
769
787
MvNormal ("x" , mu = mu , cov = cov , size = size )
788
+ assert_moment_is_expected (model , expected )
770
789
771
790
772
791
@pytest .mark .parametrize (
0 commit comments