Skip to content

Commit c6cd151

Browse files
committed
Refactor logic to reduce add batched logp dimensions
1 parent 6e0de43 commit c6cd151

File tree

1 file changed

+47
-38
lines changed

1 file changed

+47
-38
lines changed

pymc_experimental/model/marginal_model.py

+47-38
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
vectorize_graph,
2727
)
2828
from pytensor.scan import map as scan_map
29-
from pytensor.tensor import TensorVariable
29+
from pytensor.tensor import TensorType, TensorVariable
3030
from pytensor.tensor.elemwise import Elemwise
3131
from pytensor.tensor.shape import Shape
3232
from pytensor.tensor.special import log_softmax
@@ -379,41 +379,36 @@ def transform_input(inputs):
379379

380380
rv_dict = {}
381381
rv_dims = {}
382-
for seed, rv in zip(seeds, vars_to_recover):
382+
for seed, marginalized_rv in zip(seeds, vars_to_recover):
383383
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
384-
if not isinstance(rv.owner.op, supported_dists):
384+
if not isinstance(marginalized_rv.owner.op, supported_dists):
385385
raise NotImplementedError(
386-
f"RV with distribution {rv.owner.op} cannot be recovered. "
386+
f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
387387
f"Supported distribution include {supported_dists}"
388388
)
389389

390390
m = self.clone()
391-
rv = m.vars_to_clone[rv]
392-
m.unmarginalize([rv])
393-
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
394-
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)
391+
marginalized_rv = m.vars_to_clone[marginalized_rv]
392+
m.unmarginalize([marginalized_rv])
393+
dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
394+
joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False)
395395

396-
marginalized_value = m.rvs_to_values[rv]
396+
marginalized_value = m.rvs_to_values[marginalized_rv]
397397
other_values = [v for v in m.value_vars if v is not marginalized_value]
398398

399399
# Handle batch dims for marginalized value and its dependent RVs
400-
joint_logp = joint_logps[-1]
401-
for dv in joint_logps[:-1]:
402-
dbcast = dv.type.broadcastable
403-
mbcast = marginalized_value.type.broadcastable
404-
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
405-
values_axis_bcast = [
406-
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
407-
]
408-
joint_logp += dv.sum(values_axis_bcast)
400+
marginalized_logp, *dependent_logps = joint_logps
401+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
402+
marginalized_rv.type, dependent_logps
403+
)
409404

410-
rv_shape = constant_fold(tuple(rv.shape))
411-
rv_domain = get_domain_of_finite_discrete_rv(rv)
405+
rv_shape = constant_fold(tuple(marginalized_rv.shape))
406+
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
412407
rv_domain_tensor = pt.moveaxis(
413408
pt.full(
414409
(*rv_shape, len(rv_domain)),
415410
rv_domain,
416-
dtype=rv.dtype,
411+
dtype=marginalized_rv.dtype,
417412
),
418413
-1,
419414
0,
@@ -429,7 +424,7 @@ def transform_input(inputs):
429424
joint_logps_norm = log_softmax(joint_logps, axis=-1)
430425
if return_samples:
431426
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
432-
if isinstance(rv.owner.op, DiscreteUniform):
427+
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
433428
sample_rv_outs += rv_domain[0]
434429

435430
rv_loglike_fn = compile_pymc(
@@ -454,18 +449,20 @@ def transform_input(inputs):
454449
logps, samples = zip(*logvs)
455450
logps = np.array(logps)
456451
samples = np.array(samples)
457-
rv_dict[rv.name] = samples.reshape(
452+
rv_dict[marginalized_rv.name] = samples.reshape(
458453
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
459454
)
460455
else:
461456
logps = np.array(logvs)
462457

463-
rv_dict["lp_" + rv.name] = logps.reshape(
458+
rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
464459
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
465460
)
466-
if rv.name in m.named_vars_to_dims:
467-
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
468-
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]
461+
if marginalized_rv.name in m.named_vars_to_dims:
462+
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
463+
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
464+
"lp_" + marginalized_rv.name + "_dim"
465+
]
469466

470467
coords, dims = coords_and_dims_for_inferencedata(self)
471468
dims.update(rv_dims)
@@ -645,6 +642,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
645642
raise NotImplementedError(f"Cannot compute domain for op {op}")
646643

647644

645+
def _add_reduce_batch_dependent_logps(
646+
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
647+
):
648+
"""Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
649+
650+
mbcast = marginalized_type.broadcastable
651+
reduced_logps = []
652+
for dependent_logp in dependent_logps:
653+
dbcast = dependent_logp.type.broadcastable
654+
dim_diff = len(dbcast) - len(mbcast)
655+
mbcast_aligned = (True,) * dim_diff + mbcast
656+
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
657+
reduced_logps.append(dependent_logp.sum(vbcast_axis))
658+
return pt.add(*reduced_logps)
659+
660+
648661
@_logprob.register(FiniteDiscreteMarginalRV)
649662
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
650663
# Clone the inner RV graph of the Marginalized RV
@@ -660,17 +673,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
660673
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)
661674

662675
# Reduce logp dimensions corresponding to broadcasted variables
663-
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]
664-
for inner_rv, inner_value in inner_rvs_to_values.items():
665-
if inner_rv is marginalized_rv:
666-
continue
667-
vbcast = inner_value.type.broadcastable
668-
mbcast = marginalized_rv.type.broadcastable
669-
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast
670-
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
671-
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)
672-
673-
# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
676+
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
677+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
678+
marginalized_rv.type, logps_dict.values()
679+
)
680+
681+
# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
674682
# values of the marginalized RV
675683
# Some inputs are not root inputs (such as transformed projections of value variables)
676684
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -698,6 +706,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
698706
)
699707

700708
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
709+
# TODO: Try vectorize here
701710
if len(marginalized_rv_domain) <= 10:
702711
joint_logps = [
703712
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)

0 commit comments

Comments
 (0)