Skip to content

Commit 6f5d0a6

Browse files
Add missing Dirichlet shape parameters to tests
1 parent c5f01a0 commit 6f5d0a6

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

pymc3/tests/test_dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ def test_multinomial_bound():
126126
n = x.sum()
127127

128128
with pm.Model() as modelA:
129-
p_a = pm.Dirichlet('p', floatX(np.ones(2)))
129+
p_a = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
130130
MultinomialA('x', n, p_a, observed=x)
131131

132132
with pm.Model() as modelB:
133-
p_b = pm.Dirichlet('p', floatX(np.ones(2)))
133+
p_b = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
134134
MultinomialB('x', n, p_b, observed=x)
135135

136136
assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}),

pymc3/tests/test_distributions_random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,15 +912,15 @@ def test_mixture_random_shape():
912912
nr.poisson(9, size=10)])
913913
with pm.Model() as m:
914914
comp0 = pm.Poisson.dist(mu=np.ones(2))
915-
w0 = pm.Dirichlet('w0', a=np.ones(2))
915+
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
916916
like0 = pm.Mixture('like0',
917917
w=w0,
918918
comp_dists=comp0,
919919
observed=y)
920920

921921
comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
922922
shape=(20, 2))
923-
w1 = pm.Dirichlet('w1', a=np.ones(2))
923+
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
924924
like1 = pm.Mixture('like1',
925925
w=w1,
926926
comp_dists=comp1,
@@ -967,15 +967,15 @@ def test_mixture_random_shape_fast():
967967
nr.poisson(9, size=10)])
968968
with pm.Model() as m:
969969
comp0 = pm.Poisson.dist(mu=np.ones(2))
970-
w0 = pm.Dirichlet('w0', a=np.ones(2))
970+
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
971971
like0 = pm.Mixture('like0',
972972
w=w0,
973973
comp_dists=comp0,
974974
observed=y)
975975

976976
comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
977977
shape=(20, 2))
978-
w1 = pm.Dirichlet('w1', a=np.ones(2))
978+
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
979979
like1 = pm.Mixture('like1',
980980
w=w1,
981981
comp_dists=comp1,

pymc3/tests/test_mixture.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_dimensions(self):
7979

8080
def test_mixture_list_of_normals(self):
8181
with Model() as model:
82-
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
82+
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
8383
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
8484
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
8585
Mixture('x_obs', w,
@@ -98,7 +98,7 @@ def test_mixture_list_of_normals(self):
9898

9999
def test_normal_mixture(self):
100100
with Model() as model:
101-
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
101+
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
102102
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
103103
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
104104
NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)
@@ -135,7 +135,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
135135
with Model() as model0:
136136
mus = Normal('mus', shape=comp_shape)
137137
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
138-
ws = Dirichlet('ws', np.ones(ncomp))
138+
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
139139
mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd,
140140
comp_shape=comp_shape)
141141
obs0 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd,
@@ -145,7 +145,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
145145
with Model() as model1:
146146
mus = Normal('mus', shape=comp_shape)
147147
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
148-
ws = Dirichlet('ws', np.ones(ncomp))
148+
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
149149
comp_dist = [Normal.dist(mu=mus[..., i], tau=taus[..., i],
150150
shape=nd)
151151
for i in range(ncomp)]
@@ -163,7 +163,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
163163
# comp_dists.
164164
mus = Normal('mus', shape=comp_shape)
165165
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
166-
ws = Dirichlet('ws', np.ones(ncomp))
166+
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
167167
if len(nd) > 1:
168168
if nd[-1] != ncomp:
169169
with pytest.raises(ValueError):
@@ -208,7 +208,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
208208

209209
def test_poisson_mixture(self):
210210
with Model() as model:
211-
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
211+
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
212212
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
213213
Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)
214214
step = Metropolis()
@@ -224,7 +224,7 @@ def test_poisson_mixture(self):
224224

225225
def test_mixture_list_of_poissons(self):
226226
with Model() as model:
227-
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
227+
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
228228
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
229229
Mixture('x_obs', w,
230230
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
@@ -247,7 +247,7 @@ def test_mixture_of_mvn(self):
247247
cov2 = np.diag([2.5, 3.5])
248248
obs = np.asarray([[.5, .5], mu1, mu2])
249249
with Model() as model:
250-
w = Dirichlet('w', floatX(np.ones(2)), transform=None)
250+
w = Dirichlet('w', floatX(np.ones(2)), transform=None, shape=(2,))
251251
mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
252252
mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
253253
y = Mixture('x_obs', w, [mvncomp1, mvncomp2],
@@ -291,13 +291,13 @@ def test_mixture_of_mixture(self):
291291
sigma=1,
292292
shape=nbr)
293293
# weight vector for the mixtures
294-
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
295-
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
294+
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
295+
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
296296
# mixture components
297297
g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
298298
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
299299
# mixture of mixtures
300-
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None)
300+
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None, shape=(2,))
301301
mix = Mixture('mix', w=mix_w,
302302
comp_dists=[g_mix, l_mix],
303303
observed=np.exp(self.norm_x))
@@ -378,7 +378,7 @@ def build_toy_dataset(N, K):
378378
X, y = build_toy_dataset(N, K)
379379

380380
with pm.Model() as model:
381-
pi = pm.Dirichlet('pi', np.ones(K))
381+
pi = pm.Dirichlet('pi', np.ones(K), shape=(K,))
382382

383383
comp_dist = []
384384
mu = []

0 commit comments

Comments
 (0)