diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index ed9490511..f3a788fa4 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -13,13 +13,12 @@ from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model -from pymc.pytensorf import compile_pymc, constant_fold, inputvars +from pymc.pytensorf import compile_pymc, constant_fold from pymc.util import _get_seeds_per_chain, treedict from pytensor import Mode, scan from pytensor.compile import SharedVariable -from pytensor.compile.builders import OpFromGraph from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace -from pytensor.graph.replace import vectorize_graph +from pytensor.graph.replace import graph_replace, vectorize_graph from pytensor.scan import map as scan_map from pytensor.tensor import TensorType, TensorVariable from pytensor.tensor.elemwise import Elemwise @@ -686,31 +685,23 @@ def _add_reduce_batch_dependent_logps( def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # Clone the inner RV graph of the Marginalized RV marginalized_rvs_node = op.make_node(*inputs) - inner_rvs = clone_replace( + marginalized_rv, *inner_rvs = clone_replace( op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, ) - marginalized_rv = inner_rvs[0] # Obtain the joint_logp graph of the inner RV graph - inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs} - logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs) + inner_rv_values = dict(zip(inner_rvs, values)) + marginalized_vv = marginalized_rv.clone() + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) # Reduce logp dimensions corresponding to broadcasted variables - marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv]) + marginalized_logp = logps_dict.pop(marginalized_vv) joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( marginalized_rv.type, logps_dict.values() ) - # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different - # values of the marginalized RV - # Some inputs are not root inputs (such as transformed projections of value variables) - # Or cannot be used as inputs to an OpFromGraph (shared variables and constants) - inputs = list(inputvars(inputs)) - joint_logp_op = OpFromGraph( - list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True - ) - # Compute the joint_logp for all possible n values of the marginalized RV. We assume # each original dimension is independent so that it suffices to evaluate the graph # n times, once with each possible value of the marginalized RV replicated across @@ -729,17 +720,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): 0, ) - # Arbitrary cutoff to switch to Scan implementation to keep graph size under control - # TODO: Try vectorize here - if len(marginalized_rv_domain) <= 10: - joint_logps = [ - joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs) - for i in range(len(marginalized_rv_domain)) - ] - else: - + try: + joint_logps = vectorize_graph( + joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} + ) + except Exception: + # Fallback to Scan def logp_fn(marginalized_rv_const, *non_sequences): - return joint_logp_op(marginalized_rv_const, *non_sequences) + return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) joint_logps, _ = scan_map( fn=logp_fn,