|
28 | 28 | _DrawValuesContext,
|
29 | 29 | _DrawValuesContextBlocker,
|
30 | 30 | )
|
31 |
| -from .shape_utils import to_tuple, broadcast_distribution_samples |
| 31 | +from .shape_utils import ( |
| 32 | + to_tuple, |
| 33 | + broadcast_distribution_samples, |
| 34 | + get_broadcastable_dist_samples, |
| 35 | +) |
32 | 36 | from .continuous import get_tau_sigma, Normal
|
33 |
| -from ..theanof import _conversion_map |
| 37 | +from ..theanof import _conversion_map, take_along_axis |
| 38 | + |
| 39 | +__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"] |
34 | 40 |
|
35 | 41 |
|
36 | 42 | def all_discrete(comp_dists):
|
@@ -612,3 +618,241 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, *
|
612 | 618 |
|
613 | 619 | def _distr_parameters_for_repr(self):
|
614 | 620 | return ["w", "mu", "sigma"]
|
| 621 | + |
| 622 | + |
| 623 | +class MixtureSameFamily(Distribution): |
| 624 | + R""" |
| 625 | + Mixture Same Family log-likelihood |
| 626 | + This distribution handles mixtures of multivariate distributions in a vectorized |
| 627 | + manner. It is used over Mixture distribution when the mixture components are not |
| 628 | + present on the last axis of components' distribution. |
| 629 | +
|
| 630 | + .. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis} |
| 631 | +
|
| 632 | + ======== ============================================ |
| 633 | + Support :math:`\textrm{support}(f)` |
| 634 | + Mean :math:`w\mu` |
| 635 | + ======== ============================================ |
| 636 | +
|
| 637 | + Parameters |
| 638 | + ---------- |
| 639 | + w: array of floats |
| 640 | + w >= 0 and w <= 1 |
| 641 | + the mixture weights |
| 642 | + comp_dists: PyMC3 distribution (e.g. `pm.Multinomial.dist(...)`) |
| 643 | + The `comp_dists` can be scalar or multidimensional distribution. |
| 644 | + Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N), |
| 645 | + the `mixture_axis` is consumed resulting in the shape of mixture as - |
| 646 | + (i_0, ..., i_n, i_n+1, ..., i_N). |
| 647 | + mixture_axis: int, default = -1 |
| 648 | + Axis representing the mixture components to be reduced in the mixture. |
| 649 | +
|
| 650 | + Notes |
| 651 | + ----- |
| 652 | + The default behaviour resembles Mixture distribution wherein the last axis of component |
| 653 | + distribution is reduced. |
| 654 | + """ |
| 655 | + |
| 656 | + def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs): |
| 657 | + self.w = tt.as_tensor_variable(w) |
| 658 | + if not isinstance(comp_dists, Distribution): |
| 659 | + raise TypeError( |
| 660 | + "The MixtureSameFamily distribution only accepts Distribution " |
| 661 | + f"instances as its components. Got {type(comp_dists)} instead." |
| 662 | + ) |
| 663 | + self.comp_dists = comp_dists |
| 664 | + if mixture_axis < 0: |
| 665 | + mixture_axis = len(comp_dists.shape) + mixture_axis |
| 666 | + if mixture_axis < 0: |
| 667 | + raise ValueError( |
| 668 | + "`mixture_axis` is supposed to be in shape of components' distribution. " |
| 669 | + f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds." |
| 670 | + ) |
| 671 | + comp_shape = to_tuple(comp_dists.shape) |
| 672 | + self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :] |
| 673 | + self.mixture_axis = mixture_axis |
| 674 | + kwargs.setdefault("dtype", self.comp_dists.dtype) |
| 675 | + |
| 676 | + # Compute the mode so we don't always have to pass a testval |
| 677 | + defaults = kwargs.pop("defaults", []) |
| 678 | + event_shape = self.comp_dists.shape[mixture_axis + 1 :] |
| 679 | + _w = tt.shape_padleft( |
| 680 | + tt.shape_padright(w, len(event_shape)), |
| 681 | + len(self.comp_dists.shape) - w.ndim - len(event_shape), |
| 682 | + ) |
| 683 | + mode = take_along_axis( |
| 684 | + self.comp_dists.mode, |
| 685 | + tt.argmax(_w, keepdims=True), |
| 686 | + axis=mixture_axis, |
| 687 | + ) |
| 688 | + self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)] |
| 689 | + |
| 690 | + if not all_discrete(comp_dists): |
| 691 | + mean = tt.as_tensor_variable(self.comp_dists.mean) |
| 692 | + self.mean = (_w * mean).sum(axis=mixture_axis) |
| 693 | + if "mean" not in defaults: |
| 694 | + defaults.append("mean") |
| 695 | + defaults.append("mode") |
| 696 | + |
| 697 | + super().__init__(defaults=defaults, *args, **kwargs) |
| 698 | + |
| 699 | + def logp(self, value): |
| 700 | + """ |
| 701 | + Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value. |
| 702 | +
|
| 703 | + Parameters |
| 704 | + ---------- |
| 705 | + value : numeric |
| 706 | + Value(s) for which log-probability is calculated. If the log probabilities for multiple |
| 707 | + values are desired the values must be provided in a numpy array or theano tensor |
| 708 | +
|
| 709 | + Returns |
| 710 | + ------- |
| 711 | + TensorVariable |
| 712 | + """ |
| 713 | + |
| 714 | + comp_dists = self.comp_dists |
| 715 | + w = self.w |
| 716 | + mixture_axis = self.mixture_axis |
| 717 | + |
| 718 | + event_shape = comp_dists.shape[mixture_axis + 1 :] |
| 719 | + |
| 720 | + # To be able to broadcast the comp_dists.logp with w and value |
| 721 | + # We first have to pad the shape of w to the right with ones |
| 722 | + # so that it can broadcast with the event_shape. |
| 723 | + |
| 724 | + w = tt.shape_padright(w, len(event_shape)) |
| 725 | + |
| 726 | + # Second, we have to add the mixture_axis to the value tensor |
| 727 | + # To insert the mixture axis at the correct location, we use the |
| 728 | + # negative number index. This way, we can also handle situations |
| 729 | + # in which, value is an observed value with more batch dimensions |
| 730 | + # than the ones present in the comp_dists. |
| 731 | + comp_dists_ndim = len(comp_dists.shape) |
| 732 | + |
| 733 | + value = tt.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim) |
| 734 | + |
| 735 | + comp_logp = comp_dists.logp(value) |
| 736 | + return bound( |
| 737 | + logsumexp(tt.log(w) + comp_logp, axis=mixture_axis, keepdims=False), |
| 738 | + w >= 0, |
| 739 | + w <= 1, |
| 740 | + tt.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1), |
| 741 | + broadcast_conditions=False, |
| 742 | + ) |
| 743 | + |
| 744 | + def random(self, point=None, size=None): |
| 745 | + """ |
| 746 | + Draw random values from defined ``MixtureSameFamily`` distribution. |
| 747 | +
|
| 748 | + Parameters |
| 749 | + ---------- |
| 750 | + point : dict, optional |
| 751 | + Dict of variable values on which random values are to be |
| 752 | + conditioned (uses default point if not specified). |
| 753 | + size : int, optional |
| 754 | + Desired size of random sample (returns one sample if not |
| 755 | + specified). |
| 756 | +
|
| 757 | + Returns |
| 758 | + ------- |
| 759 | + array |
| 760 | + """ |
| 761 | + sample_shape = to_tuple(size) |
| 762 | + mixture_axis = self.mixture_axis |
| 763 | + |
| 764 | + # First we draw values for the mixture component weights |
| 765 | + (w,) = draw_values([self.w], point=point, size=size) |
| 766 | + |
| 767 | + # We now draw random choices from those weights. |
| 768 | + # However, we have to ensure that the number of choices has the |
| 769 | + # sample_shape present. |
| 770 | + w_shape = w.shape |
| 771 | + batch_shape = self.comp_dists.shape[: mixture_axis + 1] |
| 772 | + param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape |
| 773 | + event_shape = self.comp_dists.shape[mixture_axis + 1 :] |
| 774 | + |
| 775 | + if np.asarray(self.shape).size != 0: |
| 776 | + comp_dists_ndim = len(self.comp_dists.shape) |
| 777 | + |
| 778 | + # If event_shape of both comp_dists and supplied shape matches, |
| 779 | + # broadcast only batch_shape |
| 780 | + # else broadcast the entire given shape with batch_shape. |
| 781 | + if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape): |
| 782 | + dist_shape = np.broadcast( |
| 783 | + np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis]) |
| 784 | + ).shape |
| 785 | + else: |
| 786 | + dist_shape = np.broadcast( |
| 787 | + np.empty(self.shape), np.empty(param_shape[:mixture_axis]) |
| 788 | + ).shape |
| 789 | + else: |
| 790 | + dist_shape = param_shape[:mixture_axis] |
| 791 | + |
| 792 | + # Try to determine the size that must be used to get the mixture |
| 793 | + # components (i.e. get random choices using w). |
| 794 | + # 1. There must be size independent choices based on w. |
| 795 | + # 2. There must also be independent draws for each non singleton axis |
| 796 | + # of w. |
| 797 | + # 3. There must also be independent draws for each dimension added by |
| 798 | + # self.shape with respect to the w.ndim. These usually correspond to |
| 799 | + # observed variables with batch shapes |
| 800 | + wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis] |
| 801 | + psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis] |
| 802 | + w_sample_size = [] |
| 803 | + # Loop through the dist_shape to get the conditions 2 and 3 first |
| 804 | + for i in range(len(dist_shape)): |
| 805 | + if dist_shape[i] != psh[i] and wsh[i] == 1: |
| 806 | + # self.shape[i] is a non singleton dimension (usually caused by |
| 807 | + # observed data) |
| 808 | + sh = dist_shape[i] |
| 809 | + else: |
| 810 | + sh = wsh[i] |
| 811 | + w_sample_size.append(sh) |
| 812 | + |
| 813 | + if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape: |
| 814 | + w_sample_size = sample_shape + tuple(w_sample_size) |
| 815 | + |
| 816 | + choices = random_choice(p=w, size=w_sample_size) |
| 817 | + |
| 818 | + # We now draw samples from the mixture components random method |
| 819 | + comp_samples = self.comp_dists.random(point=point, size=size) |
| 820 | + if comp_samples.shape[: len(sample_shape)] != sample_shape: |
| 821 | + comp_samples = np.broadcast_to( |
| 822 | + comp_samples, |
| 823 | + shape=sample_shape + comp_samples.shape, |
| 824 | + ) |
| 825 | + |
| 826 | + # At this point the shapes of the arrays involved are: |
| 827 | + # comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape) |
| 828 | + # choices.shape = (sample_shape, batch_shape) |
| 829 | + # |
| 830 | + # To be able to take the choices along the mixture_axis of the |
| 831 | + # comp_samples, we have to add in dimensions to the right of the |
| 832 | + # choices array. |
| 833 | + # We also need to make sure that the batch_shapes of both the comp_samples |
| 834 | + # and choices broadcast with each other. |
| 835 | + |
| 836 | + choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape))) |
| 837 | + |
| 838 | + choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size) |
| 839 | + |
| 840 | + # We now take the choices of the mixture components along the mixture_axis |
| 841 | + # but we use the negative index representation to be able to handle the |
| 842 | + # sample_shape |
| 843 | + samples = np.take_along_axis( |
| 844 | + comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape) |
| 845 | + ) |
| 846 | + |
| 847 | + # The `samples` array still has the `mixture_axis`, so we must remove it: |
| 848 | + output = samples[(..., 0) + (slice(None),) * len(event_shape)] |
| 849 | + |
| 850 | + # Final oddity: if size == 1, pymc3 defaults to reducing the sample_shape dimension |
| 851 | + # We do this to stay consistent with the rest of the package even though |
| 852 | + # we shouldn't have to do it. |
| 853 | + if size == 1: |
| 854 | + output = output[0] |
| 855 | + return output |
| 856 | + |
| 857 | + def _distr_parameters_for_repr(self): |
| 858 | + return [] |
0 commit comments