Skip to content

Commit c0cc253

Browse files
committed
BUG Convert mean and mode to tensors before applying theano indexing.
1 parent 8a59d87 commit c0cc253

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pymc3/distributions/mixture.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def __init__(self, w, comp_dists, *args, **kwargs):
7171

7272
super(Mixture, self).__init__(shape, dtype, defaults=defaults,
7373
*args, **kwargs)
74-
74+
7575
def _comp_logp(self, value):
7676
comp_dists = self.comp_dists
77-
77+
7878
try:
7979
value_ = value if value.ndim > 1 else tt.shape_padright(value)
8080

@@ -85,14 +85,14 @@ def _comp_logp(self, value):
8585

8686
def _comp_means(self):
8787
try:
88-
return self.comp_dists.mean
88+
return tt.as_tensor_variable(self.comp_dists.mean)
8989
except AttributeError:
9090
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists],
9191
axis=1)
9292

9393
def _comp_modes(self):
9494
try:
95-
return self.comp_dists.mode
95+
return tt.as_tensor_variable(self.comp_dists.mode)
9696
except AttributeError:
9797
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists],
9898
axis=1)
@@ -137,7 +137,7 @@ def random_choice(*args, **kwargs):
137137
else:
138138
return np.squeeze(comp_samples[w_samples])
139139

140-
140+
141141
class NormalMixture(Mixture):
142142
R"""
143143
Normal mixture log-likelihood
@@ -164,6 +164,6 @@ class NormalMixture(Mixture):
164164
def __init__(self, w, mu, *args, **kwargs):
165165
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
166166
sd=kwargs.pop('sd', None))
167-
167+
168168
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd),
169169
*args, **kwargs)

0 commit comments

Comments
 (0)