Skip to content

Add DiscreteMarkovChain distribution #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c42c94a
add DiscreteMarkovChainRV
jessegrabowski Dec 16, 2022
fd472bd
remove `validate_transition_matrix`
jessegrabowski Dec 17, 2022
37ef8e0
remove `x0` argument
jessegrabowski Dec 17, 2022
b4e15db
Add reshape logic to `rv_op` based on size and `init_dist`
jessegrabowski Dec 17, 2022
a3408ba
Remove moot TODO comments
jessegrabowski Dec 17, 2022
22bfe17
Update and re-run example notebook
jessegrabowski Dec 18, 2022
3aef66d
Update `pytensor` alias to `pt`
jessegrabowski Dec 18, 2022
c2d5fc6
Remove moment method
jessegrabowski Dec 18, 2022
d77337c
Wrap tests into test class
jessegrabowski Dec 18, 2022
82978f0
Add test for default initial distribution warning
jessegrabowski Dec 18, 2022
b797af3
Replace `.dimshuffle` with `pt.moveaxis` in `rv_op`
jessegrabowski Dec 18, 2022
d827586
Fix scan error
jessegrabowski Dec 18, 2022
0267801
Add test for `change_dist_size`
jessegrabowski Dec 18, 2022
b7794ed
Add code example to `DiscreteMarkovChain` docstring
jessegrabowski Dec 18, 2022
ab00e6c
Update pymc_experimental/distributions/timeseries.py
jessegrabowski Dec 18, 2022
bfb6b43
Remove shape argument from default `init_dist`
jessegrabowski Dec 18, 2022
ad3a878
Use shape parameter in example code
jessegrabowski Dec 18, 2022
7912d02
Remove steps adjustment from `__new__`
jessegrabowski Dec 18, 2022
85233f6
Remove `.squeeze()` from scan output
jessegrabowski Dec 18, 2022
b42184d
Remove dimension check on `init_dist`
jessegrabowski Dec 18, 2022
4773aa8
Fix batch size detection
jessegrabowski Dec 19, 2022
a2b79f2
Add support for n_lags > 1
jessegrabowski Dec 20, 2022
7fb1e56
Fix shape of markov_chain when P has a batch_size and n_lags == 1.
jessegrabowski Dec 20, 2022
847a2f9
Add test to recover P when n_lags > 1
jessegrabowski Dec 20, 2022
c52191e
Fix `logp` shape checking wrong dimension of `init_dist`
jessegrabowski Dec 20, 2022
ed2983f
Updates imports following pymc-devs/pymc#6441
jessegrabowski Apr 15, 2023
bbaa81e
Add a moment function to `DiscreteMarkovRV`
jessegrabowski Apr 15, 2023
9ac136b
Raise `NotImplementedError` if `init_dist` is not `pm.Categorical`
jessegrabowski Apr 15, 2023
2673b12
Update example notebook with some new plots
jessegrabowski Apr 15, 2023
b2df5a2
Fix a bug that broke `n_lags` > 1
jessegrabowski Apr 15, 2023
daffdc8
Rename test function to correctly match test
jessegrabowski Apr 16, 2023
c35bc31
rebase from main
jessegrabowski Apr 17, 2023
170e20b
Add `timeseries.DiscreteMarkovChain` to `api_reference.rst`
jessegrabowski Apr 17, 2023
eb23686
Remove check on `init_dist`
jessegrabowski Apr 17, 2023
c24a86d
Add `DiscreteMarkovChain` to `distribtuions.__all__`
jessegrabowski Apr 17, 2023
ee141ef
Change example notebook title, add subtitles, add plots comparing res…
jessegrabowski Apr 17, 2023
f138cac
Pass `init_dist` to all tests to avoid `UserWarning`
jessegrabowski Apr 17, 2023
66a3198
Fix flakey `test_moment_function` test
jessegrabowski Apr 17, 2023
10e2817
Fix latex in docstring
jessegrabowski Apr 17, 2023
7ef7ae6
Apply suggestions from code review
jessegrabowski Apr 17, 2023
0624d2d
Fix latex in docstring
jessegrabowski Apr 17, 2023
f07cdba
Merge branch 'discrete-markov' of https://github.com/jessegrabowski/p…
jessegrabowski Apr 17, 2023
994e3fb
Fix latex in docstring
jessegrabowski Apr 17, 2023
3c96dc8
Fix warning in docstring
jessegrabowski Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,246 changes: 1,246 additions & 0 deletions notebooks/discrete_markov_chain.ipynb

Large diffs are not rendered by default.

215 changes: 215 additions & 0 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import warnings

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Distribution, SymbolicRandomVariable
from pymc.distributions.logprob import ignore_logprob, logp
from pymc.distributions.shape_utils import (
_change_dist_size,
change_dist_size,
get_support_shape_1d,
)
from pymc.logprob.abstract import _logprob
from pymc.pytensorf import intX
from pymc.util import check_dist_not_registered
from pytensor.graph.basic import Node
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable


class DiscreteMarkovChainRV(SymbolicRandomVariable):
n_lags: int
default_output = 1
_print_name = ("DiscreteMC", "\\operatorname{DiscreteMC}")

