|
19 | 19 | import pytest
|
20 | 20 | import scipy.stats as st
|
21 | 21 |
|
| 22 | +from aeppl.transforms import IntervalTransform, LogTransform |
| 23 | +from aeppl.transforms import Simplex as SimplexTransform |
22 | 24 | from aesara import tensor as at
|
| 25 | +from aesara.tensor import TensorVariable |
23 | 26 | from aesara.tensor.random.op import RandomVariable
|
24 | 27 | from numpy.testing import assert_allclose
|
25 | 28 | from scipy.special import logsumexp
|
|
32 | 35 | Exponential,
|
33 | 36 | Gamma,
|
34 | 37 | HalfNormal,
|
| 38 | + HalfStudentT, |
35 | 39 | LKJCholeskyCov,
|
36 | 40 | LogNormal,
|
37 | 41 | Mixture,
|
|
40 | 44 | Normal,
|
41 | 45 | NormalMixture,
|
42 | 46 | Poisson,
|
| 47 | + StickBreakingWeights, |
| 48 | + Triangular, |
| 49 | + Uniform, |
43 | 50 | )
|
44 | 51 | from pymc.distributions.logprob import logp
|
| 52 | +from pymc.distributions.mixture import MixtureTransformWarning |
45 | 53 | from pymc.distributions.shape_utils import to_tuple
|
| 54 | +from pymc.distributions.transforms import _default_transform |
46 | 55 | from pymc.math import expand_packed_triangular
|
47 | 56 | from pymc.model import Model
|
48 | 57 | from pymc.sampling import (
|
@@ -1216,3 +1225,90 @@ def test_list_multivariate_components(self, weights, comp_dists, size, expected)
|
1216 | 1225 | with Model() as model:
|
1217 | 1226 | Mixture("x", weights, comp_dists, size=size)
|
1218 | 1227 | assert_moment_is_expected(model, expected, check_finite_logp=False)
|
| 1228 | + |
| 1229 | + |
| 1230 | +class TestMixtureDefaultTransforms: |
| 1231 | + @pytest.mark.parametrize( |
| 1232 | + "comp_dists, expected_transform_type", |
| 1233 | + [ |
| 1234 | + (Poisson.dist(1, size=2), type(None)), |
| 1235 | + (Normal.dist(size=2), type(None)), |
| 1236 | + (Uniform.dist(size=2), IntervalTransform), |
| 1237 | + (HalfNormal.dist(size=2), LogTransform), |
| 1238 | + ([HalfNormal.dist(), Normal.dist()], type(None)), |
| 1239 | + ([HalfNormal.dist(1), Exponential.dist(1), HalfStudentT.dist(4, 1)], LogTransform), |
| 1240 | + ([Dirichlet.dist([1, 2, 3, 4]), StickBreakingWeights.dist(1, K=3)], SimplexTransform), |
| 1241 | + ([Uniform.dist(0, 1), Uniform.dist(0, 1), Triangular.dist(0, 1)], IntervalTransform), |
| 1242 | + ([Uniform.dist(0, 1), Uniform.dist(0, 2)], type(None)), |
| 1243 | + ], |
| 1244 | + ) |
| 1245 | + def test_expected(self, comp_dists, expected_transform_type): |
| 1246 | + if isinstance(comp_dists, TensorVariable): |
| 1247 | + weights = np.ones(2) / 2 |
| 1248 | + else: |
| 1249 | + weights = np.ones(len(comp_dists)) / len(comp_dists) |
| 1250 | + mix = Mixture.dist(weights, comp_dists) |
| 1251 | + assert isinstance(_default_transform(mix.owner.op, mix), expected_transform_type) |
| 1252 | + |
| 1253 | + def test_hierarchical_interval_transform(self): |
| 1254 | + with Model() as model: |
| 1255 | + lower = Normal("lower", 0.5) |
| 1256 | + upper = Uniform("upper", 0, 1) |
| 1257 | + uniform = Uniform("uniform", -at.abs(lower), at.abs(upper), transform=None) |
| 1258 | + triangular = Triangular( |
| 1259 | + "triangular", -at.abs(lower), at.abs(upper), c=0.25, transform=None |
| 1260 | + ) |
| 1261 | + comp_dists = [ |
| 1262 | + Uniform.dist(-at.abs(lower), at.abs(upper)), |
| 1263 | + Triangular.dist(-at.abs(lower), at.abs(upper), c=0.25), |
| 1264 | + ] |
| 1265 | + mix1 = Mixture("mix1", [0.3, 0.7], comp_dists) |
| 1266 | + mix2 = Mixture("mix2", [0.3, 0.7][::-1], comp_dists[::-1]) |
| 1267 | + |
| 1268 | + ip = model.compute_initial_point() |
| 1269 | + # We want an informative moment, other than zero |
| 1270 | + assert ip["mix1_interval__"] != 0 |
| 1271 | + |
| 1272 | + expected_mix_ip = ( |
| 1273 | + IntervalTransform(args_fn=lambda *args: (-0.5, 0.5)) |
| 1274 | + .forward(0.3 * ip["uniform"] + 0.7 * ip["triangular"]) |
| 1275 | + .eval() |
| 1276 | + ) |
| 1277 | + assert np.isclose(ip["mix1_interval__"], ip["mix2_interval__"]) |
| 1278 | + assert np.isclose(ip["mix1_interval__"], expected_mix_ip) |
| 1279 | + |
| 1280 | + def test_logp(self): |
| 1281 | + with Model() as m: |
| 1282 | + halfnorm = HalfNormal("halfnorm") |
| 1283 | + comp_dists = [HalfNormal.dist(), HalfNormal.dist()] |
| 1284 | + mix_transf = Mixture("mix_transf", w=[0.5, 0.5], comp_dists=comp_dists) |
| 1285 | + mix = Mixture("mix", w=[0.5, 0.5], comp_dists=comp_dists, transform=None) |
| 1286 | + |
| 1287 | + logp_fn = m.compile_logp(vars=[halfnorm, mix_transf, mix], sum=False) |
| 1288 | + test_point = {"halfnorm_log__": 1, "mix_transf_log__": 1, "mix": np.exp(1)} |
| 1289 | + logp_halfnorm, logp_mix_transf, logp_mix = logp_fn(test_point) |
| 1290 | + assert np.isclose(logp_halfnorm, logp_mix_transf) |
| 1291 | + assert np.isclose(logp_halfnorm, logp_mix + 1) |
| 1292 | + |
| 1293 | + def test_warning(self): |
| 1294 | + with Model() as m: |
| 1295 | + comp_dists = [HalfNormal.dist(), Exponential.dist(1)] |
| 1296 | + with pytest.warns(None) as rec: |
| 1297 | + Mixture("mix1", w=[0.5, 0.5], comp_dists=comp_dists) |
| 1298 | + assert not rec |
| 1299 | + |
| 1300 | + comp_dists = [Uniform.dist(0, 1), Uniform.dist(0, 2)] |
| 1301 | + with pytest.warns(MixtureTransformWarning): |
| 1302 | + Mixture("mix2", w=[0.5, 0.5], comp_dists=comp_dists) |
| 1303 | + |
| 1304 | + comp_dists = [Normal.dist(), HalfNormal.dist()] |
| 1305 | + with pytest.warns(MixtureTransformWarning): |
| 1306 | + Mixture("mix3", w=[0.5, 0.5], comp_dists=comp_dists) |
| 1307 | + |
| 1308 | + with pytest.warns(None) as rec: |
| 1309 | + Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None) |
| 1310 | + assert not rec |
| 1311 | + |
| 1312 | + with pytest.warns(None) as rec: |
| 1313 | + Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1) |
| 1314 | + assert not rec |
0 commit comments