@@ -176,7 +176,6 @@ def dist(cls, w, comp_dists, **kwargs):
176
176
)
177
177
178
178
# Check that components are not associated with a registered variable in the model
179
- components_ndim = set ()
180
179
components_ndim_supp = set ()
181
180
for dist in comp_dists :
182
181
# TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them
@@ -188,14 +187,8 @@ def dist(cls, w, comp_dists, **kwargs):
188
187
f"Component dist must be a distribution created via the `.dist()` API, got { type (dist )} "
189
188
)
190
189
check_dist_not_registered (dist )
191
- components_ndim .add (dist .ndim )
192
190
components_ndim_supp .add (dist .owner .op .ndim_supp )
193
191
194
- if len (components_ndim ) > 1 :
195
- raise ValueError (
196
- f"Mixture components must all have the same dimensionality, got { components_ndim } "
197
- )
198
-
199
192
if len (components_ndim_supp ) > 1 :
200
193
raise ValueError (
201
194
f"Mixture components must all have the same support dimensionality, got { components_ndim_supp } "
@@ -214,13 +207,18 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
214
207
# Create new rng for the mix_indexes internal RV
215
208
mix_indexes_rng = aesara .shared (np .random .default_rng ())
216
209
210
+ single_component = len (components ) == 1
211
+ ndim_supp = components [0 ].owner .op .ndim_supp
212
+
217
213
if size is not None :
218
214
components = cls ._resize_components (size , * components )
215
+ elif not single_component :
216
+ # We might need to broadcast components when size is not specified
217
+ shape = tuple (at .broadcast_shape (* components ))
218
+ size = shape [: len (shape ) - ndim_supp ]
219
+ components = cls ._resize_components (size , * components )
219
220
220
- single_component = len (components ) == 1
221
-
222
- # Extract support and replication ndims from components and weights
223
- ndim_supp = components [0 ].owner .op .ndim_supp
221
+ # Extract replication ndims from components and weights
224
222
ndim_batch = components [0 ].ndim - ndim_supp
225
223
if single_component :
226
224
# One dimension is taken by the mixture axis in the single component case
0 commit comments