diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index 9e461e95eb..a041c3bc15 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -153,9 +153,13 @@ def __init__(self, w, comp_dists, *args, **kwargs): dtype = kwargs.pop('dtype', default_dtype) try: - comp_modes = self._comp_modes() - comp_mode_logps = self.logp(comp_modes) - self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)] + if isinstance(comp_dists, Distribution): + comp_mode_logps = comp_dists.logp(comp_dists.mode) + else: + comp_mode_logps = tt.stack([cd.logp(cd.mode) for cd in comp_dists]) + + mode_idx = tt.argmax(tt.log(w) + comp_mode_logps, axis=-1) + self.mode = self._comp_modes()[mode_idx] if 'mode' not in defaults: defaults.append('mode') @@ -427,7 +431,7 @@ def logp(self, value): """ w = self.w - return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1), + return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1, keepdims=False), w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1), broadcast_conditions=False) diff --git a/pymc3/math.py b/pymc3/math.py index 4f2f319c9a..2a44453cfb 100644 --- a/pymc3/math.py +++ b/pymc3/math.py @@ -168,10 +168,11 @@ def tround(*args, **kwargs): return tt.round(*args, **kwargs) -def logsumexp(x, axis=None): +def logsumexp(x, axis=None, keepdims=True): # Adapted from https://github.com/Theano/Theano/issues/1563 x_max = tt.max(x, axis=axis, keepdims=True) - return tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max + res = tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max + return res if keepdims else res.squeeze() def logaddexp(a, b): diff --git a/pymc3/tests/test_mixture.py b/pymc3/tests/test_mixture.py index 34fd8dd275..2547b09dda 100644 --- a/pymc3/tests/test_mixture.py +++ b/pymc3/tests/test_mixture.py @@ -66,6 +66,17 @@ def setup_class(cls): cls.pois_mu = np.array([5., 20.]) cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000) + def test_dimensions(self): + a1 = Normal.dist(mu=0, sigma=1) + a2 = Normal.dist(mu=10, sigma=1) + mix = Mixture.dist(w=np.r_[0.5, 0.5], comp_dists=[a1, a2]) + + assert mix.mode.ndim == 0 + assert mix.logp(0.0).ndim == 0 + + value = np.r_[0.0, 1.0, 2.0] + assert mix.logp(value).ndim == 1 + def test_mixture_list_of_normals(self): with Model() as model: w = Dirichlet('w', floatX(np.ones_like(self.norm_w))) @@ -252,7 +263,7 @@ def test_mixture_of_mvn(self): # check logp of mixture testpoint = model.test_point mixlogp_st = logsumexp(np.log(testpoint['w']) + complogp_st, - axis=-1, keepdims=True) + axis=-1, keepdims=False) assert_allclose(y.logp_elemwise(testpoint), mixlogp_st) @@ -321,7 +332,7 @@ def mixmixlogp(value, point): complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1) mixmixlogpg = logsumexp(np.log(point['mix_w']).astype(floatX) + complogp_mix, - axis=-1, keepdims=True) + axis=-1, keepdims=False) return priorlogp, mixmixlogpg value = np.exp(self.norm_x)[:, None]