|
| 1 | +from collections.abc import Sequence |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp |
| 6 | +from pymc.logprob import conditional_logp |
| 7 | +from pymc.logprob.abstract import _logprob |
| 8 | +from pymc.pytensorf import constant_fold |
| 9 | +from pytensor import Mode, clone_replace, graph_replace, scan |
| 10 | +from pytensor import map as scan_map |
| 11 | +from pytensor import tensor as pt |
| 12 | +from pytensor.graph import vectorize_graph |
| 13 | +from pytensor.tensor import TensorType, TensorVariable |
| 14 | + |
| 15 | +from pymc_experimental.distributions import DiscreteMarkovChain |
| 16 | + |
| 17 | + |
| 18 | +class MarginalRV(SymbolicRandomVariable): |
| 19 | + """Base class for Marginalized RVs""" |
| 20 | + |
| 21 | + |
| 22 | +class FiniteDiscreteMarginalRV(MarginalRV): |
| 23 | + """Base class for Finite Discrete Marginalized RVs""" |
| 24 | + |
| 25 | + |
| 26 | +class DiscreteMarginalMarkovChainRV(MarginalRV): |
| 27 | + """Base class for Discrete Marginal Markov Chain RVs""" |
| 28 | + |
| 29 | + |
| 30 | +def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: |
| 31 | + op = rv.owner.op |
| 32 | + dist_params = rv.owner.op.dist_params(rv.owner) |
| 33 | + if isinstance(op, Bernoulli): |
| 34 | + return (0, 1) |
| 35 | + elif isinstance(op, Categorical): |
| 36 | + [p_param] = dist_params |
| 37 | + return tuple(range(pt.get_vector_length(p_param))) |
| 38 | + elif isinstance(op, DiscreteUniform): |
| 39 | + lower, upper = constant_fold(dist_params) |
| 40 | + return tuple(np.arange(lower, upper + 1)) |
| 41 | + elif isinstance(op, DiscreteMarkovChain): |
| 42 | + P, *_ = dist_params |
| 43 | + return tuple(range(pt.get_vector_length(P[-1]))) |
| 44 | + |
| 45 | + raise NotImplementedError(f"Cannot compute domain for op {op}") |
| 46 | + |
| 47 | + |
| 48 | +def _add_reduce_batch_dependent_logps( |
| 49 | + marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] |
| 50 | +): |
| 51 | + """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" |
| 52 | + |
| 53 | + mbcast = marginalized_type.broadcastable |
| 54 | + reduced_logps = [] |
| 55 | + for dependent_logp in dependent_logps: |
| 56 | + dbcast = dependent_logp.type.broadcastable |
| 57 | + dim_diff = len(dbcast) - len(mbcast) |
| 58 | + mbcast_aligned = (True,) * dim_diff + mbcast |
| 59 | + vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] |
| 60 | + reduced_logps.append(dependent_logp.sum(vbcast_axis)) |
| 61 | + return pt.add(*reduced_logps) |
| 62 | + |
| 63 | + |
| 64 | +@_logprob.register(FiniteDiscreteMarginalRV) |
| 65 | +def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): |
| 66 | + # Clone the inner RV graph of the Marginalized RV |
| 67 | + marginalized_rvs_node = op.make_node(*inputs) |
| 68 | + marginalized_rv, *inner_rvs = clone_replace( |
| 69 | + op.inner_outputs, |
| 70 | + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, |
| 71 | + ) |
| 72 | + |
| 73 | + # Obtain the joint_logp graph of the inner RV graph |
| 74 | + inner_rv_values = dict(zip(inner_rvs, values)) |
| 75 | + marginalized_vv = marginalized_rv.clone() |
| 76 | + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} |
| 77 | + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) |
| 78 | + |
| 79 | + # Reduce logp dimensions corresponding to broadcasted variables |
| 80 | + marginalized_logp = logps_dict.pop(marginalized_vv) |
| 81 | + joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( |
| 82 | + marginalized_rv.type, logps_dict.values() |
| 83 | + ) |
| 84 | + |
| 85 | + # Compute the joint_logp for all possible n values of the marginalized RV. We assume |
| 86 | + # each original dimension is independent so that it suffices to evaluate the graph |
| 87 | + # n times, once with each possible value of the marginalized RV replicated across |
| 88 | + # batched dimensions of the marginalized RV |
| 89 | + |
| 90 | + # PyMC does not allow RVs in the logp graph, even if we are just using the shape |
| 91 | + marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) |
| 92 | + marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) |
| 93 | + marginalized_rv_domain_tensor = pt.moveaxis( |
| 94 | + pt.full( |
| 95 | + (*marginalized_rv_shape, len(marginalized_rv_domain)), |
| 96 | + marginalized_rv_domain, |
| 97 | + dtype=marginalized_rv.dtype, |
| 98 | + ), |
| 99 | + -1, |
| 100 | + 0, |
| 101 | + ) |
| 102 | + |
| 103 | + try: |
| 104 | + joint_logps = vectorize_graph( |
| 105 | + joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} |
| 106 | + ) |
| 107 | + except Exception: |
| 108 | + # Fallback to Scan |
| 109 | + def logp_fn(marginalized_rv_const, *non_sequences): |
| 110 | + return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) |
| 111 | + |
| 112 | + joint_logps, _ = scan_map( |
| 113 | + fn=logp_fn, |
| 114 | + sequences=marginalized_rv_domain_tensor, |
| 115 | + non_sequences=[*values, *inputs], |
| 116 | + mode=Mode().including("local_remove_check_parameter"), |
| 117 | + ) |
| 118 | + |
| 119 | + joint_logps = pt.logsumexp(joint_logps, axis=0) |
| 120 | + |
| 121 | + # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise |
| 122 | + return joint_logps, *(pt.constant(0),) * (len(values) - 1) |
| 123 | + |
| 124 | + |
| 125 | +@_logprob.register(DiscreteMarginalMarkovChainRV) |
| 126 | +def marginal_hmm_logp(op, values, *inputs, **kwargs): |
| 127 | + marginalized_rvs_node = op.make_node(*inputs) |
| 128 | + inner_rvs = clone_replace( |
| 129 | + op.inner_outputs, |
| 130 | + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, |
| 131 | + ) |
| 132 | + |
| 133 | + chain_rv, *dependent_rvs = inner_rvs |
| 134 | + P, n_steps_, init_dist_, rng = chain_rv.owner.inputs |
| 135 | + domain = pt.arange(P.shape[-1], dtype="int32") |
| 136 | + |
| 137 | + # Construct logp in two steps |
| 138 | + # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) |
| 139 | + |
| 140 | + # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating |
| 141 | + # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, |
| 142 | + # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. |
| 143 | + chain_value = chain_rv.clone() |
| 144 | + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) |
| 145 | + logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) |
| 146 | + |
| 147 | + # Reduce and add the batch dims beyond the chain dimension |
| 148 | + reduced_logp_emissions = _add_reduce_batch_dependent_logps( |
| 149 | + chain_rv.type, logp_emissions_dict.values() |
| 150 | + ) |
| 151 | + |
| 152 | + # Add a batch dimension for the domain of the chain |
| 153 | + chain_shape = constant_fold(tuple(chain_rv.shape)) |
| 154 | + batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) |
| 155 | + batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) |
| 156 | + |
| 157 | + # Step 2: Compute the transition probabilities |
| 158 | + # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) |
| 159 | + # We do it entirely in logs, though. |
| 160 | + |
| 161 | + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) |
| 162 | + # under the initial distribution. This is robust to everything the user can throw at it. |
| 163 | + init_dist_value = init_dist_.type() |
| 164 | + logp_init_dist = logp(init_dist_, init_dist_value) |
| 165 | + # There is a degerate batch dim for lags=1 (the only supported case), |
| 166 | + # that we have to work around, by expanding the batch value and then squeezing it out of the logp |
| 167 | + batch_logp_init_dist = vectorize_graph( |
| 168 | + logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} |
| 169 | + ).squeeze(1) |
| 170 | + log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] |
| 171 | + |
| 172 | + def step_alpha(logp_emission, log_alpha, log_P): |
| 173 | + step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) |
| 174 | + return logp_emission + step_log_prob |
| 175 | + |
| 176 | + P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) |
| 177 | + log_P = pt.shape_padright(pt.log(P), P_bcast_dims) |
| 178 | + log_alpha_seq, _ = scan( |
| 179 | + step_alpha, |
| 180 | + non_sequences=[log_P], |
| 181 | + outputs_info=[log_alpha_init], |
| 182 | + # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value |
| 183 | + sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), |
| 184 | + ) |
| 185 | + # Final logp is just the sum of the last scan state |
| 186 | + joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) |
| 187 | + |
| 188 | + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first |
| 189 | + # return is the joint probability of everything together, but PyMC still expects one logp for each one. |
| 190 | + dummy_logps = (pt.constant(0),) * (len(values) - 1) |
| 191 | + return joint_logp, *dummy_logps |
0 commit comments