|
7 | 7 | from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
8 | 8 | from pymc.distributions.transforms import Chain
|
9 | 9 | from pymc.logprob.abstract import _logprob
|
10 |
| -from pymc.logprob.basic import conditional_logp |
| 10 | +from pymc.logprob.basic import conditional_logp, logp |
11 | 11 | from pymc.logprob.transforms import IntervalTransform
|
12 | 12 | from pymc.model import Model
|
13 | 13 | from pymc.pytensorf import constant_fold, inputvars
|
14 |
| -from pytensor import Mode |
| 14 | +from pytensor import Mode, scan |
15 | 15 | from pytensor.compile import SharedVariable
|
16 | 16 | from pytensor.compile.builders import OpFromGraph
|
17 | 17 | from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
|
| 18 | +from pytensor.graph.replace import vectorize_graph |
18 | 19 | from pytensor.scan import map as scan_map
|
19 | 20 | from pytensor.tensor import TensorVariable
|
20 | 21 | from pytensor.tensor.elemwise import Elemwise
|
@@ -255,6 +256,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
|
255 | 256 | """Base class for Finite Discrete Marginalized RVs"""
|
256 | 257 |
|
257 | 258 |
|
| 259 | +class DiscreteMarginalMarkovChainRV(MarginalRV): |
| 260 | + """Base class for Discrete Marginal Markov Chain RVs""" |
| 261 | + |
| 262 | + |
258 | 263 | def static_shape_ancestors(vars):
|
259 | 264 | """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
|
260 | 265 | return [
|
@@ -383,11 +388,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
|
383 | 388 | replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
|
384 | 389 | cloned_outputs = clone_replace(outputs, replace=replace_inputs)
|
385 | 390 |
|
386 |
| - marginalization_op = FiniteDiscreteMarginalRV( |
| 391 | + if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): |
| 392 | + marginalize_constructor = DiscreteMarginalMarkovChainRV |
| 393 | + else: |
| 394 | + marginalize_constructor = FiniteDiscreteMarginalRV |
| 395 | + |
| 396 | + marginalization_op = marginalize_constructor( |
387 | 397 | inputs=list(replace_inputs.values()),
|
388 | 398 | outputs=cloned_outputs,
|
389 | 399 | ndim_supp=ndim_supp,
|
390 | 400 | )
|
| 401 | + |
391 | 402 | marginalized_rvs = marginalization_op(*replace_inputs.keys())
|
392 | 403 | fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
|
393 | 404 | return rvs_to_marginalize, marginalized_rvs
|
@@ -435,7 +446,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
|
435 | 446 | values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
|
436 | 447 | joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)
|
437 | 448 |
|
438 |
| - # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different |
| 449 | + # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different |
439 | 450 | # values of the marginalized RV
|
440 | 451 | # Some inputs are not root inputs (such as transformed projections of value variables)
|
441 | 452 | # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
|
@@ -487,3 +498,55 @@ def logp_fn(marginalized_rv_const, *non_sequences):
|
487 | 498 |
|
488 | 499 | # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
|
489 | 500 | return joint_logps, *(pt.constant(0),) * (len(values) - 1)
|
| 501 | + |
| 502 | + |
| 503 | +@_logprob.register(DiscreteMarginalMarkovChainRV) |
| 504 | +def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): |
| 505 | + def step_alpha(log_alpha, log_P): |
| 506 | + return pt.logsumexp(log_alpha[:, None] + log_P, 0) |
| 507 | + |
| 508 | + def eval_logp(x): |
| 509 | + return logp(init_dist_, x) |
| 510 | + |
| 511 | + marginalized_rvs_node = op.make_node(*inputs) |
| 512 | + inner_rvs = clone_replace( |
| 513 | + op.inner_outputs, |
| 514 | + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, |
| 515 | + ) |
| 516 | + |
| 517 | + chain_rv, *dependent_rvs = inner_rvs |
| 518 | + P_, n_steps_, init_dist_, rng = chain_rv.owner.inputs |
| 519 | + |
| 520 | + domain = pt.arange(P_.shape[0], dtype="int32") |
| 521 | + |
| 522 | + vec_eval_logp = pt.vectorize(eval_logp, "()->()") |
| 523 | + logp_init = vec_eval_logp(domain) |
| 524 | + |
| 525 | + # This will break the dependency between chain and the init_dist_ random variable |
| 526 | + # TODO: Make this comment more robust after I understand better. |
| 527 | + chain_dummy = chain_rv.clone() |
| 528 | + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_dummy}) |
| 529 | + input_dict = dict(zip(dependent_rvs, values)) |
| 530 | + logp_value_dict = conditional_logp(input_dict) |
| 531 | + |
| 532 | + # TODO: Is values[0] robust to every situation? |
| 533 | + sub_dict = { |
| 534 | + chain_dummy: pt.moveaxis(pt.broadcast_to(domain, (*values[0].shape, domain.size)), -1, 0) |
| 535 | + } |
| 536 | + |
| 537 | + # TODO: @Ricardo: If you don't concatenate here, you get -inf in the logp (why?) |
| 538 | + # TODO: I'm stacking the results (adds a batch dim to the left) and summing away the batch dim == joint probability? |
| 539 | + vec_logp_emission = pt.stack(vectorize_graph(tuple(logp_value_dict.values()), sub_dict)).sum( |
| 540 | + axis=0 |
| 541 | + ) |
| 542 | + |
| 543 | + log_alpha_seq, _ = scan( |
| 544 | + step_alpha, non_sequences=[pt.log(P_)], outputs_info=[logp_init], n_steps=n_steps_ |
| 545 | + ) |
| 546 | + |
| 547 | + log_alpha_seq = pt.moveaxis(pt.concatenate([logp_init[None], log_alpha_seq], axis=0), -1, 0) |
| 548 | + joint_log_obs_given_states = pt.logsumexp(pt.add(vec_logp_emission) + log_alpha_seq, axis=0) |
| 549 | + |
| 550 | + # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise |
| 551 | + dummy_logps = (pt.constant(0.0),) * (len(values) - 1) |
| 552 | + return joint_log_obs_given_states, dummy_logps |
0 commit comments