Skip to content

Commit 9373d5a

Browse files
Sayam753lucianopaz
andauthored
Add MixtureSameFamily distribution (#4180)
* Added mixture same distribution and its tests Co-authored-by: lucianopaz <[email protected]> * Fixed pyupgrade error * Fixed suggestions * Written tests for broadcasting Handled broadcasting in case observed data has more batch dimentions Written tests for MvNormal * Added MixtureSameFamily name in rst files Added a mention in RELEASE-NOTES.md Co-authored-by: lucianopaz <[email protected]>
1 parent bd0b3c0 commit 9373d5a

File tree

5 files changed

+340
-2
lines changed

5 files changed

+340
-2
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
1919
- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/3926))
2020
- Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126))
21+
- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)).
2122

2223

2324
## PyMC3 3.9.3 (11 August 2020)

docs/source/api/distributions/mixture.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Mixture
66
.. autosummary::
77
Mixture
88
NormalMixture
9+
MixtureSameFamily
910

1011
.. automodule:: pymc3.distributions.mixture
1112
:members:

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979

8080
from .mixture import Mixture
8181
from .mixture import NormalMixture
82+
from .mixture import MixtureSameFamily
8283

8384
from .multivariate import MvNormal
8485
from .multivariate import MatrixNormal
@@ -164,6 +165,7 @@
164165
"SkewNormal",
165166
"Mixture",
166167
"NormalMixture",
168+
"MixtureSameFamily",
167169
"Triangular",
168170
"DiscreteWeibull",
169171
"Gumbel",

pymc3/distributions/mixture.py

Lines changed: 246 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@
2828
_DrawValuesContext,
2929
_DrawValuesContextBlocker,
3030
)
31-
from .shape_utils import to_tuple, broadcast_distribution_samples
31+
from .shape_utils import (
32+
to_tuple,
33+
broadcast_distribution_samples,
34+
get_broadcastable_dist_samples,
35+
)
3236
from .continuous import get_tau_sigma, Normal
33-
from ..theanof import _conversion_map
37+
from ..theanof import _conversion_map, take_along_axis
38+
39+
__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"]
3440

3541

3642
def all_discrete(comp_dists):
@@ -612,3 +618,241 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, *
612618

