Skip to content

Commit fa015e3

Browse files
committed
Implement default transform for Mixtures
1 parent 0b9f9cb commit fa015e3

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

pymc/distributions/mixture.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919

2020
from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
2121
from aeppl.logprob import _logcdf, _logprob
22+
from aeppl.transforms import IntervalTransform
2223
from aesara.compile.builders import OpFromGraph
24+
from aesara.graph.basic import equal_computations
2325
from aesara.tensor import TensorVariable
2426
from aesara.tensor.random.op import RandomVariable
2527

2628
from pymc.aesaraf import change_rv_size
29+
from pymc.distributions import transforms
2730
from pymc.distributions.continuous import Normal, get_tau_sigma
2831
from pymc.distributions.dist_math import check_parameters
2932
from pymc.distributions.distribution import (
@@ -35,6 +38,7 @@
3538
)
3639
from pymc.distributions.logprob import logcdf, logp
3740
from pymc.distributions.shape_utils import to_tuple
41+
from pymc.distributions.transforms import _default_transform
3842
from pymc.util import check_dist_not_registered
3943
from pymc.vartypes import discrete_types
4044

@@ -461,6 +465,83 @@ def marginal_mixture_moment(op, rv, rng, weights, *components):
461465
return mix_moment
462466

463467

468+
# List of transforms that can be used by Mixture, either because they do not require
469+
# special handling or because we have custom logic to enable them. If new default
470+
# transforms are implemented, this list and function should be updated
471+
allowed_default_mixture_transforms = (
472+
transforms.CholeskyCovPacked,
473+
transforms.CircularTransform,
474+
transforms.IntervalTransform,
475+
transforms.LogTransform,
476+
transforms.LogExpM1,
477+
transforms.LogOddsTransform,
478+
transforms.Ordered,
479+
transforms.Simplex,
480+
transforms.SumTo1,
481+
)
482+
483+
484+
class MixtureTransformWarning(UserWarning):
485+
pass
486+
487+
488+
@_default_transform.register(MarginalMixtureRV)
489+
def marginal_mixture_default_transform(op, rv):
490+
def transform_warning():
491+
warnings.warn(
492+
f"No safe default transform found for Mixture distribution {rv}. This can "
493+
"happen when compoments have different supports or default transforms.\n"
494+
"If appropriate, you can specify a custom transform for more efficient sampling.",
495+
MixtureTransformWarning,
496+
stacklevel=2,
497+
)
498+
499+
rng, weights, *components = rv.owner.inputs
500+
501+
default_transforms = [
502+
_default_transform(component.owner.op, component) for component in components
503+
]
504+
505+
# If there are more than one type of default transforms, we do not apply any
506+
if len({type(transform) for transform in default_transforms}) != 1:
507+
transform_warning()
508+
return None
509+
510+
default_transform = default_transforms[0]
511+
512+
if not isinstance(default_transform, allowed_default_mixture_transforms):
513+
transform_warning()
514+
return None
515+
516+
if isinstance(default_transform, IntervalTransform):
517+
# If there are more than one component, we need to check the IntervalTransform
518+
# of the components are actually equivalent (e.g., we don't have an
519+
# Interval(0, 1), and an Interval(0, 2)).
520+
if len(default_transforms) > 1:
521+
value = rv.type()
522+
backward_expressions = [
523+
transform.backward(value, *component.owner.inputs)
524+
for transform, component in zip(default_transforms, components)
525+
]
526+
for expr1, expr2 in zip(backward_expressions[:-1], backward_expressions[1:]):
527+
if not equal_computations([expr1], [expr2]):
528+
transform_warning()
529+
return None
530+
531+
# We need to create a new IntervalTransform that expects the Mixture inputs
532+
args_fn = default_transform.args_fn
533+
534+
def mixture_args_fn(rng, weights, *components):
535+
# We checked that the interval transforms of each component are equivalent,
536+
# so we can just pass the inputs of the first component
537+
return args_fn(*components[0].owner.inputs)
538+
539+
return IntervalTransform(args_fn=mixture_args_fn)
540+
541+
else:
542+
return default_transform
543+
544+
464545
class NormalMixture:
465546
R"""
466547
Normal mixture log-likelihood

pymc/tests/test_mixture.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
import pytest
2020
import scipy.stats as st
2121

22+
from aeppl.transforms import IntervalTransform, LogTransform
23+
from aeppl.transforms import Simplex as SimplexTransform
2224
from aesara import tensor as at
25+
from aesara.tensor import TensorVariable
2326
from aesara.tensor.random.op import RandomVariable
2427
from numpy.testing import assert_allclose
2528
from scipy.special import logsumexp
@@ -32,6 +35,7 @@
3235
Exponential,
3336
Gamma,
3437
HalfNormal,
38+
HalfStudentT,
3539
LKJCholeskyCov,
3640
LogNormal,
3741
Mixture,
@@ -40,9 +44,14 @@
4044
Normal,
4145
NormalMixture,
4246
Poisson,
47+
StickBreakingWeights,
48+
Triangular,
49+
Uniform,
4350
)
4451
from pymc.distributions.logprob import logp
52+
from pymc.distributions.mixture import MixtureTransformWarning
4553
from pymc.distributions.shape_utils import to_tuple
54+
from pymc.distributions.transforms import _default_transform
4655
from pymc.math import expand_packed_triangular
4756
from pymc.model import Model
4857
from pymc.sampling import (
@@ -1216,3 +1225,90 @@ def test_list_multivariate_components(self, weights, comp_dists, size, expected)
12161225
with Model() as model:
12171226
Mixture("x", weights, comp_dists, size=size)
12181227
assert_moment_is_expected(model, expected, check_finite_logp=False)
1228+
1229+
1230+
class TestMixtureDefaultTransforms:
1231+
@pytest.mark.parametrize(
1232+
"comp_dists, expected_transform_type",
1233+
[
1234+
(Poisson.dist(1, size=2), type(None)),
1235+
(Normal.dist(size=2), type(None)),
1236+
(Uniform.dist(size=2), IntervalTransform),
1237+
(HalfNormal.dist(size=2), LogTransform),
1238+
([HalfNormal.dist(), Normal.dist()], type(None)),
1239+
([HalfNormal.dist(1), Exponential.dist(1), HalfStudentT.dist(4, 1)], LogTransform),
1240+
([Dirichlet.dist([1, 2, 3, 4]), StickBreakingWeights.dist(1, K=3)], SimplexTransform),
1241+
([Uniform.dist(0, 1), Uniform.dist(0, 1), Triangular.dist(0, 1)], IntervalTransform),
1242+
([Uniform.dist(0, 1), Uniform.dist(0, 2)], type(None)),
1243+
],
1244+
)
1245+
def test_expected(self, comp_dists, expected_transform_type):
1246+
if isinstance(comp_dists, TensorVariable):
1247+
weights = np.ones(2) / 2
1248+
else:
1249+
weights = np.ones(len(comp_dists)) / len(comp_dists)
1250+
mix = Mixture.dist(weights, comp_dists)
1251+
assert isinstance(_default_transform(mix.owner.op, mix), expected_transform_type)
1252+
1253+
def test_hierarchical_interval_transform(self):
1254+
with Model() as model:
1255+
lower = Normal("lower", 0.5)
1256+
upper = Uniform("upper", 0, 1)
1257+
uniform = Uniform("uniform", -at.abs(lower), at.abs(upper), transform=None)
1258+
triangular = Triangular(
1259+
"triangular", -at.abs(lower), at.abs(upper), c=0.25, transform=None
1260+
)
1261+
comp_dists = [
1262+
Uniform.dist(-at.abs(lower), at.abs(upper)),
1263+
Triangular.dist(-at.abs(lower), at.abs(upper), c=0.25),
1264+
]
1265+
mix1 = Mixture("mix1", [0.3, 0.7], comp_dists)
1266+
mix2 = Mixture("mix2", [0.3, 0.7][::-1], comp_dists[::-1])
1267+
1268+
ip = model.compute_initial_point()
1269+
# We want an informative moment, other than zero
1270+
assert ip["mix1_interval__"] != 0
1271+
1272+
expected_mix_ip = (
1273+
IntervalTransform(args_fn=lambda *args: (-0.5, 0.5))
1274+
.forward(0.3 * ip["uniform"] + 0.7 * ip["triangular"])
1275+
.eval()
1276+
)
1277+
assert np.isclose(ip["mix1_interval__"], ip["mix2_interval__"])
1278+
assert np.isclose(ip["mix1_interval__"], expected_mix_ip)
1279+
1280+
def test_logp(self):
1281+
with Model() as m:
1282+
halfnorm = HalfNormal("halfnorm")
1283+
comp_dists = [HalfNormal.dist(), HalfNormal.dist()]
1284+
mix_transf = Mixture("mix_transf", w=[0.5, 0.5], comp_dists=comp_dists)
1285+
mix = Mixture("mix", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
1286+
1287+
logp_fn = m.compile_logp(vars=[halfnorm, mix_transf, mix], sum=False)
1288+
test_point = {"halfnorm_log__": 1, "mix_transf_log__": 1, "mix": np.exp(1)}
1289+
logp_halfnorm, logp_mix_transf, logp_mix = logp_fn(test_point)
1290+
assert np.isclose(logp_halfnorm, logp_mix_transf)
1291+
assert np.isclose(logp_halfnorm, logp_mix + 1)
1292+
1293+
def test_warning(self):
1294+
with Model() as m:
1295+
comp_dists = [HalfNormal.dist(), Exponential.dist(1)]
1296+
with pytest.warns(None) as rec:
1297+
Mixture("mix1", w=[0.5, 0.5], comp_dists=comp_dists)
1298+
assert not rec
1299+
1300+
comp_dists = [Uniform.dist(0, 1), Uniform.dist(0, 2)]
1301+
with pytest.warns(MixtureTransformWarning):
1302+
Mixture("mix2", w=[0.5, 0.5], comp_dists=comp_dists)
1303+
1304+
comp_dists = [Normal.dist(), HalfNormal.dist()]
1305+
with pytest.warns(MixtureTransformWarning):
1306+
Mixture("mix3", w=[0.5, 0.5], comp_dists=comp_dists)
1307+
1308+
with pytest.warns(None) as rec:
1309+
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
1310+
assert not rec
1311+
1312+
with pytest.warns(None) as rec:
1313+
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
1314+
assert not rec

0 commit comments

Comments
 (0)