def __init__(self, *args, n_lags, **kwargs):
self.n_lags = n_lags
super().__init__(*args, **kwargs)

def update(self, node: Node):
return {node.inputs[-1]: node.outputs[0]}


class DiscreteMarkovChain(Distribution):
r"""
A Discrete Markov Chain is a sequence of random variables
.. math::
\{x_t\}_{t=0}^T
Where transition probability P(x_t | x_{t-1}) depends only on the state of the system at x_{t-1}.

Parameters
----------
P: tensor
Matrix of transition probabilities between states. Rows must sum to 1.
One of P or P_logits must be provided.
P_logit: tensor, optional
Matrix of transition logits. Converted to probabilities via Softmax activation.
One of P or P_logits must be provided.
steps: tensor, optional
Length of the markov chain. Only needed if state is not provided.
init_dist : unnamed distribution, optional
Vector distribution for initial values. Unnamed refers to distributions
created with the ``.dist()`` API. Distribution should have shape n_states.
If not, it will be automatically resized. Defaults to pm.Categorical.dist(p=np.full(n_states, 1/n_states)).
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.

Notes
-----
The initial distribution will be cloned, rendering it distinct from the one passed as
input.

Examples
--------
.. code-block:: python
# Create a Markov Chain of length 100 with 3 states
with pm.Model() as markov_chain:
# The transition probability matrix should be square with rows that sum to 1
# The number of states in the markov chain is given by the shape of P, 3 in this example
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
# The initial state probabilities should have size = n_states, or 3 in this case.
init = pm.Categorical.dist(p = np.full(3, 1 / 3))
markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init, shape=(100,))
"""

rv_type = DiscreteMarkovChainRV

def __new__(cls, *args, steps=None, n_lags=1, initval="prior", **kwargs):

steps = get_support_shape_1d(
support_shape=steps,
shape=None,
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
support_shape_offset=n_lags,
)

return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)

@classmethod
def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwargs):

steps = get_support_shape_1d(
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1
)

if steps is None:
raise ValueError("Must specify steps or shape parameter")
if P is None and logit_P is None:
raise ValueError("Must specify P or logit_P parameter")
if P is not None and logit_P is not None:
raise ValueError("Must specify only one of either P or logit_P parameter")

if logit_P is not None:
P = pm.math.softmax(logit_P, axis=-1)

P = pt.as_tensor_variable(P)
steps = pt.as_tensor_variable(intX(steps))

if init_dist is not None:
if not isinstance(init_dist, TensorVariable) or not isinstance(
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
):
raise ValueError(
f"Init dist must be a distribution created via the `.dist()` API, "
f"got {type(init_dist)}"
)
check_dist_not_registered(init_dist)
if init_dist.owner.op.ndim_supp > 1:
raise ValueError(
"Init distribution must have a scalar or vector support dimension, ",
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
)
else:
warnings.warn(
"Initial distribution not specified, defaulting to "
"`Categorical.dist(p=pt.full((k_states, ), 1/k_states), shape=...)`. You can specify an init_dist "
"manually to suppress this warning.",
UserWarning,
)
k = P.shape[-1]
init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))

# We can ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)

@classmethod
def rv_op(cls, P, steps, init_dist, n_lags, size=None):
if size is not None:
batch_size = size
else:
batch_size = pt.broadcast_shape(
P[tuple([...] + [0] * (n_lags + 1))], pt.atleast_1d(init_dist)[..., 0]
)

init_dist = change_dist_size(init_dist, (n_lags, *batch_size))

init_dist_ = init_dist.type()
P_ = P.type()
steps_ = steps.type()

state_rng = pytensor.shared(np.random.default_rng())

def transition(*args):
*states, transition_probs, old_rng = args
p = transition_probs[tuple(states)]
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
return next_state, {old_rng: next_rng}

markov_chain, state_updates = pytensor.scan(
transition,
non_sequences=[P_, state_rng],
outputs_info=[{"initial": init_dist_, "taps": list(range(-n_lags, 0))}],
n_steps=steps_,
strict=True,
)

(state_next_rng,) = tuple(state_updates.values())

discrete_mc_ = pt.moveaxis(
pt.concatenate([init_dist_, markov_chain.squeeze()], axis=0), 0, -1
)

discrete_mc_op = DiscreteMarkovChainRV(
inputs=[P_, steps_, init_dist_],
outputs=[state_next_rng, discrete_mc_],
ndim_supp=1,
n_lags=n_lags,
)

discrete_mc = discrete_mc_op(P, steps, init_dist)
return discrete_mc


@_change_dist_size.register(DiscreteMarkovChainRV)
def change_mc_size(op, dist, new_size, expand=False):
if expand:
old_size = dist.shape[:-1]
new_size = tuple(new_size) + tuple(old_size)

return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags)


@_logprob.register(DiscreteMarkovChainRV)
def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
value = values[0]
n_lags = op.n_lags

indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]

mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)

