@@ -979,16 +979,21 @@ def test_univariate(self, symbolic_rv):
979
979
np .testing .assert_allclose (obs_logp , st .norm ([1 , 2 ]).logpdf ([0.25 , 0.5 ]))
980
980
np .testing .assert_allclose (unobs_logp , st .norm ([3 ]).logpdf ([0.25 ]))
981
981
982
+ @pytest .mark .parametrize ("mutable_shape" , (False , True ))
982
983
@pytest .mark .parametrize ("obs_component_selected" , (True , False ))
983
- def test_multivariate_constant_mask_separable (self , obs_component_selected ):
984
+ def test_multivariate_constant_mask_separable (self , obs_component_selected , mutable_shape ):
984
985
if obs_component_selected :
985
986
mask = np .zeros ((1 , 4 ), dtype = bool )
986
987
else :
987
988
mask = np .ones ((1 , 4 ), dtype = bool )
988
989
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
989
990
unobs_data = np .array ([[0.4 , 0.1 , 0.4 , 0.1 ]])
990
991
991
- rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = (1 , 4 ))
992
+ if mutable_shape :
993
+ shape = (1 , pytensor .shared (np .array (4 , dtype = int )))
994
+ else :
995
+ shape = (1 , 4 )
996
+ rv = pm .Dirichlet .dist (pt .arange (shape [- 1 ]) + 1 , shape = shape )
992
997
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
993
998
994
999
# Test types
@@ -1023,6 +1028,10 @@ def test_multivariate_constant_mask_separable(self, obs_component_selected):
1023
1028
np .testing .assert_allclose (obs_logp , expected_obs_logp )
1024
1029
np .testing .assert_allclose (unobs_logp , expected_unobs_logp )
1025
1030
1031
+ if mutable_shape :
1032
+ shape [- 1 ].set_value (7 )
1033
+ assert tuple (joined_rv .shape .eval ()) == (1 , 7 )
1034
+
1026
1035
def test_multivariate_constant_mask_unseparable (self ):
1027
1036
mask = pt .constant (np .array ([[True , True , False , False ]]))
1028
1037
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
@@ -1097,14 +1106,19 @@ def test_multivariate_shared_mask_separable(self):
1097
1106
np .testing .assert_almost_equal (obs_logp , new_expected_logp )
1098
1107
np .testing .assert_array_equal (unobs_logp , [])
1099
1108
1100
- def test_multivariate_shared_mask_unseparable (self ):
1109
+ @pytest .mark .parametrize ("mutable_shape" , (False , True ))
1110
+ def test_multivariate_shared_mask_unseparable (self , mutable_shape ):
1101
1111
# Even if the mask is initially not mixing support dims,
1102
1112
# it could later be changed in a way that does!
1103
1113
mask = shared (np .array ([[True , True , True , True ]]))
1104
1114
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
1105
1115
unobs_data = np .array ([[0.4 , 0.1 , 0.4 , 0.1 ]])
1106
1116
1107
- rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = (1 , 4 ))
1117
+ if mutable_shape :
1118
+ shape = mask .shape
1119
+ else :
1120
+ shape = (1 , 4 )
1121
+ rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = shape )
1108
1122
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
1109
1123
1110
1124
# Test types
@@ -1134,16 +1148,22 @@ def test_multivariate_shared_mask_unseparable(self):
1134
1148
1135
1149
# Test that we can update a shared mask
1136
1150
mask .set_value (np .array ([[False , False , True , True ]]))
1151
+ equivalent_value = np .array ([0.1 , 0.4 , 0.4 , 0.1 ])
1137
1152
1138
1153
assert tuple (obs_rv .shape .eval ()) == (2 ,)
1139
1154
assert tuple (unobs_rv .shape .eval ()) == (2 ,)
1140
1155
1141
- new_expected_logp = pm .logp (rv , [ 0.1 , 0.4 , 0.4 , 0.1 ] ).eval ()
1156
+ new_expected_logp = pm .logp (rv , equivalent_value ).eval ()
1142
1157
assert not np .isclose (expected_logp , new_expected_logp ) # Otherwise test is weak
1143
1158
obs_logp , unobs_logp = logp_fn ()
1144
1159
np .testing .assert_almost_equal (obs_logp , new_expected_logp )
1145
1160
np .testing .assert_array_equal (unobs_logp , [])
1146
1161
1162
+ if mutable_shape :
1163
+ mask .set_value (np .array ([[False , False , True , False ], [False , False , False , True ]]))
1164
+ assert tuple (obs_rv .shape .eval ()) == (6 ,)
1165
+ assert tuple (unobs_rv .shape .eval ()) == (2 ,)
1166
+
1147
1167
def test_support_point (self ):
1148
1168
x = pm .GaussianRandomWalk .dist (init_dist = pm .Normal .dist (- 5 ), mu = 1 , steps = 9 )
1149
1169
ref_support_point = support_point (x ).eval ()
0 commit comments