Skip to content

Commit 9e7495b

Browse files
author
Junpeng Lao
authored
Fix random sampling in Mixture (#3004)
* WIP fix mixture random * Add test * fix test * remove shape of MultiObservedRV * Fix test Revert back to original, ie not setting the shape of observedRV * Fix test * Fix test, add kwarg to NormalMixture Also updated dependent_density_regression notebook * improve naming * fix test
1 parent dd7caab commit 9e7495b

File tree

4 files changed

+846
-703
lines changed

4 files changed

+846
-703
lines changed

docs/source/notebooks/dependent_density_regression.ipynb

Lines changed: 722 additions & 686 deletions
Large diffs are not rendered by default.

pymc3/distributions/mixture.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,33 @@ def random_choice(*args, **kwargs):
158158
return np.random.choice(k, p=w, *args, **kwargs)
159159

160160
w = draw_values([self.w], point=point)[0]
161-
161+
comp_tmp = self._comp_samples(point=point, size=None)
162+
if self.shape.size == 0:
163+
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
164+
else:
165+
distshape = self.shape
162166
w_samples = generate_samples(random_choice,
163167
w=w,
164168
broadcast_shape=w.shape[:-1] or (1,),
165-
dist_shape=self.shape,
169+
dist_shape=distshape,
166170
size=size).squeeze()
167-
comp_samples = self._comp_samples(point=point, size=size)
168-
169-
if comp_samples.ndim > 1:
170-
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples])
171+
if (size is None) or (distshape.size == 0):
172+
comp_samples = self._comp_samples(point=point, size=size)
173+
if comp_samples.ndim > 1:
174+
samples = np.squeeze(comp_samples[np.arange(w_samples.size), ..., w_samples])
175+
else:
176+
samples = np.squeeze(comp_samples[w_samples])
171177
else:
172-
return np.squeeze(comp_samples[w_samples])
178+
samples = np.zeros((size,)+tuple(distshape))
179+
for i in range(size):
180+
w_tmp = w_samples[i, :]
181+
comp_tmp = self._comp_samples(point=point, size=None)
182+
if comp_tmp.ndim > 1:
183+
samples[i, :] = np.squeeze(comp_tmp[np.arange(w_tmp.size), ..., w_tmp])
184+
else:
185+
samples[i, :] = np.squeeze(comp_tmp[w_tmp])
186+
187+
return samples
173188

174189

