Skip to content

Commit a6e79e0

Browse files
Restore ndim_supp check, add assumed ndim_supp based on the type of RV being marginalized
1 parent 178be94 commit a6e79e0

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pymc_experimental/marginal_model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,9 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
349349
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
350350
if len(ndim_supp) != 1:
351351
raise NotImplementedError()
352-
ndim_supp = tuple(ndim_supp)[0]
353-
# if max(ndim_supp) > 0:
354-
# raise NotImplementedError(
355-
# "Marginalization with dependent Multivariate RVs not implemented"
356-
# )
352+
353+
if max(ndim_supp) > 0:
354+
raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented")
357355

358356
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
359357
dependent_rvs_input_rvs = [
@@ -390,13 +388,15 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
390388

391389
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
392390
marginalize_constructor = DiscreteMarginalMarkovChainRV
391+
marginalize_target_rv_ndim_supp = 0
393392
else:
394393
marginalize_constructor = FiniteDiscreteMarginalRV
394+
marginalize_target_rv_ndim_supp = 1
395395

396396
marginalization_op = marginalize_constructor(
397397
inputs=list(replace_inputs.values()),
398398
outputs=cloned_outputs,
399-
ndim_supp=ndim_supp,
399+
ndim_supp=marginalize_target_rv_ndim_supp,
400400
)
401401

402402
marginalized_rvs = marginalization_op(*replace_inputs.keys())

0 commit comments

Comments
 (0)