-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add MixtureSameFamily distribution #4180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a7b313
d1240ac
df899f3
8b9d51f
8c08c7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,9 +28,15 @@ | |
_DrawValuesContext, | ||
_DrawValuesContextBlocker, | ||
) | ||
from .shape_utils import to_tuple, broadcast_distribution_samples | ||
from .shape_utils import ( | ||
to_tuple, | ||
broadcast_distribution_samples, | ||
get_broadcastable_dist_samples, | ||
) | ||
from .continuous import get_tau_sigma, Normal | ||
from ..theanof import _conversion_map | ||
from ..theanof import _conversion_map, take_along_axis | ||
|
||
__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"] | ||
|
||
|
||
def all_discrete(comp_dists): | ||
|
@@ -612,3 +618,241 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, * | |
|
||
def _distr_parameters_for_repr(self): | ||
return ["w", "mu", "sigma"] | ||
|
||
|
||
class MixtureSameFamily(Distribution): | ||
R""" | ||
Mixture Same Family log-likelihood | ||
This distribution handles mixtures of multivariate distributions in a vectorized | ||
manner. It is used over Mixture distribution when the mixture components are not | ||
present on the last axis of components' distribution. | ||
|
||
.. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis} | ||
|
||
======== ============================================ | ||
Support :math:`\textrm{support}(f)` | ||
Mean :math:`w\mu` | ||
======== ============================================ | ||
|
||
Parameters | ||
---------- | ||
w: array of floats | ||
w >= 0 and w <= 1 | ||
the mixture weights | ||
comp_dists: PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`) | ||
The `comp_dists` can be scalar or multidimensional distribution. | ||
Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N), | ||
the `mixture_axis` is consumed resulting in the shape of mixture as - | ||
(i_0, ..., i_n, i_n+1, ..., i_N). | ||
mixture_axis: int, default = -1 | ||
Axis representing the mixture components to be reduced in the mixture. | ||
|
||
Notes | ||
----- | ||
The default behaviour resembles Mixture distribution wherein the last axis of component | ||
distribution is reduced. | ||
""" | ||
|
||
def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs): | ||
self.w = tt.as_tensor_variable(w) | ||
if not isinstance(comp_dists, Distribution): | ||
raise TypeError( | ||
"The MixtureSameFamily distribution only accepts Distribution " | ||
f"instances as its components. Got {type(comp_dists)} instead." | ||
) | ||
self.comp_dists = comp_dists | ||
if mixture_axis < 0: | ||
mixture_axis = len(comp_dists.shape) + mixture_axis | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to ensure that after we do this, the |
||
if mixture_axis < 0: | ||
raise ValueError( | ||
"`mixture_axis` is supposed to be in shape of components' distribution. " | ||
f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds." | ||
) | ||
comp_shape = to_tuple(comp_dists.shape) | ||
self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :] | ||
self.mixture_axis = mixture_axis | ||
kwargs.setdefault("dtype", self.comp_dists.dtype) | ||
|
||
# Compute the mode so we don't always have to pass a testval | ||
defaults = kwargs.pop("defaults", []) | ||
event_shape = self.comp_dists.shape[mixture_axis + 1 :] | ||
_w = tt.shape_padleft( | ||
tt.shape_padright(w, len(event_shape)), | ||
len(self.comp_dists.shape) - w.ndim - len(event_shape), | ||
) | ||
mode = take_along_axis( | ||
self.comp_dists.mode, | ||
tt.argmax(_w, keepdims=True), | ||
axis=mixture_axis, | ||
) | ||
self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was kind of lazy and didn't add the |
||
|
||
if not all_discrete(comp_dists): | ||
mean = tt.as_tensor_variable(self.comp_dists.mean) | ||
self.mean = (_w * mean).sum(axis=mixture_axis) | ||
if "mean" not in defaults: | ||
defaults.append("mean") | ||
defaults.append("mode") | ||
|
||
super().__init__(defaults=defaults, *args, **kwargs) | ||
|
||
def logp(self, value): | ||
""" | ||
Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value. | ||
|
||
Parameters | ||
---------- | ||
value : numeric | ||
Value(s) for which log-probability is calculated. If the log probabilities for multiple | ||
values are desired the values must be provided in a numpy array or theano tensor | ||
|
||
Returns | ||
------- | ||
TensorVariable | ||
""" | ||
|
||
comp_dists = self.comp_dists | ||
w = self.w | ||
mixture_axis = self.mixture_axis | ||
|
||
event_shape = comp_dists.shape[mixture_axis + 1 :] | ||
|
||
# To be able to broadcast the comp_dists.logp with w and value | ||
# We first have to pad the shape of w to the right with ones | ||
# so that it can broadcast with the event_shape. | ||
|
||
w = tt.shape_padright(w, len(event_shape)) | ||
|
||
# Second, we have to add the mixture_axis to the value tensor | ||
# To insert the mixture axis at the correct location, we use the | ||
# negative number index. This way, we can also handle situations | ||
# in which, value is an observed value with more batch dimensions | ||
# than the ones present in the comp_dists. | ||
comp_dists_ndim = len(comp_dists.shape) | ||
|
||
value = tt.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim) | ||
|
||
comp_logp = comp_dists.logp(value) | ||
return bound( | ||
logsumexp(tt.log(w) + comp_logp, axis=mixture_axis, keepdims=False), | ||
w >= 0, | ||
w <= 1, | ||
tt.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1), | ||
broadcast_conditions=False, | ||
) | ||
|
||
def random(self, point=None, size=None): | ||
""" | ||
Draw random values from defined ``MixtureSameFamily`` distribution. | ||
|
||
Parameters | ||
---------- | ||
point : dict, optional | ||
Dict of variable values on which random values are to be | ||
conditioned (uses default point if not specified). | ||
size : int, optional | ||
Desired size of random sample (returns one sample if not | ||
specified). | ||
|
||
Returns | ||
------- | ||
array | ||
""" | ||
sample_shape = to_tuple(size) | ||
mixture_axis = self.mixture_axis | ||
|
||
# First we draw values for the mixture component weights | ||
(w,) = draw_values([self.w], point=point, size=size) | ||
|
||
# We now draw random choices from those weights. | ||
# However, we have to ensure that the number of choices has the | ||
# sample_shape present. | ||
w_shape = w.shape | ||
batch_shape = self.comp_dists.shape[: mixture_axis + 1] | ||
param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape | ||
event_shape = self.comp_dists.shape[mixture_axis + 1 :] | ||
|
||
if np.asarray(self.shape).size != 0: | ||
comp_dists_ndim = len(self.comp_dists.shape) | ||
|
||
# If event_shape of both comp_dists and supplied shape matches, | ||
# broadcast only batch_shape | ||
# else broadcast the entire given shape with batch_shape. | ||
if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape): | ||
dist_shape = np.broadcast( | ||
np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis]) | ||
).shape | ||
else: | ||
dist_shape = np.broadcast( | ||
np.empty(self.shape), np.empty(param_shape[:mixture_axis]) | ||
).shape | ||
else: | ||
dist_shape = param_shape[:mixture_axis] | ||
|
||
# Try to determine the size that must be used to get the mixture | ||
# components (i.e. get random choices using w). | ||
# 1. There must be size independent choices based on w. | ||
# 2. There must also be independent draws for each non singleton axis | ||
# of w. | ||
# 3. There must also be independent draws for each dimension added by | ||
# self.shape with respect to the w.ndim. These usually correspond to | ||
# observed variables with batch shapes | ||
wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis] | ||
psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis] | ||
w_sample_size = [] | ||
# Loop through the dist_shape to get the conditions 2 and 3 first | ||
for i in range(len(dist_shape)): | ||
if dist_shape[i] != psh[i] and wsh[i] == 1: | ||
# self.shape[i] is a non singleton dimension (usually caused by | ||
# observed data) | ||
sh = dist_shape[i] | ||
else: | ||
sh = wsh[i] | ||
w_sample_size.append(sh) | ||
|
||
if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape: | ||
w_sample_size = sample_shape + tuple(w_sample_size) | ||
|
||
choices = random_choice(p=w, size=w_sample_size) | ||
|
||
# We now draw samples from the mixture components random method | ||
comp_samples = self.comp_dists.random(point=point, size=size) | ||
if comp_samples.shape[: len(sample_shape)] != sample_shape: | ||
comp_samples = np.broadcast_to( | ||
comp_samples, | ||
shape=sample_shape + comp_samples.shape, | ||
) | ||
|
||
# At this point the shapes of the arrays involved are: | ||
# comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape) | ||
# choices.shape = (sample_shape, batch_shape) | ||
# | ||
# To be able to take the choices along the mixture_axis of the | ||
# comp_samples, we have to add in dimensions to the right of the | ||
# choices array. | ||
# We also need to make sure that the batch_shapes of both the comp_samples | ||
# and choices broadcast with each other. | ||
|
||
choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape))) | ||
|
||
choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size) | ||
|
||
# We now take the choices of the mixture components along the mixture_axis | ||
# but we use the negative index representation to be able to handle the | ||
# sample_shape | ||
samples = np.take_along_axis( | ||
comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape) | ||
) | ||
|
||
# The `samples` array still has the `mixture_axis`, so we must remove it: | ||
output = samples[(..., 0) + (slice(None),) * len(event_shape)] | ||
|
||
# Final oddity: if size == 1, pymc3 defaults to reducing the sample_shape dimension | ||
# We do this to stay consistent with the rest of the package even though | ||
# we shouldn't have to do it. | ||
if size == 1: | ||
output = output[0] | ||
return output | ||
|
||
def _distr_parameters_for_repr(self): | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
mixture_axis=-1
mean the default is to reduce along the last axis? If yes, it'd be nice to add this precision in the doc string just aboveThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should say this in the docstring.