Skip to content

Add MixtureSameFamily distribution #4180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/3926))
- Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126))
- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)).


## PyMC3 3.9.3 (11 August 2020)
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/distributions/mixture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Mixture
.. autosummary::
Mixture
NormalMixture
MixtureSameFamily

.. automodule:: pymc3.distributions.mixture
:members:
2 changes: 2 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@

from .mixture import Mixture
from .mixture import NormalMixture
from .mixture import MixtureSameFamily

from .multivariate import MvNormal
from .multivariate import MatrixNormal
Expand Down Expand Up @@ -164,6 +165,7 @@
"SkewNormal",
"Mixture",
"NormalMixture",
"MixtureSameFamily",
"Triangular",
"DiscreteWeibull",
"Gumbel",
Expand Down
248 changes: 246 additions & 2 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@
_DrawValuesContext,
_DrawValuesContextBlocker,
)
from .shape_utils import to_tuple, broadcast_distribution_samples
from .shape_utils import (
to_tuple,
broadcast_distribution_samples,
get_broadcastable_dist_samples,
)
from .continuous import get_tau_sigma, Normal
from ..theanof import _conversion_map
from ..theanof import _conversion_map, take_along_axis

__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"]


def all_discrete(comp_dists):
Expand Down Expand Up @@ -612,3 +618,241 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, *

def _distr_parameters_for_repr(self):
return ["w", "mu", "sigma"]


class MixtureSameFamily(Distribution):
R"""
Mixture Same Family log-likelihood
This distribution handles mixtures of multivariate distributions in a vectorized
manner. It is used over Mixture distribution when the mixture components are not
present on the last axis of components' distribution.

.. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis}

======== ============================================
Support :math:`\textrm{support}(f)`
Mean :math:`w\mu`
======== ============================================

Parameters
----------
w: array of floats
w >= 0 and w <= 1
the mixture weights
comp_dists: PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`)
The `comp_dists` can be scalar or multidimensional distribution.
Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N),
the `mixture_axis` is consumed resulting in the shape of mixture as -
(i_0, ..., i_n, i_n+1, ..., i_N).
mixture_axis: int, default = -1
Axis representing the mixture components to be reduced in the mixture.

Notes
-----
The default behaviour resembles Mixture distribution wherein the last axis of component
distribution is reduced.
"""

def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does mixture_axis=-1 mean the default is to reduce along the last axis? If yes, it'd be nice to add this precision in the doc string just above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should say this in the docstring.

self.w = tt.as_tensor_variable(w)
if not isinstance(comp_dists, Distribution):
raise TypeError(
"The MixtureSameFamily distribution only accepts Distribution "
f"instances as its components. Got {type(comp_dists)} instead."
)
self.comp_dists = comp_dists
if mixture_axis < 0:
mixture_axis = len(comp_dists.shape) + mixture_axis
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to ensure that after we do this, the mixture_axis is positive and that it is in the comp_dists number of dimensions.

if mixture_axis < 0:
raise ValueError(
"`mixture_axis` is supposed to be in shape of components' distribution. "
f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds."
)
comp_shape = to_tuple(comp_dists.shape)
self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :]
self.mixture_axis = mixture_axis
kwargs.setdefault("dtype", self.comp_dists.dtype)

# Compute the mode so we don't always have to pass a testval
defaults = kwargs.pop("defaults", [])
event_shape = self.comp_dists.shape[mixture_axis + 1 :]
_w = tt.shape_padleft(
tt.shape_padright(w, len(event_shape)),
len(self.comp_dists.shape) - w.ndim - len(event_shape),
)
mode = take_along_axis(
self.comp_dists.mode,
tt.argmax(_w, keepdims=True),
axis=mixture_axis,
)
self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was kind of lazy and didn't add the self.mean. It's only necessary if the components are continuous distributions. If you want, you can try to copy the implementation from the regular Mixture here.


