@@ -71,10 +71,10 @@ def __init__(self, w, comp_dists, *args, **kwargs):
71
71
72
72
super (Mixture , self ).__init__ (shape , dtype , defaults = defaults ,
73
73
* args , ** kwargs )
74
-
74
+
75
75
def _comp_logp (self , value ):
76
76
comp_dists = self .comp_dists
77
-
77
+
78
78
try :
79
79
value_ = value if value .ndim > 1 else tt .shape_padright (value )
80
80
@@ -85,14 +85,14 @@ def _comp_logp(self, value):
85
85
86
86
def _comp_means (self ):
87
87
try :
88
- return self .comp_dists .mean
88
+ return tt . as_tensor_variable ( self .comp_dists .mean )
89
89
except AttributeError :
90
90
return tt .stack ([comp_dist .mean for comp_dist in self .comp_dists ],
91
91
axis = 1 )
92
92
93
93
def _comp_modes (self ):
94
94
try :
95
- return self .comp_dists .mode
95
+ return tt . as_tensor_variable ( self .comp_dists .mode )
96
96
except AttributeError :
97
97
return tt .stack ([comp_dist .mode for comp_dist in self .comp_dists ],
98
98
axis = 1 )
@@ -137,7 +137,7 @@ def random_choice(*args, **kwargs):
137
137
else :
138
138
return np .squeeze (comp_samples [w_samples ])
139
139
140
-
140
+
141
141
class NormalMixture (Mixture ):
142
142
R"""
143
143
Normal mixture log-likelihood
@@ -164,6 +164,6 @@ class NormalMixture(Mixture):
164
164
def __init__ (self , w , mu , * args , ** kwargs ):
165
165
_ , sd = get_tau_sd (tau = kwargs .pop ('tau' , None ),
166
166
sd = kwargs .pop ('sd' , None ))
167
-
167
+
168
168
super (NormalMixture , self ).__init__ (w , Normal .dist (mu , sd = sd ),
169
169
* args , ** kwargs )
0 commit comments