613619
def _distr_parameters_for_repr(self):
614620
return ["w", "mu", "sigma"]
621+
622+
623+
class MixtureSameFamily(Distribution):
624+
R"""
625+
Mixture Same Family log-likelihood
626+
This distribution handles mixtures of multivariate distributions in a vectorized
627+
manner. It is used over Mixture distribution when the mixture components are not
628+
present on the last axis of components' distribution.
629+
630+
.. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis}
631+
632+
======== ============================================
633+
Support :math:`\textrm{support}(f)`
634+
Mean :math:`w\mu`
635+
======== ============================================
636+
637+
Parameters
638+
----------
639+
w: array of floats
640+
w >= 0 and w <= 1
641+
the mixture weights
642+
comp_dists: PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`)
643+
The `comp_dists` can be scalar or multidimensional distribution.
644+
Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N),
645+
the `mixture_axis` is consumed resulting in the shape of mixture as -
646+
(i_0, ..., i_n, i_n+1, ..., i_N).
647+
mixture_axis: int, default = -1
648+
Axis representing the mixture components to be reduced in the mixture.
649+
650+
Notes
651+
-----
652+
The default behaviour resembles Mixture distribution wherein the last axis of component
653+
distribution is reduced.
654+
"""
655+
656+
def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs):
657+
self.w = tt.as_tensor_variable(w)
658+
if not isinstance(comp_dists, Distribution):
659+
raise TypeError(
660+
"The MixtureSameFamily distribution only accepts Distribution "
661+
f"instances as its components. Got {type(comp_dists)} instead."
662+
)
663+
self.comp_dists = comp_dists
664+
if mixture_axis < 0:
665+
mixture_axis = len(comp_dists.shape) + mixture_axis
666+
if mixture_axis < 0:
667+
raise ValueError(
668+
"`mixture_axis` is supposed to be in shape of components' distribution. "
669+
f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds."
670+
)
671+
comp_shape = to_tuple(comp_dists.shape)
672+
self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :]
673+
self.mixture_axis = mixture_axis
674+
kwargs.setdefault("dtype", self.comp_dists.dtype)
675+
676+
# Compute the mode so we don't always have to pass a testval
677+
defaults = kwargs.pop("defaults", [])
678+
event_shape = self.comp_dists.shape[mixture_axis + 1 :]
679+
_w = tt.shape_padleft(
680+
tt.shape_padright(w, len(event_shape)),
681+
len(self.comp_dists.shape) - w.ndim - len(event_shape),
682+
)
683+
mode = take_along_axis(
684+
self.comp_dists.mode,
685+
tt.argmax(_w, keepdims=True),
686+
axis=mixture_axis,
687+
)
688+
self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)]
689+
690+
if not all_discrete(comp_dists):
691+
mean = tt.as_tensor_variable(self.comp_dists.mean)
692+
self.mean = (_w * mean).sum(axis=mixture_axis)
693+
if "mean" not in defaults:
694+
defaults.append("mean")
695+
defaults.append("mode")
696+
697+
super().__init__(defaults=defaults, *args, **kwargs)
698+
699+
def logp(self, value):
700+
"""
701+
Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value.
702+
703+
Parameters
704+
----------
705+
value : numeric
706+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
707+
values are desired the values must be provided in a numpy array or theano tensor
708+
709+
Returns
710+
-------
711+
TensorVariable
712+
"""
713+
714+
comp_dists = self.comp_dists
715+
w = self.w
716+
mixture_axis = self.mixture_axis
717+
718+
event_shape = comp_dists.shape[mixture_axis + 1 :]
719+
720+
# To be able to broadcast the comp_dists.logp with w and value
721+
# We first have to pad the shape of w to the right with ones
722+
# so that it can broadcast with the event_shape.
723+
724+
w = tt.shape_padright(w, len(event_shape))
725+
726+
# Second, we have to add the mixture_axis to the value tensor
727+
# To insert the mixture axis at the correct location, we use the
728+
# negative number index. This way, we can also handle situations
729+
# in which, value is an observed value with more batch dimensions
730+
# than the ones present in the comp_dists.
731+
comp_dists_ndim = len(comp_dists.shape)
732+
733+
value = tt.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim)
734+
735+
comp_logp = comp_dists.logp(value)
736+
return bound(
737+
logsumexp(tt.log(w) + comp_logp, axis=mixture_axis, keepdims=False),
738+
w >= 0,
739+
w <= 1,
740+
tt.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1),
741+
broadcast_conditions=False,
742+
)
743+
744+
def random(self, point=None, size=None):
745+
"""
746+
Draw random values from defined ``MixtureSameFamily`` distribution.
747+
748+
Parameters
749+
----------
750+
point : dict, optional
751+
Dict of variable values on which random values are to be
752+
conditioned (uses default point if not specified).
753+
size : int, optional
754+
Desired size of random sample (returns one sample if not
755+
specified).
756+
757+
Returns
758+
-------
759+
array
760+
"""
761+
sample_shape = to_tuple(size)
762+
mixture_axis = self.mixture_axis
763+
764+
# First we draw values for the mixture component weights
765+
(w,) = draw_values([self.w], point=point, size=size)
766+
767+
# We now draw random choices from those weights.
768+
# However, we have to ensure that the number of choices has the
769+
# sample_shape present.
770+
w_shape = w.shape
771+
batch_shape = self.comp_dists.shape[: mixture_axis + 1]
772+
param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape
773+
event_shape = self.comp_dists.shape[mixture_axis + 1 :]
774+
775+
if np.asarray(self.shape).size != 0:
776+
comp_dists_ndim = len(self.comp_dists.shape)
777+
778+
# If event_shape of both comp_dists and supplied shape matches,
779+
# broadcast only batch_shape
780+
# else broadcast the entire given shape with batch_shape.
781+
if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape):
782+
dist_shape = np.broadcast(
783+
np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis])
784+
).shape
785+
else:
786+
dist_shape = np.broadcast(
787+
np.empty(self.shape), np.empty(param_shape[:mixture_axis])
788+
).shape
789+
else:
790+
dist_shape = param_shape[:mixture_axis]
791+
792+
# Try to determine the size that must be used to get the mixture
793+
# components (i.e. get random choices using w).
794+
# 1. There must be size independent choices based on w.
795+
# 2. There must also be independent draws for each non singleton axis
796+
# of w.
797+
# 3. There must also be independent draws for each dimension added by
798+
# self.shape with respect to the w.ndim. These usually correspond to
799+
# observed variables with batch shapes
800+
wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis]
801+
psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis]
802+
w_sample_size = []
803+
# Loop through the dist_shape to get the conditions 2 and 3 first
804+
for i in range(len(dist_shape)):
805+
if dist_shape[i] != psh[i] and wsh[i] == 1:
806+
# self.shape[i] is a non singleton dimension (usually caused by
807+
# observed data)
808+
sh = dist_shape[i]
809+
else:
810+
sh = wsh[i]
811+
w_sample_size.append(sh)
812+
813+
if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape:
814+
w_sample_size = sample_shape + tuple(w_sample_size)
815+
816+
choices = random_choice(p=w, size=w_sample_size)
817+
818+
# We now draw samples from the mixture components random method
819+
comp_samples = self.comp_dists.random(point=point, size=size)
820+
if comp_samples.shape[: len(sample_shape)] != sample_shape:
821+
comp_samples = np.broadcast_to(
822+
comp_samples,
823+
shape=sample_shape + comp_samples.shape,
824+
)
825+
826+
# At this point the shapes of the arrays involved are:
827+
# comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape)
828+
# choices.shape = (sample_shape, batch_shape)
829+
#
830+
# To be able to take the choices along the mixture_axis of the
831+
# comp_samples, we have to add in dimensions to the right of the
832+
# choices array.
833+
# We also need to make sure that the batch_shapes of both the comp_samples
834+
# and choices broadcast with each other.
835+
836+
choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape)))
837+
838+
choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size)
839+
840+
# We now take the choices of the mixture components along the mixture_axis
841+
# but we use the negative index representation to be able to handle the
842+
# sample_shape
843+
samples = np.take_along_axis(
844+
comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape)
845+
)
846+
847+
# The `samples` array still has the `mixture_axis`, so we must remove it:
848+
output = samples[(..., 0) + (slice(None),) * len(event_shape)]
849+
850+
# Final oddity: if size == 1, pymc3 defaults to reducing the sample_shape dimension
851+
# We do this to stay consistent with the rest of the package even though
852+
# we shouldn't have to do it.
853+
if size == 1:
854+
output = output[0]
855+
return output
856+
857+
def _distr_parameters_for_repr(self):
858+
return []

