Skip to content

Add DiscreteMarkovChain distribution #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c42c94a
add DiscreteMarkovChainRV
jessegrabowski Dec 16, 2022
fd472bd
remove `validate_transition_matrix`
jessegrabowski Dec 17, 2022
37ef8e0
remove `x0` argument
jessegrabowski Dec 17, 2022
b4e15db
Add reshape logic to `rv_op` based on size and `init_dist`
jessegrabowski Dec 17, 2022
a3408ba
Remove moot TODO comments
jessegrabowski Dec 17, 2022
22bfe17
Update and re-run example notebook
jessegrabowski Dec 18, 2022
3aef66d
Update `pytensor` alias to `pt`
jessegrabowski Dec 18, 2022
c2d5fc6
Remove moment method
jessegrabowski Dec 18, 2022
d77337c
Wrap tests into test class
jessegrabowski Dec 18, 2022
82978f0
Add test for default initial distribution warning
jessegrabowski Dec 18, 2022
b797af3
Replace `.dimshuffle` with `pt.moveaxis` in `rv_op`
jessegrabowski Dec 18, 2022
d827586
Fix scan error
jessegrabowski Dec 18, 2022
0267801
Add test for `change_dist_size`
jessegrabowski Dec 18, 2022
b7794ed
Add code example to `DiscreteMarkovChain` docstring
jessegrabowski Dec 18, 2022
ab00e6c
Update pymc_experimental/distributions/timeseries.py
jessegrabowski Dec 18, 2022
bfb6b43
Remove shape argument from default `init_dist`
jessegrabowski Dec 18, 2022
ad3a878
Use shape parameter in example code
jessegrabowski Dec 18, 2022
7912d02
Remove steps adjustment from `__new__`
jessegrabowski Dec 18, 2022
85233f6
Remove `.squeeze()` from scan output
jessegrabowski Dec 18, 2022
b42184d
Remove dimension check on `init_dist`
jessegrabowski Dec 18, 2022
4773aa8
Fix batch size detection
jessegrabowski Dec 19, 2022
a2b79f2
Add support for n_lags > 1
jessegrabowski Dec 20, 2022
7fb1e56
Fix shape of markov_chain when P has a batch_size and n_lags == 1.
jessegrabowski Dec 20, 2022
847a2f9
Add test to recover P when n_lags > 1
jessegrabowski Dec 20, 2022
c52191e
Fix `logp` shape checking wrong dimension of `init_dist`
jessegrabowski Dec 20, 2022
ed2983f
Updates imports following pymc-devs/pymc#6441
jessegrabowski Apr 15, 2023
bbaa81e
Add a moment function to `DiscreteMarkovRV`
jessegrabowski Apr 15, 2023
9ac136b
Raise `NotImplementedError` if `init_dist` is not `pm.Categorical`
jessegrabowski Apr 15, 2023
2673b12
Update example notebook with some new plots
jessegrabowski Apr 15, 2023
b2df5a2
Fix a bug that broke `n_lags` > 1
jessegrabowski Apr 15, 2023
daffdc8
Rename test function to correctly match test
jessegrabowski Apr 16, 2023
c35bc31
rebase from main
jessegrabowski Apr 17, 2023
170e20b
Add `timeseries.DiscreteMarkovChain` to `api_reference.rst`
jessegrabowski Apr 17, 2023
eb23686
Remove check on `init_dist`
jessegrabowski Apr 17, 2023
c24a86d
Add `DiscreteMarkovChain` to `distribtuions.__all__`
jessegrabowski Apr 17, 2023
ee141ef
Change example notebook title, add subtitles, add plots comparing res…
jessegrabowski Apr 17, 2023
f138cac
Pass `init_dist` to all tests to avoid `UserWarning`
jessegrabowski Apr 17, 2023
66a3198
Fix flakey `test_moment_function` test
jessegrabowski Apr 17, 2023
10e2817
Fix latex in docstring
jessegrabowski Apr 17, 2023
7ef7ae6
Apply suggestions from code review
jessegrabowski Apr 17, 2023
0624d2d
Fix latex in docstring
jessegrabowski Apr 17, 2023
f07cdba
Merge branch 'discrete-markov' of https://github.com/jessegrabowski/p…
jessegrabowski Apr 17, 2023
994e3fb
Fix latex in docstring
jessegrabowski Apr 17, 2023
3c96dc8
Fix warning in docstring
jessegrabowski Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Distributions

GenExtreme
histogram_utils.histogram_approximation
DiscreteMarkovChain


Gaussian Processess
Expand Down
1,403 changes: 1,403 additions & 0 deletions notebooks/discrete_markov_chain.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions pymc_experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

from pymc_experimental.distributions.continuous import GenExtreme
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain

__all__ = [
"GenExtreme",
]
__all__ = ["GenExtreme", "DiscreteMarkovChain"]
267 changes: 267 additions & 0 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
import warnings
from typing import List, Union

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
_moment,
moment,
)
from pymc.distributions.shape_utils import (
_change_dist_size,
change_dist_size,
get_support_shape_1d,
)
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import logp
from pymc.logprob.utils import ignore_logprob
from pymc.pytensorf import intX
from pymc.util import check_dist_not_registered
from pytensor.graph.basic import Node
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable


def _make_outputs_info(n_lags: int, init_dist: Distribution) -> List[Union[Distribution, dict]]:
"""
Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of
desired (steps, *batch_size)

Parameters
----------
n_lags: int
Number of lags the Markov Chain considers when transitioning to the next state
init_dist: RandomVariable
Distribution over initial states

Returns
-------
taps: list
Lags to be fed into pytensor.scan when drawing a markov chain
"""

