Skip to content

Commit 49761d7

Browse files
committed
Add test for univariate and multivariate marginal mixture
1 parent dd2a060 commit 49761d7

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

pymc_experimental/marginal_model.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,12 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
340340
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
341341

342342
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
343-
if max(ndim_supp) > 0:
344-
raise NotImplementedError(
345-
"Marginalization of withe dependent Multivariate RVs not implemented"
346-
)
343+
if len(ndim_supp) != 1:
344+
raise NotImplementedError()
345+
# if max(ndim_supp) > 0:
346+
# raise NotImplementedError(
347+
# "Marginalization with dependent Multivariate RVs not implemented"
348+
# )
347349

348350
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
349351
dependent_rvs_input_rvs = [
@@ -381,7 +383,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
381383
marginalization_op = FiniteDiscreteMarginalRV(
382384
inputs=list(replace_inputs.values()),
383385
outputs=cloned_outputs,
384-
ndim_supp=0,
386+
ndim_supp=ndim_supp,
385387
)
386388
marginalized_rvs = marginalization_op(*replace_inputs.keys())
387389
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))

pymc_experimental/tests/test_marginal_model.py

+44
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,47 @@ def test_is_conditional_dependent_static_shape():
426426
x2 = pt.matrix("x2", shape=(9, 5))
427427
y2 = pt.random.normal(size=pt.shape(x2))
428428
assert not is_conditional_dependent(y2, x2, [x2, y2])
429+
430+
431+
@pytest.mark.parametrize("univariate", (True, False))
432+
def test_vector_univariate_mixture(univariate):
433+
434+
with MarginalModel() as m:
435+
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
436+
437+
def dist(idx, size):
438+
return pm.math.switch(
439+
pm.math.eq(idx, 0),
440+
pm.Normal.dist([-10, -10], 1),
441+
pm.Normal.dist([10, 10], 1),
442+
)
443+
444+
pm.CustomDist("norm", idx, dist=dist)
445+
446+
m.marginalize(idx)
447+
logp_fn = m.compile_logp()
448+
449+
if univariate:
450+
with pm.Model() as ref_m:
451+
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
452+
else:
453+
with pm.Model() as ref_m:
454+
pm.Mixture(
455+
"norm",
456+
w=[0.5, 0.5],
457+
comp_dists=[
458+
pm.MvNormal.dist([-10, -10], np.eye(2)),
459+
pm.MvNormal.dist([10, 10], np.eye(2)),
460+
],
461+
shape=(2,),
462+
)
463+
ref_logp_fn = ref_m.compile_logp()
464+
465+
for test_value in (
466+
[-10, -10],
467+
[10, 10],
468+
[-10, 10],
469+
[-10, 10],
470+
):
471+
pt = {"norm": test_value}
472+
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))

0 commit comments

Comments
 (0)