Skip to content

Commit 8b35cc6

Browse files
committed
Minor cleanup MarginalModel
1 parent 10b6fed commit 8b35cc6

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pymc_experimental/model/marginal_model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
582582
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
583583

584584
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
585-
if max(ndim_supp) > 0:
585+
if len(ndim_supp) != 1:
586586
raise NotImplementedError(
587-
"Marginalization of withe dependent Multivariate RVs not implemented"
587+
"Marginalization with dependent variables of different support dimensionality not implemented"
588588
)
589+
[ndim_supp] = ndim_supp
590+
if ndim_supp > 0:
591+
raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented")
589592

590593
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
591594
dependent_rvs_input_rvs = [
@@ -623,7 +626,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
623626
marginalization_op = FiniteDiscreteMarginalRV(
624627
inputs=list(replace_inputs.values()),
625628
outputs=cloned_outputs,
626-
ndim_supp=0,
629+
ndim_supp=ndim_supp,
627630
)
628631
marginalized_rvs = marginalization_op(*replace_inputs.keys())
629632
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))

pymc_experimental/tests/model/test_marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_not_supported_marginalized():
472472
y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10]))
473473
with pytest.raises(
474474
NotImplementedError,
475-
match="Marginalization of withe dependent Multivariate RVs not implemented",
475+
match="Marginalization with dependent Multivariate RVs not implemented",
476476
):
477477
m.marginalize(x)
478478

0 commit comments

Comments
 (0)