if n_lags > 1:
return [{"initial": init_dist, "taps": list(range(-n_lags, 0))}]
else:
return [init_dist[0]]


class DiscreteMarkovChainRV(SymbolicRandomVariable):
n_lags: int
default_output = 1
_print_name = ("DiscreteMC", "\\operatorname{DiscreteMC}")

def __init__(self, *args, n_lags, **kwargs):
self.n_lags = n_lags
super().__init__(*args, **kwargs)

def update(self, node: Node):
return {node.inputs[-1]: node.outputs[0]}


class DiscreteMarkovChain(Distribution):
r"""
A Discrete Markov Chain is a sequence of random variables

.. math::

\{x_t\}_{t=0}^T

Where transition probability :math:`P(x_t | x_{t-1})` depends only on the state of the system at :math:`x_{t-1}`.

Parameters
----------
P: tensor
Matrix of transition probabilities between states. Rows must sum to 1.
One of P or P_logits must be provided.
P_logit: tensor, optional
Matrix of transition logits. Converted to probabilities via Softmax activation.
One of P or P_logits must be provided.
steps: tensor, optional
Length of the markov chain. Only needed if state is not provided.
init_dist : unnamed distribution, optional
Vector distribution for initial values. Unnamed refers to distributions
created with the ``.dist()`` API. Distribution should have shape n_states.
If not, it will be automatically resized.

.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.

Notes
-----
The initial distribution will be cloned, rendering it distinct from the one passed as
input.

Examples
--------
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
3 in this case.

>>> with pm.Model() as markov_chain:
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))

"""

rv_type = DiscreteMarkovChainRV

def __new__(cls, *args, steps=None, n_lags=1, **kwargs):
steps = get_support_shape_1d(
support_shape=steps,
shape=None,
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
support_shape_offset=n_lags,
)

return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)

@classmethod
def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwargs):
steps = get_support_shape_1d(
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=n_lags
)

if steps is None:
raise ValueError("Must specify steps or shape parameter")
if P is None and logit_P is None:
raise ValueError("Must specify P or logit_P parameter")
if P is not None and logit_P is not None:
raise ValueError("Must specify only one of either P or logit_P parameter")

if logit_P is not None:
P = pm.math.softmax(logit_P, axis=-1)

P = pt.as_tensor_variable(P)
steps = pt.as_tensor_variable(intX(steps))

if init_dist is not None:
if not isinstance(init_dist, TensorVariable) or not isinstance(
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
):
raise ValueError(
f"Init dist must be a distribution created via the `.dist()` API, "
f"got {type(init_dist)}"
)

check_dist_not_registered(init_dist)
if init_dist.owner.op.ndim_supp > 1:
raise ValueError(
"Init distribution must have a scalar or vector support dimension, ",
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
)
else:
warnings.warn(
"Initial distribution not specified, defaulting to "
"`Categorical.dist(p=pt.full((k_states, ), 1/k_states), shape=...)`. You can specify an init_dist "
"manually to suppress this warning.",
UserWarning,
)
k = P.shape[-1]
init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))

# We can ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)

@classmethod
def rv_op(cls, P, steps, init_dist, n_lags, size=None):
if size is not None:
batch_size = size
else:
batch_size = pt.broadcast_shape(
P[tuple([...] + [0] * (n_lags + 1))], pt.atleast_1d(init_dist)[..., 0]
)

init_dist = change_dist_size(init_dist, (n_lags, *batch_size))
init_dist_ = init_dist.type()
P_ = P.type()
steps_ = steps.type()

state_rng = pytensor.shared(np.random.default_rng())

def transition(*args):
*states, transition_probs, old_rng = args
p = transition_probs[tuple(states)]
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
return next_state, {old_rng: next_rng}

markov_chain, state_updates = pytensor.scan(
transition,
non_sequences=[P_, state_rng],
outputs_info=_make_outputs_info(n_lags, init_dist_),
n_steps=steps_,
strict=True,
)

(state_next_rng,) = tuple(state_updates.values())

discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)

discrete_mc_op = DiscreteMarkovChainRV(
inputs=[P_, steps_, init_dist_],
outputs=[state_next_rng, discrete_mc_],
ndim_supp=1,
n_lags=n_lags,
)

discrete_mc = discrete_mc_op(P, steps, init_dist)
return discrete_mc


@_change_dist_size.register(DiscreteMarkovChainRV)
def change_mc_size(op, dist, new_size, expand=False):
if expand:
old_size = dist.shape[:-1]
new_size = tuple(new_size) + tuple(old_size)

return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags)


@_moment.register(DiscreteMarkovChainRV)
def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
init_dist_moment = moment(init_dist)
n_lags = op.n_lags

def greedy_transition(*args):
*states, transition_probs, old_rng = args
p = transition_probs[tuple(states)]
return pt.argmax(p)

chain_moment, moment_updates = pytensor.scan(
greedy_transition,
non_sequences=[P, state_rng],
outputs_info=_make_outputs_info(n_lags, init_dist),
n_steps=steps,
strict=True,
)
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
return chain_moment


@_logprob.register(DiscreteMarkovChainRV)
def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
value = values[0]
n_lags = op.n_lags

indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]

mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)

return check_parameters(
mc_logprob,
pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
pt.eq(pt.atleast_1d(init_dist).shape[0], n_lags),
msg="Last (n_lags + 1) dimensions of P must be square, "
"P must sum to 1 along the last axis, "
"First dimension of init_dist must be n_lags",
)
Loading