Skip to content

Commit 9041d11

Browse files
committed
Implement step method sampler for DiscreteMarkovChain
1 parent aec5c76 commit 9041d11

File tree

2 files changed

+122
-6
lines changed

2 files changed

+122
-6
lines changed

pymc_experimental/distributions/timeseries.py

+82-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from pymc.logprob.abstract import _logprob
2121
from pymc.logprob.basic import logp
2222
from pymc.pytensorf import constant_fold, intX
23-
from pymc.util import check_dist_not_registered
23+
from pymc.step_methods import STEP_METHODS
24+
from pymc.step_methods.arraystep import ArrayStep
25+
from pymc.step_methods.compound import Competence
26+
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
27+
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
28+
from pytensor import Mode
2429
from pytensor.graph.basic import Node
2530
from pytensor.tensor import TensorVariable
2631
from pytensor.tensor.random.op import RandomVariable
@@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
101106
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
102107
3 in this case.
103108
104-
>>> with pm.Model() as markov_chain:
105-
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
106-
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
107-
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
109+
.. code-block:: python
110+
111+
import pymc as pm
112+
import pymc_experimental as pmx
113+
114+
with pm.Model() as markov_chain:
115+
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116+
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117+
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
108118
109119
"""
110120

@@ -266,3 +276,70 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
266276
"P must sum to 1 along the last axis, "
267277
"First dimension of init_dist must be n_lags",
268278
)
279+
280+
281+
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
282+
283+
name = "discrete_markov_chain_gibbs_metropolis"
284+
285+
def __init__(self, vars, proposal="uniform", order="random", model=None):
286+
model = pm.modelcontext(model)
287+
vars = get_value_vars_from_user_vars(vars, model)
288+
initial_point = model.initial_point()
289+
290+
dimcats = []
291+
# The above variable is a list of pairs (aggregate dimension, number
292+
# of categories). For example, if vars = [x, y] with x being a 2-D
293+
# variable with M categories and y being a 3-D variable with N
294+
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
295+
for v in vars:
296+
v_init_val = initial_point[v.name]
297+
rv_var = model.values_to_rvs[v]
298+
rv_op = rv_var.owner.op
299+
300+
if not isinstance(rv_op, DiscreteMarkovChainRV):
301+
raise TypeError("All variables must be DiscreteMarkovChainRV")
302+
303+
k_graph = rv_var.owner.inputs[0].shape[-1]
304+
(k_graph,) = model.replace_rvs_by_values((k_graph,))
305+
k = model.compile_fn(
306+
k_graph,
307+
inputs=model.value_vars,
308+
on_unused_input="ignore",
309+
mode=Mode(linker="py", optimizer=None),
310+
)(initial_point)
311+
start = len(dimcats)
312+
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
313+
314+
if order == "random":
315+
self.shuffle_dims = True
316+
self.dimcats = dimcats
317+
else:
318+
if sorted(order) != list(range(len(dimcats))):
319+
raise ValueError("Argument 'order' has to be a permutation")
320+
self.shuffle_dims = False
321+
self.dimcats = [dimcats[j] for j in order]
322+
323+
if proposal == "uniform":
324+
self.astep = self.astep_unif
325+
elif proposal == "proportional":
326+
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
327+
self.astep = self.astep_prop
328+
else:
329+
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
330+
331+
# Doesn't actually tune, but it's required to emit a sampler stat
332+
# that indicates whether a draw was done in a tuning phase.
333+
self.tune = True
334+
335+
# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
336+
ArrayStep.__init__(self, vars, [model.compile_logp()])
337+
338+
@staticmethod
339+
def competence(var):
340+
if isinstance(var.owner.op, DiscreteMarkovChainRV):
341+
return Competence.IDEAL
342+
return Competence.INCOMPATIBLE
343+
344+
345+
STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)

tests/distributions/test_discrete_markov_chain.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
import pytensor.tensor as pt
66
import pytest
77

8+
from pymc.distributions import Categorical
89
from pymc.distributions.shape_utils import change_dist_size
910
from pymc.logprob.utils import ParameterValueError
11+
from pymc.sampling.mcmc import assign_step_methods
1012

11-
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
13+
from pymc_experimental.distributions.timeseries import (
14+
DiscreteMarkovChain,
15+
DiscreteMarkovChainGibbsMetropolis,
16+
)
1217

1318

1419
def transition_probability_tests(steps, n_states, n_lags, n_draws, atol):
@@ -216,3 +221,37 @@ def test_change_size_univariate(self):
216221

217222
new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
218223
assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)
224+
225+
def test_mcmc_sampling(self):
226+
227+
with pm.Model(coords={"step": range(100)}) as model:
228+
init_dist = Categorical.dist(p=[0.5, 0.5])
229+
DiscreteMarkovChain(
230+
"markov_chain",
231+
P=[[0.1, 0.9], [0.1, 0.9]],
232+
init_dist=init_dist,
233+
shape=(100,),
234+
dims="step",
235+
)
236+
237+
step_method = assign_step_methods(model)
238+
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)
239+
240+
# Sampler needs no tuning
241+
idata = pm.sample(
242+
tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False
243+
)
244+
245+
np.testing.assert_allclose(
246+
idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")),
247+
0.5,
248+
atol=0.05,
249+
)
250+
251+
np.testing.assert_allclose(
252+
idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")),
253+
0.9,
254+
atol=0.05,
255+
)
256+
257+
assert pm.stats.ess(idata, method="tail").min() > 950

0 commit comments

Comments
 (0)