Skip to content

Commit 88c3b67

Browse files
Progress on marginalization of DiscreteMarkovChain
1 parent f9d1c0e commit 88c3b67

File tree

2 files changed

+88
-75
lines changed

2 files changed

+88
-75
lines changed

pymc_experimental/marginal_model.py

+67-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
88
from pymc.distributions.transforms import Chain
99
from pymc.logprob.abstract import _logprob
10-
from pymc.logprob.basic import conditional_logp
10+
from pymc.logprob.basic import conditional_logp, logp
1111
from pymc.logprob.transforms import IntervalTransform
1212
from pymc.model import Model
1313
from pymc.pytensorf import constant_fold, inputvars
14-
from pytensor import Mode
14+
from pytensor import Mode, scan
1515
from pytensor.compile import SharedVariable
1616
from pytensor.compile.builders import OpFromGraph
1717
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
18+
from pytensor.graph.replace import vectorize_graph
1819
from pytensor.scan import map as scan_map
1920
from pytensor.tensor import TensorVariable
2021
from pytensor.tensor.elemwise import Elemwise
@@ -255,6 +256,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
255256
"""Base class for Finite Discrete Marginalized RVs"""
256257

257258

259+
class DiscreteMarginalMarkovChainRV(MarginalRV):
260+
"""Base class for Discrete Marginal Markov Chain RVs"""
261+
262+
258263
def static_shape_ancestors(vars):
259264
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
260265
return [
@@ -383,11 +388,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
383388
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
384389
cloned_outputs = clone_replace(outputs, replace=replace_inputs)
385390

386-
marginalization_op = FiniteDiscreteMarginalRV(
391+
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
392+
marginalize_constructor = DiscreteMarginalMarkovChainRV
393+
else:
394+
marginalize_constructor = FiniteDiscreteMarginalRV
395+
396+
marginalization_op = marginalize_constructor(
387397
inputs=list(replace_inputs.values()),
388398
outputs=cloned_outputs,
389399
ndim_supp=ndim_supp,
390400
)
401+
391402
marginalized_rvs = marginalization_op(*replace_inputs.keys())
392403
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
393404
return rvs_to_marginalize, marginalized_rvs
@@ -435,7 +446,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
435446
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
436447
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)
437448

438-
# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
449+
# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
439450
# values of the marginalized RV
440451
# Some inputs are not root inputs (such as transformed projections of value variables)
441452
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -487,3 +498,55 @@ def logp_fn(marginalized_rv_const, *non_sequences):
487498

488499
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
489500
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
501+
502+
503+
@_logprob.register(DiscreteMarginalMarkovChainRV)
504+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
505+
def step_alpha(log_alpha, log_P):
506+
return pt.logsumexp(log_alpha[:, None] + log_P, 0)
507+
508+
def eval_logp(x):
509+
return logp(init_dist_, x)
510+
511+
marginalized_rvs_node = op.make_node(*inputs)
512+
inner_rvs = clone_replace(
513+
op.inner_outputs,
514+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
515+
)
516+
517+
chain_rv, *dependent_rvs = inner_rvs
518+
P_, n_steps_, init_dist_, rng = chain_rv.owner.inputs
519+
520+
domain = pt.arange(P_.shape[0], dtype="int32")
521+
522+
vec_eval_logp = pt.vectorize(eval_logp, "()->()")
523+
logp_init = vec_eval_logp(domain)
524+
525+
# This will break the dependency between chain and the init_dist_ random variable
526+
# TODO: Make this comment more robust after I understand better.
527+
chain_dummy = chain_rv.clone()
528+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_dummy})
529+
input_dict = dict(zip(dependent_rvs, values))
530+
logp_value_dict = conditional_logp(input_dict)
531+
532+
# TODO: Is values[0] robust to every situation?
533+
sub_dict = {
534+
chain_dummy: pt.moveaxis(pt.broadcast_to(domain, (*values[0].shape, domain.size)), -1, 0)
535+
}
536+
537+
# TODO: @Ricardo: If you don't concatenate here, you get -inf in the logp (why?)
538+
# TODO: I'm stacking the results (adds a batch dim to the left) and summing away the batch dim == joint probability?
539+
vec_logp_emission = pt.stack(vectorize_graph(tuple(logp_value_dict.values()), sub_dict)).sum(
540+
axis=0
541+
)
542+
543+
log_alpha_seq, _ = scan(
544+
step_alpha, non_sequences=[pt.log(P_)], outputs_info=[logp_init], n_steps=n_steps_
545+
)
546+
547+
log_alpha_seq = pt.moveaxis(pt.concatenate([logp_init[None], log_alpha_seq], axis=0), -1, 0)
548+
joint_log_obs_given_states = pt.logsumexp(pt.add(vec_logp_emission) + log_alpha_seq, axis=0)
549+
550+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
551+
dummy_logps = (pt.constant(0.0),) * (len(values) - 1)
552+
return joint_log_obs_given_states, dummy_logps