pymc3/tests/test_mixture.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,93 @@ def logp_matches(self, mixture, latent_mix, z, npop, model):
490490
logps.append(z_logp + latent_mix.logp(test_point))
491491
latent_mix_logp = logsumexp(np.array(logps), axis=0)
492492
assert_allclose(mix_logp, latent_mix_logp, rtol=rtol)
493+
494+
495+
class TestMixtureSameFamily(SeededTest):
496+
@classmethod
497+
def setup_class(cls):
498+
super().setup_class()
499+
cls.size = 50
500+
cls.n_samples = 1000
501+
cls.mixture_comps = 10
502+
503+
@pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str)
504+
def test_with_multinomial(self, batch_shape):
505+
p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3))
506+
n = 100 * np.ones((*batch_shape, 1))
507+
w = np.ones(self.mixture_comps) / self.mixture_comps
508+
mixture_axis = len(batch_shape)
509+
with pm.Model() as model:
510+
comp_dists = pm.Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
511+
mixture = pm.MixtureSameFamily(
512+
"mixture",
513+
w=w,
514+
comp_dists=comp_dists,
515+
mixture_axis=mixture_axis,
516+
shape=(*batch_shape, 3),
517+
)
518+
prior = pm.sample_prior_predictive(samples=self.n_samples)
519+
520+
assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3)
521+
assert mixture.random(size=self.size).shape == (self.size, *batch_shape, 3)
522+
523+
if theano.config.floatX == "float32":
524+
rtol = 1e-4
525+
else:
526+
rtol = 1e-7
527+
528+
comp_logp = comp_dists.logp(model.test_point["mixture"].reshape(*batch_shape, 1, 3))
529+
log_sum_exp = logsumexp(
530+
comp_logp.eval() + np.log(w)[..., None], axis=mixture_axis, keepdims=True
531+
).sum()
532+
assert_allclose(
533+
model.logp(model.test_point),
534+
log_sum_exp,
535+
rtol,
536+
)
537+
538+
# TODO: Handle case when `batch_shape` == `sample_shape`.
539+
# See https://github.com/pymc-devs/pymc3/issues/4185 for details.
540+
def test_with_mvnormal(self):
541+
# 10 batch, 3-variate Gaussian
542+
mu = np.random.randn(self.mixture_comps, 3)
543+
mat = np.random.randn(3, 3)
544+
cov = mat @ mat.T
545+
chol = np.linalg.cholesky(cov)
546+
w = np.ones(self.mixture_comps) / self.mixture_comps
547+
548+
with pm.Model() as model:
549+
comp_dists = pm.MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3))
550+
mixture = pm.MixtureSameFamily(
551+
"mixture", w=w, comp_dists=comp_dists, mixture_axis=0, shape=(3,)
552+
)
553+
prior = pm.sample_prior_predictive(samples=self.n_samples)
554+
555+
assert prior["mixture"].shape == (self.n_samples, 3)
556+
assert mixture.random(size=self.size).shape == (self.size, 3)
557+
558+
if theano.config.floatX == "float32":
559+
rtol = 1e-4
560+
else:
561+
rtol = 1e-7
562+
563+
comp_logp = comp_dists.logp(model.test_point["mixture"].reshape(1, 3))
564+
log_sum_exp = logsumexp(
565+
comp_logp.eval() + np.log(w)[..., None], axis=0, keepdims=True
566+
).sum()
567+
assert_allclose(
568+
model.logp(model.test_point),
569+
log_sum_exp,
570+
rtol,
571+
)
572+
573+
def test_broadcasting_in_shape(self):
574+
with pm.Model() as model:
575+
mu = pm.Gamma("mu", 1.0, 1.0, shape=2)
576+
comp_dists = pm.Poisson.dist(mu, shape=2)
577+
mix = pm.MixtureSameFamily(
578+
"mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,)
579+
)
580+
prior = pm.sample_prior_predictive(samples=self.n_samples)
581+
582+
assert prior["mix"].shape == (self.n_samples, 1000)

0 commit comments

Comments
 (0)