diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 8a832ef8a..eca27ceb1 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -579,6 +579,15 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): return True +from pytensor.graph.basic import graph_inputs + + +def collect_shared_vars(outputs, blockers): + return [ + inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) + ] + + def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): # TODO: This should eventually be integrated in a more general routine that can # identify other types of supported marginalization, of which finite discrete @@ -621,14 +630,8 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] outputs = rvs_to_marginalize - # Clone replace inner RV rng inputs so that we can be sure of the update order - # replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()} - # Clone replace outter RV inputs, so that their shared RNGs don't make it into - # the inner graph of the marginalized RVs - # FIXME: This shouldn't be needed! - replace_inputs = {} - replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) - cloned_outputs = clone_replace(outputs, replace=replace_inputs) + # We are strict about shared variables in SymbolicRandomVariables + inputs = dependent_rvs_input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): marginalize_constructor = DiscreteMarginalMarkovChainRV @@ -636,12 +639,12 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs marginalize_constructor = FiniteDiscreteMarginalRV marginalization_op = marginalize_constructor( - inputs=list(replace_inputs.values()), - outputs=cloned_outputs, + inputs=inputs, + outputs=outputs, ndim_supp=ndim_supp, ) - marginalized_rvs = marginalization_op(*replace_inputs.keys()) + marginalized_rvs = marginalization_op(*inputs) fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) return rvs_to_marginalize, marginalized_rvs diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index f9a0a344b..4085cad7c 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -758,3 +758,19 @@ def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} np.testing.assert_allclose(logp_fn(test_point), expected_logp) + + +@pytest.importorskip("jax") +def test_mutable_indexing_jax_backend(): + from pymc.sampling.jax import get_jaxified_logp + + with MarginalModel() as model: + data = pm.MutableData(f"data", np.zeros(10)) + + cat_effect = pm.Normal("cat_effect", sigma=1, shape=5) + cat_effect_idx = pm.MutableData("cat_effect_idx", np.array([0, 1] * 5)) + + is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10) + pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) + model.marginalize(["is_outlier"]) + get_jaxified_logp(model)