if not all_discrete(comp_dists):
mean = tt.as_tensor_variable(self.comp_dists.mean)
self.mean = (_w * mean).sum(axis=mixture_axis)
if "mean" not in defaults:
defaults.append("mean")
defaults.append("mode")

super().__init__(defaults=defaults, *args, **kwargs)

def logp(self, value):
"""
Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value.

Parameters
----------
value : numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or theano tensor

Returns
-------
TensorVariable
"""

comp_dists = self.comp_dists
w = self.w
mixture_axis = self.mixture_axis

event_shape = comp_dists.shape[mixture_axis + 1 :]

# To be able to broadcast the comp_dists.logp with w and value
# We first have to pad the shape of w to the right with ones
# so that it can broadcast with the event_shape.

w = tt.shape_padright(w, len(event_shape))

# Second, we have to add the mixture_axis to the value tensor
# To insert the mixture axis at the correct location, we use the
# negative number index. This way, we can also handle situations
# in which, value is an observed value with more batch dimensions
# than the ones present in the comp_dists.
comp_dists_ndim = len(comp_dists.shape)

value = tt.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim)

comp_logp = comp_dists.logp(value)
return bound(
logsumexp(tt.log(w) + comp_logp, axis=mixture_axis, keepdims=False),
w >= 0,
w <= 1,
tt.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1),
broadcast_conditions=False,
)

