Skip to content

Commit ceb1db8

Browse files
Fix Mixture distribution mode computation and logp dimensions
Closes pymc-devs#3994.
1 parent 7842072 commit ceb1db8

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

pymc3/distributions/mixture.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,13 @@ def __init__(self, w, comp_dists, *args, **kwargs):
153153
dtype = kwargs.pop('dtype', default_dtype)
154154

155155
try:
156-
comp_modes = self._comp_modes()
157-
comp_mode_logps = self.logp(comp_modes)
158-
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)]
156+
if isinstance(comp_dists, Distribution):
157+
comp_mode_logps = comp_dists.logp(comp_dists.mode)
158+
else:
159+
comp_mode_logps = tt.stack([cd.logp(cd.mode) for cd in comp_dists])
160+
161+
mode_idx = tt.argmax(tt.log(w) + comp_mode_logps, axis=-1)
162+
self.mode = self._comp_modes()[mode_idx]
159163

160164
if 'mode' not in defaults:
161165
defaults.append('mode')
@@ -427,7 +431,7 @@ def logp(self, value):
427431
"""
428432
w = self.w
429433

430-
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1),
434+
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1, keepdims=False),
431435
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1),
432436
broadcast_conditions=False)
433437

pymc3/math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,11 @@ def tround(*args, **kwargs):
168168
return tt.round(*args, **kwargs)
169169

170170

171-
def logsumexp(x, axis=None):
171+
def logsumexp(x, axis=None, keepdims=True):
172172
# Adapted from https://github.com/Theano/Theano/issues/1563
173173
x_max = tt.max(x, axis=axis, keepdims=True)
174-
return tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
174+
res = tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
175+
return res if keepdims else res.squeeze()
175176

176177

177178
def logaddexp(a, b):

pymc3/tests/test_mixture.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ def setup_class(cls):
6666
cls.pois_mu = np.array([5., 20.])
6767
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000)
6868

69+
def test_dimensions(self):
70+
a1 = Normal.dist(mu=0, sigma=1)
71+
a2 = Normal.dist(mu=10, sigma=1)
72+
mix = Mixture.dist(w=np.r_[0.5, 0.5], comp_dists=[a1, a2])
73+
74+
assert mix.mode.ndim == 0
75+
assert mix.logp(0.0).ndim == 0
76+
77+
value = np.r_[0.0, 1.0, 2.0]
78+
assert mix.logp(value).ndim == 1
79+
6980
def test_mixture_list_of_normals(self):
7081
with Model() as model:
7182
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
@@ -252,7 +263,7 @@ def test_mixture_of_mvn(self):
252263
# check logp of mixture
253264
testpoint = model.test_point
254265
mixlogp_st = logsumexp(np.log(testpoint['w']) + complogp_st,
255-
axis=-1, keepdims=True)
266+
axis=-1, keepdims=False)
256267
assert_allclose(y.logp_elemwise(testpoint),
257268
mixlogp_st)
258269

@@ -321,7 +332,7 @@ def mixmixlogp(value, point):
321332
complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1)
322333
mixmixlogpg = logsumexp(np.log(point['mix_w']).astype(floatX) +
323334
complogp_mix,
324-
axis=-1, keepdims=True)
335+
axis=-1, keepdims=False)
325336
return priorlogp, mixmixlogpg
326337

327338
value = np.exp(self.norm_x)[:, None]

0 commit comments

Comments
 (0)