Skip to content

Support HMM via marginalization of DiscreteMarkovChain #257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 152 additions & 55 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,26 @@
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
from pytensor import Mode
from pytensor import Mode, scan
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import (
Constant,
FunctionGraph,
ancestors,
clone_replace,
vectorize_graph,
)
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
from pytensor.graph.replace import vectorize_graph
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorVariable
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Shape
from pytensor.tensor.special import log_softmax

__all__ = ["MarginalModel"]

from pymc_experimental.distributions import DiscreteMarkovChain


class MarginalModel(Model):
"""Subclass of PyMC Model that implements functionality for automatic
Expand Down Expand Up @@ -247,16 +244,25 @@ def marginalize(
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
]

supported_dists = (Bernoulli, Categorical, DiscreteUniform)
for rv_to_marginalize in rvs_to_marginalize:
if rv_to_marginalize not in self.free_RVs:
raise ValueError(
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
)
if not isinstance(rv_to_marginalize.owner.op, supported_dists):

rv_op = rv_to_marginalize.owner.op
if isinstance(rv_op, DiscreteMarkovChain):
if rv_op.n_lags > 1:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
)
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a markov chain have a non-matrix transition probability?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be valid for batch dims

)
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
raise NotImplementedError(
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
f"Supported distribution include {supported_dists}"
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the old error message was more helpful

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was but it not gonna scale

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe link to the docs where it lists all the supported distributions then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The notes state that functionality is restricted, only finite discrete RVs are supported which is kind of true. Although we don't yet support Truncated/Censored of infinite discrete RVs which thus become finite: #95

We also don't support Multinomial which in theory is finite... So I think we the disclaimer functionality is restricted and this error message indicating the type of the RV that could not be marginalized it's fair game?


if rv_to_marginalize.name in self.named_vars_to_dims:
Expand Down Expand Up @@ -381,41 +387,36 @@ def transform_input(inputs):

rv_dict = {}
rv_dims = {}
for seed, rv in zip(seeds, vars_to_recover):
for seed, marginalized_rv in zip(seeds, vars_to_recover):
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
if not isinstance(rv.owner.op, supported_dists):
if not isinstance(marginalized_rv.owner.op, supported_dists):
raise NotImplementedError(
f"RV with distribution {rv.owner.op} cannot be recovered. "
f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
f"Supported distribution include {supported_dists}"
)

m = self.clone()
rv = m.vars_to_clone[rv]
m.unmarginalize([rv])
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)
marginalized_rv = m.vars_to_clone[marginalized_rv]
m.unmarginalize([marginalized_rv])
dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False)

marginalized_value = m.rvs_to_values[rv]
marginalized_value = m.rvs_to_values[marginalized_rv]
other_values = [v for v in m.value_vars if v is not marginalized_value]

# Handle batch dims for marginalized value and its dependent RVs
joint_logp = joint_logps[-1]
for dv in joint_logps[:-1]:
dbcast = dv.type.broadcastable
mbcast = marginalized_value.type.broadcastable
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
values_axis_bcast = [
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
]
joint_logp += dv.sum(values_axis_bcast)
marginalized_logp, *dependent_logps = joint_logps
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, dependent_logps
)

rv_shape = constant_fold(tuple(rv.shape))
rv_domain = get_domain_of_finite_discrete_rv(rv)
rv_shape = constant_fold(tuple(marginalized_rv.shape))
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
rv_domain_tensor = pt.moveaxis(
pt.full(
(*rv_shape, len(rv_domain)),
rv_domain,
dtype=rv.dtype,
dtype=marginalized_rv.dtype,
),
-1,
0,
Expand All @@ -431,7 +432,7 @@ def transform_input(inputs):
joint_logps_norm = log_softmax(joint_logps, axis=-1)
if return_samples:
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
if isinstance(rv.owner.op, DiscreteUniform):
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
sample_rv_outs += rv_domain[0]

rv_loglike_fn = compile_pymc(
Expand All @@ -456,18 +457,20 @@ def transform_input(inputs):
logps, samples = zip(*logvs)
logps = np.array(logps)
samples = np.array(samples)
rv_dict[rv.name] = samples.reshape(
rv_dict[marginalized_rv.name] = samples.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
)
else:
logps = np.array(logvs)

rv_dict["lp_" + rv.name] = logps.reshape(
rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
)
if rv.name in m.named_vars_to_dims:
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]
if marginalized_rv.name in m.named_vars_to_dims:
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
"lp_" + marginalized_rv.name + "_dim"
]

coords, dims = coords_and_dims_for_inferencedata(self)
dims.update(rv_dims)
Expand Down Expand Up @@ -495,6 +498,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
"""Base class for Finite Discrete Marginalized RVs"""


class DiscreteMarginalMarkovChainRV(MarginalRV):
"""Base class for Discrete Marginal Markov Chain RVs"""


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
return [
Expand Down Expand Up @@ -582,10 +589,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")

ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
if max(ndim_supp) > 0:
if len(ndim_supp) != 1:
raise NotImplementedError(
"Marginalization of withe dependent Multivariate RVs not implemented"
"Marginalization with dependent variables of different support dimensionality not implemented"
)
[ndim_supp] = ndim_supp
if ndim_supp > 0:
raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented")

marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
dependent_rvs_input_rvs = [
Expand Down Expand Up @@ -620,11 +630,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
cloned_outputs = clone_replace(outputs, replace=replace_inputs)

marginalization_op = FiniteDiscreteMarginalRV(
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
marginalize_constructor = DiscreteMarginalMarkovChainRV
else:
marginalize_constructor = FiniteDiscreteMarginalRV

marginalization_op = marginalize_constructor(
inputs=list(replace_inputs.values()),
outputs=cloned_outputs,
ndim_supp=0,
ndim_supp=ndim_supp,
)

marginalized_rvs = marginalization_op(*replace_inputs.keys())
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
return rvs_to_marginalize, marginalized_rvs
Expand All @@ -640,10 +656,29 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
elif isinstance(op, DiscreteUniform):
lower, upper = constant_fold(rv.owner.inputs[3:])
return tuple(range(lower, upper + 1))
elif isinstance(op, DiscreteMarkovChain):
P = rv.owner.inputs[0]
return tuple(range(pt.get_vector_length(P[-1])))

raise NotImplementedError(f"Cannot compute domain for op {op}")


def _add_reduce_batch_dependent_logps(
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
):
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""

mbcast = marginalized_type.broadcastable
reduced_logps = []
for dependent_logp in dependent_logps:
dbcast = dependent_logp.type.broadcastable
dim_diff = len(dbcast) - len(mbcast)
mbcast_aligned = (True,) * dim_diff + mbcast
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
reduced_logps.append(dependent_logp.sum(vbcast_axis))
return pt.add(*reduced_logps)


@_logprob.register(FiniteDiscreteMarginalRV)
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
Expand All @@ -659,17 +694,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)

# Reduce logp dimensions corresponding to broadcasted variables
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]
for inner_rv, inner_value in inner_rvs_to_values.items():
if inner_rv is marginalized_rv:
continue
vbcast = inner_value.type.broadcastable
mbcast = marginalized_rv.type.broadcastable
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)

# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, logps_dict.values()
)

# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
# values of the marginalized RV
# Some inputs are not root inputs (such as transformed projections of value variables)
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
Expand Down Expand Up @@ -697,6 +727,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
)

# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
# TODO: Try vectorize here
if len(marginalized_rv_domain) <= 10:
joint_logps = [
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
Expand All @@ -718,3 +749,69 @@ def logp_fn(marginalized_rv_const, *non_sequences):

# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
return joint_logps, *(pt.constant(0),) * (len(values) - 1)


@_logprob.register(DiscreteMarginalMarkovChainRV)
def marginal_hmm_logp(op, values, *inputs, **kwargs):

marginalized_rvs_node = op.make_node(*inputs)
inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)

chain_rv, *dependent_rvs = inner_rvs
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
domain = pt.arange(P.shape[-1], dtype="int32")

# Construct logp in two steps
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)

# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
chain_value = chain_rv.clone()
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))

# Reduce and add the batch dims beyond the chain dimension
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
chain_rv.type, logp_emissions_dict.values()
)

# Add a batch dimension for the domain of the chain
chain_shape = constant_fold(tuple(chain_rv.shape))
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})

# Step 2: Compute the transition probabilities
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
# We do it entirely in logs, though.

# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
# the initial distribution. This is robust to everything the user can throw at it.
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No way to avoid this lambda here with vectorize_graph? I recall this used to be a little function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way to avoid it is to be a little function, but seems like a fine use for lambda?

batch_chain_value[..., 0]
)
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]

def step_alpha(logp_emission, log_alpha, log_P):
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
return logp_emission + step_log_prob

P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
log_alpha_seq, _ = scan(
step_alpha,
non_sequences=[log_P],
outputs_info=[log_alpha_init],
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
)
# Final logp is just the sum of the last scan state
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)

# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
dummy_logps = (pt.constant(0),) * (len(values) - 1)
return joint_logp, *dummy_logps
Loading