return check_parameters(
mc_logprob,
pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
pt.eq(pt.atleast_1d(init_dist).shape[-1], n_lags),
msg="Last (n_lags + 1) dimensions of P must be square, "
"P must sum to 1 along the last axis"
"Last dimension of init_dist must be n_lags",
)
145 changes: 145 additions & 0 deletions pymc_experimental/tests/distributions/test_discrete_markov_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import numpy as np
import pymc as pm

# general imports
import pytensor.tensor as pt
import pytest
from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.utils import ParameterValueError

from pymc_experimental.distributions.timeseries import DiscreteMarkovChain


class TestDiscreteMarkovRV:
def test_fail_if_P_not_square(self):
P = pt.eye(3, 2)
chain = DiscreteMarkovChain.dist(P=P, steps=3)
with pytest.raises(ParameterValueError):
pm.logp(chain, np.zeros((3,))).eval()

def test_fail_if_P_not_valid(self):
P = pt.zeros((3, 3))
chain = DiscreteMarkovChain.dist(P=P, steps=3)
with pytest.raises(ParameterValueError):
pm.logp(chain, np.zeros((3,))).eval()

def test_default_init_dist_warns_user(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))

with pytest.warns(UserWarning):
DiscreteMarkovChain.dist(P=P, steps=3)

def test_logp_shape(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))

# Test with steps
chain = DiscreteMarkovChain.dist(P=P, steps=3)
draws = pm.draw(chain, 5)
logp = pm.logp(chain, draws).eval()

assert logp.shape == (5,)

# Test with shape
chain = DiscreteMarkovChain.dist(P=P, shape=(3,))
draws = pm.draw(chain, 5)
logp = pm.logp(chain, draws).eval()

assert logp.shape == (5,)

def test_logp_with_default_init_dist(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
chain = DiscreteMarkovChain.dist(P=P, steps=3)

logp = pm.logp(chain, [0, 1, 2]).eval()
assert logp == np.log((1 / 3) * 0.5 * 0.3)

def test_logp_with_user_defined_init_dist(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
x0 = pm.Categorical.dist(p=[0.2, 0.6, 0.2])
chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)

logp = pm.logp(chain, [0, 1, 2]).eval()
assert logp == np.log(0.2 * 0.5 * 0.3)

def test_define_steps_via_shape_arg(self):
P = pt.full((3, 3), 1 / 3)
chain = DiscreteMarkovChain.dist(P=P, shape=(3,))
assert chain.eval().shape == (3,)

chain = DiscreteMarkovChain.dist(P=P, shape=(3, 2))
assert chain.eval().shape == (3, 2)

def test_define_steps_via_dim_arg(self):
coords = {"steps": [1, 2, 3]}

with pm.Model(coords=coords):
P = pt.full((3, 3), 1 / 3)
chain = DiscreteMarkovChain("chain", P=P, dims=["steps"])

assert chain.eval().shape == (3,)

def test_dims_when_steps_are_defined(self):
coords = {"steps": [1, 2, 3, 4]}

with pm.Model(coords=coords):
P = pt.full((3, 3), 1 / 3)
chain = DiscreteMarkovChain("chain", P=P, steps=3, dims=["steps"])

assert chain.eval().shape == (4,)

def test_multiple_dims_with_steps(self):
coords = {"steps": [1, 2, 3], "mc_chains": [1, 2, 3]}

with pm.Model(coords=coords):
P = pt.full((3, 3), 1 / 3)
chain = DiscreteMarkovChain("chain", P=P, steps=2, dims=["steps", "mc_chains"])

assert chain.eval().shape == (3, 3)

def test_mutiple_dims_with_steps_and_init_dist(self):
coords = {"steps": [1, 2, 3], "mc_chains": [1, 2, 3]}

with pm.Model(coords=coords):
P = pt.full((3, 3), 1 / 3)
x0 = pm.Categorical.dist(p=[0.1, 0.1, 0.8], size=(3,))
chain = DiscreteMarkovChain(
"chain", P=P, init_dist=x0, steps=2, dims=["steps", "mc_chains"]
)

assert chain.eval().shape == (3, 3)

def test_random_draws(self):
steps = 3
n_states = 2
n_draws = 2500
atol = 0.05

P = np.full((n_states, n_states), 1 / n_states)
chain = DiscreteMarkovChain.dist(P=pt.as_tensor_variable(P), steps=steps)

draws = pm.draw(chain, n_draws)

# Test x0 is uniform over n_states
assert np.allclose(
np.histogram(draws[:, ..., 0], bins=n_states)[0] / n_draws, 1 / n_states, atol=atol
)

bigrams = [(chain[..., i], chain[..., i + 1]) for chain in draws for i in range(1, steps)]
freq_table = np.zeros((n_states, n_states))
for bigram in bigrams:
i, j = bigram
freq_table[i, j] += 1
freq_table /= freq_table.sum(axis=1)[:, None]

# Test continuation probabilities match P
assert np.allclose(P, freq_table, atol=atol)

def test_change_size_univariate(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
chain = DiscreteMarkovChain.dist(P=P, shape=(100, 5))

new_rw = change_dist_size(chain, new_size=(7,))
assert tuple(new_rw.shape.eval()) == (7, 5)

new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)