Skip to content

Use vectorize in finite_discrete_marginal_rv_logp #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
from pymc.pytensorf import compile_pymc, constant_fold
from pymc.util import _get_seeds_per_chain, treedict
from pytensor import Mode, scan
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
from pytensor.graph.replace import vectorize_graph
from pytensor.graph.replace import graph_replace, vectorize_graph
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -686,31 +685,23 @@ def _add_reduce_batch_dependent_logps(
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
inner_rvs = clone_replace(
marginalized_rv, *inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)
marginalized_rv = inner_rvs[0]

# Obtain the joint_logp graph of the inner RV graph
inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs}
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)
inner_rv_values = dict(zip(inner_rvs, values))
marginalized_vv = marginalized_rv.clone()
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)

# Reduce logp dimensions corresponding to broadcasted variables
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
marginalized_logp = logps_dict.pop(marginalized_vv)
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, logps_dict.values()
)

# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
# values of the marginalized RV
# Some inputs are not root inputs (such as transformed projections of value variables)
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
inputs = list(inputvars(inputs))
joint_logp_op = OpFromGraph(
list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True
)

# Compute the joint_logp for all possible n values of the marginalized RV. We assume
# each original dimension is independent so that it suffices to evaluate the graph
# n times, once with each possible value of the marginalized RV replicated across
Expand All @@ -729,17 +720,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
0,
)

# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
# TODO: Try vectorize here
if len(marginalized_rv_domain) <= 10:
joint_logps = [
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
for i in range(len(marginalized_rv_domain))
]
else:

try:
joint_logps = vectorize_graph(
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
)
except Exception:
# Fallback to Scan
def logp_fn(marginalized_rv_const, *non_sequences):
return joint_logp_op(marginalized_rv_const, *non_sequences)
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})

joint_logps, _ = scan_map(
fn=logp_fn,
Expand Down