|
72 | 72 | TransformValuesMapping,
|
73 | 73 | TransformValuesRewrite,
|
74 | 74 | )
|
| 75 | +from pymc.logprob.utils import ParameterValueError |
75 | 76 | from pymc.testing import Rplusbig, Vector, assert_no_rvs
|
76 | 77 | from tests.distributions.test_transform import check_jacobian_det
|
77 | 78 |
|
@@ -1159,6 +1160,61 @@ def test_special_log_exp_transforms(transform):
|
1159 | 1160 | assert equal_computations([logp_test], [logp_ref])
|
1160 | 1161 |
|
1161 | 1162 |
|
| 1163 | +def test_measurable_power_exponent_with_constant_base(): |
| 1164 | + # test power(2, rv) = exp2(rv) |
| 1165 | + # test negative base fails |
| 1166 | + x_rv_pow = pt.pow(2, pt.random.normal()) |
| 1167 | + x_rv_exp2 = pt.exp2(pt.random.normal()) |
| 1168 | + |
| 1169 | + x_vv_pow = x_rv_pow.clone() |
| 1170 | + x_vv_exp2 = x_rv_exp2.clone() |
| 1171 | + |
| 1172 | + x_logp_fn_pow = pytensor.function([x_vv_pow], pt.sum(logp(x_rv_pow, x_vv_pow))) |
| 1173 | + x_logp_fn_exp2 = pytensor.function([x_vv_exp2], pt.sum(logp(x_rv_exp2, x_vv_exp2))) |
| 1174 | + |
| 1175 | + np.testing.assert_allclose(x_logp_fn_pow(0.1), x_logp_fn_exp2(0.1)) |
| 1176 | + |
| 1177 | + with pytest.raises(ParameterValueError, match="base >= 0"): |
| 1178 | + x_rv_neg = pt.pow(-2, pt.random.normal()) |
| 1179 | + x_vv_neg = x_rv_neg.clone() |
| 1180 | + logp(x_rv_neg, x_vv_neg) |
| 1181 | + |
| 1182 | + |
| 1183 | +def test_measurable_power_exponent_with_variable_base(): |
| 1184 | + # test with RV when logp(<0) we raise error |
| 1185 | + base_rv = pt.random.normal([2]) |
| 1186 | + x_raw_rv = pt.random.normal() |
| 1187 | + x_rv = pt.power(base_rv, x_raw_rv) |
| 1188 | + |
| 1189 | + x_rv.name = "x" |
| 1190 | + base_rv.name = "base" |
| 1191 | + base_vv = base_rv.clone() |
| 1192 | + x_vv = x_rv.clone() |
| 1193 | + |
| 1194 | + res = conditional_logp({base_rv: base_vv, x_rv: x_vv}) |
| 1195 | + x_logp = res[x_vv] |
| 1196 | + logp_vals_fn = pytensor.function([base_vv, x_vv], x_logp) |
| 1197 | + |
| 1198 | + with pytest.raises(ParameterValueError, match="base >= 0"): |
| 1199 | + logp_vals_fn(np.array([-2]), np.array([2])) |
| 1200 | + |
| 1201 | + |
| 1202 | +def test_base_exponent_non_measurable(): |
| 1203 | + # test dual sources of measuravility fails |
| 1204 | + base_rv = pt.random.normal([2]) |
| 1205 | + x_raw_rv = pt.random.normal() |
| 1206 | + x_rv = pt.power(base_rv, x_raw_rv) |
| 1207 | + x_rv.name = "x" |
| 1208 | + |
| 1209 | + x_vv = x_rv.clone() |
| 1210 | + |
| 1211 | + with pytest.raises( |
| 1212 | + RuntimeError, |
| 1213 | + match="The logprob terms of the following value variables could not be derived: {x}", |
| 1214 | + ): |
| 1215 | + conditional_logp({x_rv: x_vv}) |
| 1216 | + |
| 1217 | + |
1162 | 1218 | @pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])])
|
1163 | 1219 | @pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])])
|
1164 | 1220 | def test_multivariate_rv_transform(shift, scale):
|
|
0 commit comments