Skip to content

Commit 6e0de43

Browse files
committed
Minor cleanup MarginalModel
1 parent 74776cb commit 6e0de43

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pymc_experimental/model/marginal_model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
580580
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
581581

582582
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
583-
if max(ndim_supp) > 0:
583+
if len(ndim_supp) != 1:
584584
raise NotImplementedError(
585-
"Marginalization of withe dependent Multivariate RVs not implemented"
585+
"Marginalization with dependent variables of different support not implemented"
586586
)
587+
[ndim_supp] = ndim_supp
588+
if ndim_supp > 0:
589+
raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented")
587590

588591
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
589592
dependent_rvs_input_rvs = [
@@ -621,7 +624,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621624
marginalization_op = FiniteDiscreteMarginalRV(
622625
inputs=list(replace_inputs.values()),
623626
outputs=cloned_outputs,
624-
ndim_supp=0,
627+
ndim_supp=ndim_supp,
625628
)
626629
marginalized_rvs = marginalization_op(*replace_inputs.keys())
627630
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))

0 commit comments

Comments
 (0)