diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 741bfa7216..9199eaa931 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -18,6 +18,7 @@ - Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926)) - Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/3926)) - Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126)) +- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)). ## PyMC3 3.9.3 (11 August 2020) diff --git a/docs/source/api/distributions/mixture.rst b/docs/source/api/distributions/mixture.rst index 99c82e9f36..ee79f018a9 100644 --- a/docs/source/api/distributions/mixture.rst +++ b/docs/source/api/distributions/mixture.rst @@ -6,6 +6,7 @@ Mixture .. autosummary:: Mixture NormalMixture + MixtureSameFamily .. automodule:: pymc3.distributions.mixture :members: diff --git a/pymc3/distributions/__init__.py b/pymc3/distributions/__init__.py index d396d61dd6..fce98766f0 100644 --- a/pymc3/distributions/__init__.py +++ b/pymc3/distributions/__init__.py @@ -79,6 +79,7 @@ from .mixture import Mixture from .mixture import NormalMixture +from .mixture import MixtureSameFamily from .multivariate import MvNormal from .multivariate import MatrixNormal @@ -164,6 +165,7 @@ "SkewNormal", "Mixture", "NormalMixture", + "MixtureSameFamily", "Triangular", "DiscreteWeibull", "Gumbel", diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index 2c284c2310..7f88acbb7e 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -28,9 +28,15 @@ _DrawValuesContext, _DrawValuesContextBlocker, ) -from .shape_utils import to_tuple, broadcast_distribution_samples +from .shape_utils import ( + to_tuple, + broadcast_distribution_samples, + get_broadcastable_dist_samples, +) from .continuous import get_tau_sigma, Normal -from ..theanof import _conversion_map +from ..theanof import _conversion_map, take_along_axis + +__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"] def all_discrete(comp_dists): @@ -612,3 +618,241 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, * def _distr_parameters_for_repr(self): return ["w", "mu", "sigma"] + + +class MixtureSameFamily(Distribution): + R""" + Mixture Same Family log-likelihood + This distribution handles mixtures of multivariate distributions in a vectorized + manner. It is used over Mixture distribution when the mixture components are not + present on the last axis of components' distribution. + + .. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis} + + ======== ============================================ + Support :math:`\textrm{support}(f)` + Mean :math:`w\mu` + ======== ============================================ + + Parameters + ---------- + w: array of floats + w >= 0 and w <= 1 + the mixture weights + comp_dists: PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`) + The `comp_dists` can be scalar or multidimensional distribution. + Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N), + the `mixture_axis` is consumed resulting in the shape of mixture as - + (i_0, ..., i_n, i_n+1, ..., i_N). + mixture_axis: int, default = -1 + Axis representing the mixture components to be reduced in the mixture. + + Notes + ----- + The default behaviour resembles Mixture distribution wherein the last axis of component + distribution is reduced. + """ + + def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs): + self.w = tt.as_tensor_variable(w) + if not isinstance(comp_dists, Distribution): + raise TypeError( + "The MixtureSameFamily distribution only accepts Distribution " + f"instances as its components. Got {type(comp_dists)} instead." + ) + self.comp_dists = comp_dists + if mixture_axis < 0: + mixture_axis = len(comp_dists.shape) + mixture_axis + if mixture_axis < 0: + raise ValueError( + "`mixture_axis` is supposed to be in shape of components' distribution. " + f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds." + ) + comp_shape = to_tuple(comp_dists.shape) + self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :] + self.mixture_axis = mixture_axis + kwargs.setdefault("dtype", self.comp_dists.dtype) + + # Compute the mode so we don't always have to pass a testval + defaults = kwargs.pop("defaults", []) + event_shape = self.comp_dists.shape[mixture_axis + 1 :] + _w = tt.shape_padleft( + tt.shape_padright(w, len(event_shape)), + len(self.comp_dists.shape) - w.ndim - len(event_shape), + ) + mode = take_along_axis( + self.comp_dists.mode, + tt.argmax(_w, keepdims=True), + axis=mixture_axis, + ) + self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)] + + if not all_discrete(comp_dists): + mean = tt.as_tensor_variable(self.comp_dists.mean) + self.mean = (_w * mean).sum(axis=mixture_axis) + if "mean" not in defaults: + defaults.append("mean") + defaults.append("mode") + + super().__init__(defaults=defaults, *args, **kwargs) + + def logp(self, value): + """ + Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value. + + Parameters + ---------- + value : numeric + Value(s) for which log-probability is calculated. If the log probabilities for multiple + values are desired the values must be provided in a numpy array or theano tensor + + Returns + ------- + TensorVariable + """ + + comp_dists = self.comp_dists + w = self.w + mixture_axis = self.mixture_axis + + event_shape = comp_dists.shape[mixture_axis + 1 :] + + # To be able to broadcast the comp_dists.logp with w and value + # We first have to pad the shape of w to the right with ones + # so that it can broadcast with the event_shape. + + w = tt.shape_padright(w, len(event_shape)) + + # Second, we have to add the mixture_axis to the value tensor + # To insert the mixture axis at the correct location, we use the + # negative number index. This way, we can also handle situations + # in which, value is an observed value with more batch dimensions + # than the ones present in the comp_dists. + comp_dists_ndim = len(comp_dists.shape) + + value = tt.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim) + + comp_logp = comp_dists.logp(value) + return bound( + logsumexp(tt.log(w) + comp_logp, axis=mixture_axis, keepdims=False), + w >= 0, + w <= 1, + tt.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1), + broadcast_conditions=False, + ) + + def random(self, point=None, size=None): + """ + Draw random values from defined ``MixtureSameFamily`` distribution. + + Parameters + ---------- + point : dict, optional + Dict of variable values on which random values are to be + conditioned (uses default point if not specified). + size : int, optional + Desired size of random sample (returns one sample if not + specified). + + Returns + ------- + array + """ + sample_shape = to_tuple(size) + mixture_axis = self.mixture_axis + + # First we draw values for the mixture component weights + (w,) = draw_values([self.w], point=point, size=size) + + # We now draw random choices from those weights. + # However, we have to ensure that the number of choices has the + # sample_shape present. + w_shape = w.shape + batch_shape = self.comp_dists.shape[: mixture_axis + 1] + param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape + event_shape = self.comp_dists.shape[mixture_axis + 1 :] + + if np.asarray(self.shape).size != 0: + comp_dists_ndim = len(self.comp_dists.shape) + + # If event_shape of both comp_dists and supplied shape matches, + # broadcast only batch_shape + # else broadcast the entire given shape with batch_shape. + if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape): + dist_shape = np.broadcast( + np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis]) + ).shape + else: + dist_shape = np.broadcast( + np.empty(self.shape), np.empty(param_shape[:mixture_axis]) + ).shape + else: + dist_shape = param_shape[:mixture_axis] + + # Try to determine the size that must be used to get the mixture + # components (i.e. get random choices using w). + # 1. There must be size independent choices based on w. + # 2. There must also be independent draws for each non singleton axis + # of w. + # 3. There must also be independent draws for each dimension added by + # self.shape with respect to the w.ndim. These usually correspond to + # observed variables with batch shapes + wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis] + psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis] + w_sample_size = [] + # Loop through the dist_shape to get the conditions 2 and 3 first + for i in range(len(dist_shape)): + if dist_shape[i] != psh[i] and wsh[i] == 1: + # self.shape[i] is a non singleton dimension (usually caused by + # observed data) + sh = dist_shape[i] + else: + sh = wsh[i] + w_sample_size.append(sh) + + if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape: + w_sample_size = sample_shape + tuple(w_sample_size) + + choices = random_choice(p=w, size=w_sample_size) + + # We now draw samples from the mixture components random method + comp_samples = self.comp_dists.random(point=point, size=size) + if comp_samples.shape[: len(sample_shape)] != sample_shape: + comp_samples = np.broadcast_to( + comp_samples, + shape=sample_shape + comp_samples.shape, + ) + + # At this point the shapes of the arrays involved are: + # comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape) + # choices.shape = (sample_shape, batch_shape) + # + # To be able to take the choices along the mixture_axis of the + # comp_samples, we have to add in dimensions to the right of the + # choices array. + # We also need to make sure that the batch_shapes of both the comp_samples + # and choices broadcast with each other. + + choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape))) + + choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size) + + # We now take the choices of the mixture components along the mixture_axis + # but we use the negative index representation to be able to handle the + # sample_shape + samples = np.take_along_axis( + comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape) + ) + + # The `samples` array still has the `mixture_axis`, so we must remove it: + output = samples[(..., 0) + (slice(None),) * len(event_shape)] + + # Final oddity: if size == 1, pymc3 defaults to reducing the sample_shape dimension + # We do this to stay consistent with the rest of the package even though + # we shouldn't have to do it. + if size == 1: + output = output[0] + return output + + def _distr_parameters_for_repr(self): + return [] diff --git a/pymc3/tests/test_mixture.py b/pymc3/tests/test_mixture.py index 4dfd9fb2ca..27914b6d74 100644 --- a/pymc3/tests/test_mixture.py +++ b/pymc3/tests/test_mixture.py @@ -490,3 +490,93 @@ def logp_matches(self, mixture, latent_mix, z, npop, model): logps.append(z_logp + latent_mix.logp(test_point)) latent_mix_logp = logsumexp(np.array(logps), axis=0) assert_allclose(mix_logp, latent_mix_logp, rtol=rtol) + + +class TestMixtureSameFamily(SeededTest): + @classmethod + def setup_class(cls): + super().setup_class() + cls.size = 50 + cls.n_samples = 1000 + cls.mixture_comps = 10 + + @pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str) + def test_with_multinomial(self, batch_shape): + p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3)) + n = 100 * np.ones((*batch_shape, 1)) + w = np.ones(self.mixture_comps) / self.mixture_comps + mixture_axis = len(batch_shape) + with pm.Model() as model: + comp_dists = pm.Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3)) + mixture = pm.MixtureSameFamily( + "mixture", + w=w, + comp_dists=comp_dists, + mixture_axis=mixture_axis, + shape=(*batch_shape, 3), + ) + prior = pm.sample_prior_predictive(samples=self.n_samples) + + assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3) + assert mixture.random(size=self.size).shape == (self.size, *batch_shape, 3) + + if theano.config.floatX == "float32": + rtol = 1e-4 + else: + rtol = 1e-7 + + comp_logp = comp_dists.logp(model.test_point["mixture"].reshape(*batch_shape, 1, 3)) + log_sum_exp = logsumexp( + comp_logp.eval() + np.log(w)[..., None], axis=mixture_axis, keepdims=True + ).sum() + assert_allclose( + model.logp(model.test_point), + log_sum_exp, + rtol, + ) + + # TODO: Handle case when `batch_shape` == `sample_shape`. + # See https://github.com/pymc-devs/pymc3/issues/4185 for details. + def test_with_mvnormal(self): + # 10 batch, 3-variate Gaussian + mu = np.random.randn(self.mixture_comps, 3) + mat = np.random.randn(3, 3) + cov = mat @ mat.T + chol = np.linalg.cholesky(cov) + w = np.ones(self.mixture_comps) / self.mixture_comps + + with pm.Model() as model: + comp_dists = pm.MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3)) + mixture = pm.MixtureSameFamily( + "mixture", w=w, comp_dists=comp_dists, mixture_axis=0, shape=(3,) + ) + prior = pm.sample_prior_predictive(samples=self.n_samples) + + assert prior["mixture"].shape == (self.n_samples, 3) + assert mixture.random(size=self.size).shape == (self.size, 3) + + if theano.config.floatX == "float32": + rtol = 1e-4 + else: + rtol = 1e-7 + + comp_logp = comp_dists.logp(model.test_point["mixture"].reshape(1, 3)) + log_sum_exp = logsumexp( + comp_logp.eval() + np.log(w)[..., None], axis=0, keepdims=True + ).sum() + assert_allclose( + model.logp(model.test_point), + log_sum_exp, + rtol, + ) + + def test_broadcasting_in_shape(self): + with pm.Model() as model: + mu = pm.Gamma("mu", 1.0, 1.0, shape=2) + comp_dists = pm.Poisson.dist(mu, shape=2) + mix = pm.MixtureSameFamily( + "mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,) + ) + prior = pm.sample_prior_predictive(samples=self.n_samples) + + assert prior["mix"].shape == (self.n_samples, 1000)