pymc_experimental/tests/test_marginal_model.py

+21-71
Original file line numberDiff line numberDiff line change
@@ -473,88 +473,38 @@ def dist(idx, size):
473473
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
474474

475475

476-
def hmm_logp(values, P, steps, init_dist, state_rng):
477-
478-
[e_value] = values
479-
480-
# P = [[0, 1], [1, 0]]
481-
domain = tuple(range(pt.get_vector_length(P[-1])))
482-
483-
# This should be done on log-scale
484-
# Probability of states at t0
485-
logprob_states = pm.math.stack([logp(init_dist, d) for d in domain])
486-
487-
logprob_emiss_ts = []
488-
for e_value_t in e_value:
489-
# Use vectorize
490-
logprob_emiss_t = pt.sum(
491-
[
492-
logpprob_state + logp(clone_replace(emission_rv, replace={state_rv: state_value}), e_value_t)
493-
for (logpprob_state, state_value) in zip(logprob_states, domain)
494-
]
495-
)
496-
497-
# Probability next state
498-
# prob_states = prob_states @ P
499-
logprob_states = P[:, None]
500-
501-
logprob_emiss_ts.append(logprob_emiss_t)
476+
def test_marginalized_hmm_with_one_emission():
477+
with MarginalModel() as m:
478+
P = [[0, 1], [1, 0]]
479+
init_dist = pm.Categorical.dist(p=[1, 0])
480+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3)
481+
emission = pm.Normal("emission", mu=chain * 2 - 1, sigma=1e-1)
502482

503-
return logprob_emiss_ts.sum()
483+
m.marginalize([chain])
504484

485+
logp_fn = m.compile_logp()
486+
test_value = [-1, 1, -1, 1]
505487

488+
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
489+
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)
506490

507-
def test_hmm():
508491

492+
def test_marginalized_hmm_with_many_emissions():
509493
with MarginalModel() as m:
510-
p = pt.as_tensor(np.array([1, 0]))
511-
512-
chain_0 = pm.Bernoulli("chain_0", p=0)
513-
chain_1 = pm.Bernoulli("chain_1", p=p[chain_0])
514-
chain_2 = pm.Bernoulli("chain_2", p=p[chain_1])
515-
chain_3 = pm.Bernoulli("chain_3", p=p[chain_2])
516-
517-
pm.Normal("emission_0", chain_0 * 2 - 1, sigma=1e-1)
518-
pm.Normal("emission_1", chain_1 * 2 - 1, sigma=1e-1)
519-
pm.Normal("emission_2", chain_2 * 2 - 1, sigma=1e-1)
520-
pm.Normal("emission_3", chain_3 * 2 - 1, sigma=1e-1)
521-
522-
494+
P = [[0, 1], [1, 0]]
495+
init_dist = pm.Categorical.dist(p=[1, 0])
496+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3)
497+
emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1)
498+
emission_2 = pm.Normal("emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1)
523499

524500
with pytest.warns(UserWarning, match="multiple dependent variables"):
525-
m.marginalize([chain_0, chain_1, chain_2, chain_3])
526-
import pytensor
527-
print()
528-
pytensor.dprint(m.clone()._marginalize().free_RVs)
501+
m.marginalize([chain])
529502

530503
logp_fn = m.compile_logp()
531504
test_value = [-1, 1, -1, 1]
532505

533506
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
507+
test_point = {"emission_1": test_value, "emission_2": test_value * -1}
534508

535-
np.testing.assert_allclose(
536-
logp_fn({f"emission_{i}": test_value_i for i, test_value_i in enumerate(test_value)}),
537-
expected_logp,
538-
)
539-
return
540-
541-
# with MarginalModel() as m:
542-
# P = [[0, 1], [1, 0]]
543-
# zero = pm.DiracDelta.dist(np.array(0, dtype="int64"))
544-
# chain = DiscreteMarkovChain("chain", P=P, init_dist=zero, steps=3)
545-
# emmission = pm.Normal("emission", mu=chain * 2 - 1, sigma=1e-1)
546-
# np.testing.assert_equal(pm.draw(chain), [0, 1, 0, 1])
547-
# m.marginalize(chain)
548-
549-
# test_value = [-1, 1, -1, 1]
550-
# expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), [0, 0, 0, 0]).eval()
551-
552-
# np.testing.assert_allclose(
553-
# logp_fn({"emission": test_value}),
554-
# expected_logp,
555-
# )
556-
#
557-
# np.testing.assert_allclose(
558-
# logp_fn({f"emission{i}": test_value_i for i, test_value_i in enumerate(test_value)}),
559-
# expected_logp,
560-
# )
509+
assert False
510+
# np.testing.assert_allclose(logp_fn(test_point), expected_logp)

0 commit comments

Comments
 (0)