175190
class NormalMixture(Mixture):
@@ -197,22 +212,22 @@ class NormalMixture(Mixture):
197212
the component standard deviations
198213
tau : array of floats
199214
the component precisions
215+
comp_shape : shape of the Normal component
216+
notice that it should be different than the shape
217+
of the mixture distribution, with one axis being
218+
the number of components.
200219
201220
Note: You only have to pass in sd or tau, but not both.
202221
"""
203222

204-
def __init__(self, w, mu, *args, **kwargs):
223+
def __init__(self, w, mu, comp_shape=(), *args, **kwargs):
205224
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
206225
sd=kwargs.pop('sd', None))
207226

208-
distshape = np.broadcast(mu, sd).shape
209227
self.mu = mu = tt.as_tensor_variable(mu)
210228
self.sd = sd = tt.as_tensor_variable(sd)
211229

212-
if not distshape:
213-
distshape = np.broadcast(mu.tag.test_value, sd.tag.test_value).shape
214-
215-
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd, shape=distshape),
230+
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd, shape=comp_shape),
216231
*args, **kwargs)
217232

218233
def _repr_latex_(self, name=None, dist=None):

pymc3/tests/test_distributions_random.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,17 +759,79 @@ def ref_rand(size, w, mu, sd):
759759
pymc3_random(pm.NormalMixture, {'w': Simplex(2),
760760
'mu': Domain([[.05, 2.5], [-5., 1.]], edges=(None, None)),
761761
'sd': Domain([[1, 1], [1.5, 2.]], edges=(None, None))},
762+
extra_args={'comp_shape': 2},
762763
size=1000,
763764
ref_rand=ref_rand)
764765
pymc3_random(pm.NormalMixture, {'w': Simplex(3),
765766
'mu': Domain([[-5., 1., 2.5]], edges=(None, None)),
766767
'sd': Domain([[1.5, 2., 3.]], edges=(None, None))},
768+
extra_args={'comp_shape': 3},
767769
size=1000,
768770
ref_rand=ref_rand)
769771

772+
773+
def test_mixture_random_shape():
774+
# test the shape broadcasting in mixture random
775+
y = np.concatenate([nr.poisson(5, size=10),
776+
nr.poisson(9, size=10)])
777+
with pm.Model() as m:
778+
comp0 = pm.Poisson.dist(mu=np.ones(2))
779+
w0 = pm.Dirichlet('w0', a=np.ones(2))
780+
like0 = pm.Mixture('like0',
781+
w=w0,
782+
comp_dists=comp0,
783+
shape=y.shape,
784+
observed=y)
785+
786+
comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
787+
shape=(20, 2))
788+
w1 = pm.Dirichlet('w1', a=np.ones(2))
789+
like1 = pm.Mixture('like1',
790+
w=w1,
791+
comp_dists=comp1, observed=y)
792+
793+
comp2 = pm.Poisson.dist(mu=np.ones(2))
794+
w2 = pm.Dirichlet('w2',
795+
a=np.ones(2),
796+
shape=(20, 2))
797+
like2 = pm.Mixture('like2',
798+
w=w2,
799+
comp_dists=comp2,
800+
observed=y)
801+
802+
comp3 = pm.Poisson.dist(mu=np.ones(2),
803+
shape=(20, 2))
804+
w3 = pm.Dirichlet('w3',
805+
a=np.ones(2),
806+
shape=(20, 2))
807+
like3 = pm.Mixture('like3',
808+
w=w3,
809+
comp_dists=comp3,
810+
observed=y)
811+
812+
rand0 = like0.distribution.random(m.test_point, size=100)
813+
assert rand0.shape == (100, 20)
814+
815+
rand1 = like1.distribution.random(m.test_point, size=100)
816+
assert rand1.shape == (100, 20)
817+
818+
rand2 = like2.distribution.random(m.test_point, size=100)
819+
assert rand2.shape == (100, 20)
820+
821+
rand3 = like3.distribution.random(m.test_point, size=100)
822+
assert rand3.shape == (100, 20)
823+
824+
with m:
825+
ppc = pm.sample_ppc([m.test_point], samples=200)
826+
assert ppc['like0'].shape == (200, 20)
827+
assert ppc['like1'].shape == (200, 20)
828+
assert ppc['like2'].shape == (200, 20)
829+
assert ppc['like3'].shape == (200, 20)
830+
831+
770832
def test_density_dist_with_random_sampleable():
771833
with pm.Model() as model:
772-
mu = pm.Normal('mu',0,1)
834+
mu = pm.Normal('mu', 0, 1)
773835
normal_dist = pm.Normal.dist(mu, 1)
774836
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
775837
trace = pm.sample(100)
@@ -781,7 +843,7 @@ def test_density_dist_with_random_sampleable():
781843

782844
def test_density_dist_without_random_not_sampleable():
783845
with pm.Model() as model:
784-
mu = pm.Normal('mu',0,1)
846+
mu = pm.Normal('mu', 0, 1)
785847
normal_dist = pm.Normal.dist(mu, 1)
786848
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
787849
trace = pm.sample(100)

pymc3/tests/test_sampling.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,14 @@ def test_empty_model():
140140
pm.sample()
141141
error.match('any free variables')
142142

143+
143144
def test_partial_trace_sample():
144145
with pm.Model() as model:
145146
a = pm.Normal('a', mu=0, sd=1)
146147
b = pm.Normal('b', mu=0, sd=1)
147148
trace = pm.sample(trace=[a])
148149

150+
149151
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
150152
class TestNamedSampling(SeededTest):
151153
def test_shared_named(self):
@@ -235,7 +237,34 @@ def test_normal_vector(self):
235237
trace = pm.sample()
236238

237239
with model:
238-
# test list input
240+
# test list input
241+
ppc0 = pm.sample_ppc([model.test_point], samples=10)
242+
ppc = pm.sample_ppc(trace, samples=10, vars=[])
243+
assert len(ppc) == 0
244+
ppc = pm.sample_ppc(trace, samples=10, vars=[a])
245+
assert 'a' in ppc
246+
assert ppc['a'].shape == (10, 2)
247+
248+
ppc = pm.sample_ppc(trace, samples=10, vars=[a], size=4)
249+
assert 'a' in ppc
250+
assert ppc['a'].shape == (10, 4, 2)
251+
252+
def test_vector_observed(self):
253+
# This test was initially created to test whether observedRVs
254+
# can assert the shape automatically from the observed data.
255+
# It can make sample_ppc correct for RVs similar to below (i.e.,
256+
# some kind of broadcasting is involved). However, doing so makes
257+
# the application with `theano.shared` array as observed data
258+
# invalid (after the `.set_value` the RV shape could change).
259+
with pm.Model() as model:
260+
mu = pm.Normal('mu', mu=0, sd=1)
261+
a = pm.Normal('a', mu=mu, sd=1,
262+
shape=2, # necessary to make ppc sample correct
263+
observed=np.array([0., 1.]))
264+
trace = pm.sample()
265+
266+
with model:
267+
# test list input
239268
ppc0 = pm.sample_ppc([model.test_point], samples=10)
240269
ppc = pm.sample_ppc(trace, samples=10, vars=[])
241270
assert len(ppc) == 0
@@ -254,7 +283,7 @@ def test_sum_normal(self):
254283
trace = pm.sample()
255284

256285
with model:
257-
# test list input
286+
# test list input
258287
ppc0 = pm.sample_ppc([model.test_point], samples=10)
259288
ppc = pm.sample_ppc(trace, samples=1000, vars=[b])
260289
assert len(ppc) == 1
@@ -263,6 +292,7 @@ def test_sum_normal(self):
263292
_, pval = stats.kstest(ppc['b'], stats.norm(scale=scale).cdf)
264293
assert pval > 0.001
265294

295+
266296
class TestSamplePPCW(SeededTest):
267297
def test_sample_ppc_w(self):
268298
data0 = np.random.normal(0, 1, size=500)

0 commit comments

Comments
 (0)