|
20 | 20 | from pymc.logprob.abstract import _logprob
|
21 | 21 | from pymc.logprob.basic import logp
|
22 | 22 | from pymc.pytensorf import constant_fold, intX
|
23 |
| -from pymc.util import check_dist_not_registered |
| 23 | +from pymc.step_methods import STEP_METHODS |
| 24 | +from pymc.step_methods.arraystep import ArrayStep |
| 25 | +from pymc.step_methods.compound import Competence |
| 26 | +from pymc.step_methods.metropolis import CategoricalGibbsMetropolis |
| 27 | +from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars |
| 28 | +from pytensor import Mode |
24 | 29 | from pytensor.graph.basic import Node
|
25 | 30 | from pytensor.tensor import TensorVariable
|
26 | 31 | from pytensor.tensor.random.op import RandomVariable
|
@@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
|
101 | 106 | Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
|
102 | 107 | 3 in this case.
|
103 | 108 |
|
104 |
| - >>> with pm.Model() as markov_chain: |
105 |
| - >>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) |
106 |
| - >>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) |
107 |
| - >>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) |
| 109 | + .. code-block:: python |
| 110 | +
|
| 111 | + import pymc as pm |
| 112 | + import pymc_experimental as pmx |
| 113 | +
|
| 114 | + with pm.Model() as markov_chain: |
| 115 | + P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) |
| 116 | + init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) |
| 117 | + markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) |
108 | 118 |
|
109 | 119 | """
|
110 | 120 |
|
@@ -266,3 +276,70 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
|
266 | 276 | "P must sum to 1 along the last axis, "
|
267 | 277 | "First dimension of init_dist must be n_lags",
|
268 | 278 | )
|
| 279 | + |
| 280 | + |
| 281 | +class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis): |
| 282 | + |
| 283 | + name = "discrete_markov_chain_gibbs_metropolis" |
| 284 | + |
| 285 | + def __init__(self, vars, proposal="uniform", order="random", model=None): |
| 286 | + model = pm.modelcontext(model) |
| 287 | + vars = get_value_vars_from_user_vars(vars, model) |
| 288 | + initial_point = model.initial_point() |
| 289 | + |
| 290 | + dimcats = [] |
| 291 | + # The above variable is a list of pairs (aggregate dimension, number |
| 292 | + # of categories). For example, if vars = [x, y] with x being a 2-D |
| 293 | + # variable with M categories and y being a 3-D variable with N |
| 294 | + # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)]. |
| 295 | + for v in vars: |
| 296 | + v_init_val = initial_point[v.name] |
| 297 | + rv_var = model.values_to_rvs[v] |
| 298 | + rv_op = rv_var.owner.op |
| 299 | + |
| 300 | + if not isinstance(rv_op, DiscreteMarkovChainRV): |
| 301 | + raise TypeError("All variables must be DiscreteMarkovChainRV") |
| 302 | + |
| 303 | + k_graph = rv_var.owner.inputs[0].shape[-1] |
| 304 | + (k_graph,) = model.replace_rvs_by_values((k_graph,)) |
| 305 | + k = model.compile_fn( |
| 306 | + k_graph, |
| 307 | + inputs=model.value_vars, |
| 308 | + on_unused_input="ignore", |
| 309 | + mode=Mode(linker="py", optimizer=None), |
| 310 | + )(initial_point) |
| 311 | + start = len(dimcats) |
| 312 | + dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)] |
| 313 | + |
| 314 | + if order == "random": |
| 315 | + self.shuffle_dims = True |
| 316 | + self.dimcats = dimcats |
| 317 | + else: |
| 318 | + if sorted(order) != list(range(len(dimcats))): |
| 319 | + raise ValueError("Argument 'order' has to be a permutation") |
| 320 | + self.shuffle_dims = False |
| 321 | + self.dimcats = [dimcats[j] for j in order] |
| 322 | + |
| 323 | + if proposal == "uniform": |
| 324 | + self.astep = self.astep_unif |
| 325 | + elif proposal == "proportional": |
| 326 | + # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. |
| 327 | + self.astep = self.astep_prop |
| 328 | + else: |
| 329 | + raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") |
| 330 | + |
| 331 | + # Doesn't actually tune, but it's required to emit a sampler stat |
| 332 | + # that indicates whether a draw was done in a tuning phase. |
| 333 | + self.tune = True |
| 334 | + |
| 335 | + # We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic |
| 336 | + ArrayStep.__init__(self, vars, [model.compile_logp()]) |
| 337 | + |
| 338 | + @staticmethod |
| 339 | + def competence(var): |
| 340 | + if isinstance(var.owner.op, DiscreteMarkovChainRV): |
| 341 | + return Competence.IDEAL |
| 342 | + return Competence.INCOMPATIBLE |
| 343 | + |
| 344 | + |
| 345 | +STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis) |
0 commit comments