def random(self, point=None, size=None):
"""
Draw random values from defined ``MixtureSameFamily`` distribution.

Parameters
----------
point : dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size : int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
sample_shape = to_tuple(size)
mixture_axis = self.mixture_axis

# First we draw values for the mixture component weights
(w,) = draw_values([self.w], point=point, size=size)

# We now draw random choices from those weights.
# However, we have to ensure that the number of choices has the
# sample_shape present.
w_shape = w.shape
batch_shape = self.comp_dists.shape[: mixture_axis + 1]
param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape
event_shape = self.comp_dists.shape[mixture_axis + 1 :]

if np.asarray(self.shape).size != 0:
comp_dists_ndim = len(self.comp_dists.shape)

# If event_shape of both comp_dists and supplied shape matches,
# broadcast only batch_shape
# else broadcast the entire given shape with batch_shape.
if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape):
dist_shape = np.broadcast(
np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis])
).shape
else:
dist_shape = np.broadcast(
np.empty(self.shape), np.empty(param_shape[:mixture_axis])
).shape
else:
dist_shape = param_shape[:mixture_axis]

# Try to determine the size that must be used to get the mixture
# components (i.e. get random choices using w).
# 1. There must be size independent choices based on w.
# 2. There must also be independent draws for each non singleton axis
# of w.
# 3. There must also be independent draws for each dimension added by
# self.shape with respect to the w.ndim. These usually correspond to
# observed variables with batch shapes
wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis]
psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis]
w_sample_size = []
# Loop through the dist_shape to get the conditions 2 and 3 first
for i in range(len(dist_shape)):
if dist_shape[i] != psh[i] and wsh[i] == 1:
# self.shape[i] is a non singleton dimension (usually caused by
# observed data)
sh = dist_shape[i]
else:
sh = wsh[i]
w_sample_size.append(sh)

if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape:
w_sample_size = sample_shape + tuple(w_sample_size)

choices = random_choice(p=w, size=w_sample_size)

# We now draw samples from the mixture components random method
comp_samples = self.comp_dists.random(point=point, size=size)
if comp_samples.shape[: len(sample_shape)] != sample_shape:
comp_samples = np.broadcast_to(
comp_samples,
shape=sample_shape + comp_samples.shape,
)

# At this point the shapes of the arrays involved are:
# comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape)
# choices.shape = (sample_shape, batch_shape)
#
# To be able to take the choices along the mixture_axis of the
# comp_samples, we have to add in dimensions to the right of the
# choices array.
# We also need to make sure that the batch_shapes of both the comp_samples
# and choices broadcast with each other.

choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape)))

choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size)

# We now take the choices of the mixture components along the mixture_axis
# but we use the negative index representation to be able to handle the
# sample_shape
samples = np.take_along_axis(
comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape)
)

# The `samples` array still has the `mixture_axis`, so we must remove it:
output = samples[(..., 0) + (slice(None),) * len(event_shape)]

# Final oddity: if size == 1, pymc3 defaults to reducing the sample_shape dimension
# We do this to stay consistent with the rest of the package even though
# we shouldn't have to do it.
if size == 1:
output = output[0]
return output

def _distr_parameters_for_repr(self):
return []
90 changes: 90 additions & 0 deletions pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,93 @@ def logp_matches(self, mixture, latent_mix, z, npop, model):
logps.append(z_logp + latent_mix.logp(test_point))
latent_mix_logp = logsumexp(np.array(logps), axis=0)
assert_allclose(mix_logp, latent_mix_logp, rtol=rtol)


class TestMixtureSameFamily(SeededTest):
@classmethod
def setup_class(cls):
super().setup_class()
cls.size = 50
cls.n_samples = 1000
cls.mixture_comps = 10

@pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str)
def test_with_multinomial(self, batch_shape):
p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3))
n = 100 * np.ones((*batch_shape, 1))
w = np.ones(self.mixture_comps) / self.mixture_comps
mixture_axis = len(batch_shape)
with pm.Model() as model:
comp_dists = pm.Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
mixture = pm.MixtureSameFamily(
"mixture",
w=w,
comp_dists=comp_dists,
mixture_axis=mixture_axis,
shape=(*batch_shape, 3),
)
prior = pm.sample_prior_predictive(samples=self.n_samples)

assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3)
assert mixture.random(size=self.size).shape == (self.size, *batch_shape, 3)

if theano.config.floatX == "float32":
rtol = 1e-4
else:
rtol = 1e-7

comp_logp = comp_dists.logp(model.test_point["mixture"].reshape(*batch_shape, 1, 3))
log_sum_exp = logsumexp(
comp_logp.eval() + np.log(w)[..., None], axis=mixture_axis, keepdims=True
).sum()
assert_allclose(
model.logp(model.test_point),
log_sum_exp,
rtol,
)

# TODO: Handle case when `batch_shape` == `sample_shape`.
# See https://github.com/pymc-devs/pymc3/issues/4185 for details.
def test_with_mvnormal(self):
# 10 batch, 3-variate Gaussian
mu = np.random.randn(self.mixture_comps, 3)
mat = np.random.randn(3, 3)
cov = mat @ mat.T
chol = np.linalg.cholesky(cov)
w = np.ones(self.mixture_comps) / self.mixture_comps

with pm.Model() as model:
comp_dists = pm.MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3))
mixture = pm.MixtureSameFamily(
"mixture", w=w, comp_dists=comp_dists, mixture_axis=0, shape=(3,)
)
prior = pm.sample_prior_predictive(samples=self.n_samples)

assert prior["mixture"].shape == (self.n_samples, 3)
assert mixture.random(size=self.size).shape == (self.size, 3)

if theano.config.floatX == "float32":
rtol = 1e-4
else:
rtol = 1e-7

comp_logp = comp_dists.logp(model.test_point["mixture"].reshape(1, 3))
log_sum_exp = logsumexp(
comp_logp.eval() + np.log(w)[..., None], axis=0, keepdims=True
).sum()
assert_allclose(
model.logp(model.test_point),
log_sum_exp,
rtol,
)

def test_broadcasting_in_shape(self):
with pm.Model() as model:
mu = pm.Gamma("mu", 1.0, 1.0, shape=2)
comp_dists = pm.Poisson.dist(mu, shape=2)
mix = pm.MixtureSameFamily(
"mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,)
)
prior = pm.sample_prior_predictive(samples=self.n_samples)

assert prior["mix"].shape == (self.n_samples, 1000)