Skip to content

Commit a9224b6

Browse files
committed
Refactor logic to reduce add batched logp dimensions
1 parent 8b35cc6 commit a9224b6

File tree

1 file changed

+47
-38
lines changed

1 file changed

+47
-38
lines changed

pymc_experimental/model/marginal_model.py

+47-38
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
vectorize_graph,
2727
)
2828
from pytensor.scan import map as scan_map
29-
from pytensor.tensor import TensorVariable
29+
from pytensor.tensor import TensorType, TensorVariable
3030
from pytensor.tensor.elemwise import Elemwise
3131
from pytensor.tensor.shape import Shape
3232
from pytensor.tensor.special import log_softmax
@@ -381,41 +381,36 @@ def transform_input(inputs):
381381

382382
rv_dict = {}
383383
rv_dims = {}
384-
for seed, rv in zip(seeds, vars_to_recover):
384+
for seed, marginalized_rv in zip(seeds, vars_to_recover):
385385
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
386-
if not isinstance(rv.owner.op, supported_dists):
386+
if not isinstance(marginalized_rv.owner.op, supported_dists):
387387
raise NotImplementedError(
388-
f"RV with distribution {rv.owner.op} cannot be recovered. "
388+
f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
389389
f"Supported distribution include {supported_dists}"
390390
)
391391

392392
m = self.clone()
393-
rv = m.vars_to_clone[rv]
394-
m.unmarginalize([rv])
395-
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
396-
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)
393+
marginalized_rv = m.vars_to_clone[marginalized_rv]
394+
m.unmarginalize([marginalized_rv])
395+
dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
396+
joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False)
397397

398-
marginalized_value = m.rvs_to_values[rv]
398+
marginalized_value = m.rvs_to_values[marginalized_rv]
399399
other_values = [v for v in m.value_vars if v is not marginalized_value]
400400

401401
# Handle batch dims for marginalized value and its dependent RVs
402-
joint_logp = joint_logps[-1]
403-
for dv in joint_logps[:-1]:
404-
dbcast = dv.type.broadcastable
405-
mbcast = marginalized_value.type.broadcastable
406-
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
407-
values_axis_bcast = [
408-
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
409-
]
410-
joint_logp += dv.sum(values_axis_bcast)
402+
marginalized_logp, *dependent_logps = joint_logps
403+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
404+
marginalized_rv.type, dependent_logps
405+
)
411406

412-
rv_shape = constant_fold(tuple(rv.shape))
413-
rv_domain = get_domain_of_finite_discrete_rv(rv)
407+
rv_shape = constant_fold(tuple(marginalized_rv.shape))
408+
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
414409
rv_domain_tensor = pt.moveaxis(
415410
pt.full(
416411
(*rv_shape, len(rv_domain)),
417412
rv_domain,
418-
dtype=rv.dtype,
413+
dtype=marginalized_rv.dtype,
419414
),
420415
-1,
421416
0,
@@ -431,7 +426,7 @@ def transform_input(inputs):
431426
joint_logps_norm = log_softmax(joint_logps, axis=-1)
432427
if return_samples:
433428
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
434-
if isinstance(rv.owner.op, DiscreteUniform):
429+
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
435430
sample_rv_outs += rv_domain[0]
436431

437432
rv_loglike_fn = compile_pymc(
@@ -456,18 +451,20 @@ def transform_input(inputs):
456451
logps, samples = zip(*logvs)
457452
logps = np.array(logps)
458453
samples = np.array(samples)
459-
rv_dict[rv.name] = samples.reshape(
454+
rv_dict[marginalized_rv.name] = samples.reshape(
460455
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
461456
)
462457
else:
463458
logps = np.array(logvs)
464459

465-
rv_dict["lp_" + rv.name] = logps.reshape(
460+
rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
466461
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
467462
)
468-
if rv.name in m.named_vars_to_dims:
469-
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
470-
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]
463+
if marginalized_rv.name in m.named_vars_to_dims:
464+
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
465+
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
466+
"lp_" + marginalized_rv.name + "_dim"
467+
]
471468

472469
coords, dims = coords_and_dims_for_inferencedata(self)
473470
dims.update(rv_dims)
@@ -647,6 +644,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
647644
raise NotImplementedError(f"Cannot compute domain for op {op}")
648645

649646

647+
def _add_reduce_batch_dependent_logps(
648+
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
649+
):
650+
"""Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
651+
652+
mbcast = marginalized_type.broadcastable
653+
reduced_logps = []
654+
for dependent_logp in dependent_logps:
655+
dbcast = dependent_logp.type.broadcastable
656+
dim_diff = len(dbcast) - len(mbcast)
657+
mbcast_aligned = (True,) * dim_diff + mbcast
658+
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
659+
reduced_logps.append(dependent_logp.sum(vbcast_axis))
660+
return pt.add(*reduced_logps)
661+
662+
650663
@_logprob.register(FiniteDiscreteMarginalRV)
651664
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
652665
# Clone the inner RV graph of the Marginalized RV
@@ -662,17 +675,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
662675
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)
663676

664677
# Reduce logp dimensions corresponding to broadcasted variables
665-
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]
666-
for inner_rv, inner_value in inner_rvs_to_values.items():
667-
if inner_rv is marginalized_rv:
668-
continue
669-
vbcast = inner_value.type.broadcastable
670-
mbcast = marginalized_rv.type.broadcastable
671-
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast
672-
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
673-
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)
674-
675-
# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
678+
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
679+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
680+
marginalized_rv.type, logps_dict.values()
681+
)
682+
683+
# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
676684
# values of the marginalized RV
677685
# Some inputs are not root inputs (such as transformed projections of value variables)
678686
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -700,6 +708,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
700708
)
701709

702710
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
711+
# TODO: Try vectorize here
703712
if len(marginalized_rv_domain) <= 10:
704713
joint_logps = [
705714
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)

0 commit comments

Comments
 (0)