Skip to content

Commit 49fb695

Browse files
committed
Allow broadcasting of mixture components
1 parent b36ccee commit 49fb695

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

pymc/distributions/mixture.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def dist(cls, w, comp_dists, **kwargs):
176176
)
177177

178178
# Check that components are not associated with a registered variable in the model
179-
components_ndim = set()
180179
components_ndim_supp = set()
181180
for dist in comp_dists:
182181
# TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them
@@ -188,14 +187,8 @@ def dist(cls, w, comp_dists, **kwargs):
188187
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"
189188
)
190189
check_dist_not_registered(dist)
191-
components_ndim.add(dist.ndim)
192190
components_ndim_supp.add(dist.owner.op.ndim_supp)
193191

194-
if len(components_ndim) > 1:
195-
raise ValueError(
196-
f"Mixture components must all have the same dimensionality, got {components_ndim}"
197-
)
198-
199192
if len(components_ndim_supp) > 1:
200193
raise ValueError(
201194
f"Mixture components must all have the same support dimensionality, got {components_ndim_supp}"
@@ -214,13 +207,18 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
214207
# Create new rng for the mix_indexes internal RV
215208
mix_indexes_rng = aesara.shared(np.random.default_rng())
216209

210+
single_component = len(components) == 1
211+
ndim_supp = components[0].owner.op.ndim_supp
212+
217213
if size is not None:
218214
components = cls._resize_components(size, *components)
215+
elif not single_component:
216+
# We might need to broadcast components when size is not specified
217+
shape = tuple(at.broadcast_shape(*components))
218+
size = shape[: len(shape) - ndim_supp]
219+
components = cls._resize_components(size, *components)
219220

220-
single_component = len(components) == 1
221-
222-
# Extract support and replication ndims from components and weights
223-
ndim_supp = components[0].owner.op.ndim_supp
221+
# Extract replication ndims from components and weights
224222
ndim_batch = components[0].ndim - ndim_supp
225223
if single_component:
226224
# One dimension is taken by the mixture axis in the single component case

pymc/tests/test_mixture.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import scipy.stats as st
2121

2222
from aesara import tensor as at
23+
from aesara.tensor.random.op import RandomVariable
2324
from numpy.testing import assert_allclose
2425
from scipy.special import logsumexp
2526

@@ -677,6 +678,38 @@ def test_mixture_dtype(self):
677678
).dtype
678679
assert mix_dtype == aesara.config.floatX
679680

681+
@pytest.mark.parametrize(
682+
"comp_dists, expected_shape",
683+
[
684+
(
685+
[
686+
Normal.dist([[0, 0, 0], [0, 0, 0]]),
687+
Normal.dist([0, 0, 0]),
688+
Normal.dist([0]),
689+
],
690+
(2, 3),
691+
),
692+
(
693+
[
694+
Dirichlet.dist([[1, 1, 1], [1, 1, 1]]),
695+
Dirichlet.dist([1, 1, 1]),
696+
],
697+
(2, 3),
698+
),
699+
],
700+
)
701+
def test_broadcast_components(self, comp_dists, expected_shape):
702+
n_dists = len(comp_dists)
703+
mix = Mixture.dist(w=np.ones(n_dists) / n_dists, comp_dists=comp_dists)
704+
mix_eval = mix.eval()
705+
assert tuple(mix_eval.shape) == expected_shape
706+
assert np.unique(mix_eval).size == mix.eval().size
707+
for comp_dist in mix.owner.inputs[2:]:
708+
# We check that the input is a "pure" RandomVariable and not a broadcast
709+
# operation. This confirms that all draws will be unique
710+
assert isinstance(comp_dist.owner.op, RandomVariable)
711+
assert tuple(comp_dist.shape.eval()) == expected_shape
712+
680713

681714
class TestNormalMixture(SeededTest):
682715
def test_normal_mixture_sampling(self):

0 commit comments

Comments
 (0)