|
15 | 15 | import numpy as np
|
16 | 16 | import warnings
|
17 | 17 | import logging
|
| 18 | +from typing import Union, List, Optional, Type |
18 | 19 |
|
19 | 20 | from .arraystep import ArrayStepShared, metrop_select, Competence
|
20 | 21 | from .compound import CompoundStep
|
21 | 22 | from .metropolis import Proposal, Metropolis, delta_logp
|
| 23 | +from ..model import Model |
22 | 24 | import pymc3 as pm
|
23 | 25 |
|
24 | 26 | __all__ = ["MetropolisMLDA", "RecursiveDAProposal", "MLDA"]
|
25 | 27 |
|
26 | 28 |
|
27 |
| -# Available proposal distributions for MLDA |
28 |
| - |
29 |
| - |
30 |
| -class RecursiveDAProposal(Proposal): |
31 |
| - """ |
32 |
| - Recursive Delayed Acceptance proposal to be used with MLDA step sampler. |
33 |
| - Recursively calls an MLDA sampler if level > 0 and calls MetropolisMLDA |
34 |
| - sampler if level = 0. The sampler generates subsampling_rate samples and |
35 |
| - the last one is used as a proposal. Results in a hierarchy of chains |
36 |
| - each of which is used to propose samples to the chain above. |
37 |
| - """ |
38 |
| - |
39 |
| - def __init__(self, next_step_method, next_model, tune, subsampling_rate): |
40 |
| - |
41 |
| - self.next_step_method = next_step_method |
42 |
| - self.next_model = next_model |
43 |
| - self.tune = tune |
44 |
| - self.subsampling_rate = subsampling_rate |
45 |
| - |
46 |
| - def __call__(self, q0_dict): |
47 |
| - """Returns proposed sample given the current sample |
48 |
| - in dictionary form (q0_dict). |
49 |
| - """ |
50 |
| - |
51 |
| - # Logging is reduced to avoid extensive console output |
52 |
| - # during multiple recursive calls of sample() |
53 |
| - _log = logging.getLogger("pymc3") |
54 |
| - _log.setLevel(logging.ERROR) |
55 |
| - |
56 |
| - with self.next_model: |
57 |
| - # Check if the tuning flag has been set to False |
58 |
| - # in which case tuning is stopped. The flag is set |
59 |
| - # to False (by MLDA's astep) when the burn-in |
60 |
| - # iterations of the highest-level MLDA sampler run out. |
61 |
| - # The change propagates to all levels. |
62 |
| - if self.tune: |
63 |
| - # Sample in tuning mode |
64 |
| - output = pm.sample( |
65 |
| - draws=0, |
66 |
| - step=self.next_step_method, |
67 |
| - start=q0_dict, |
68 |
| - tune=self.subsampling_rate, |
69 |
| - chains=1, |
70 |
| - progressbar=False, |
71 |
| - compute_convergence_checks=False, |
72 |
| - discard_tuned_samples=False, |
73 |
| - ).point(-1) |
74 |
| - else: |
75 |
| - # Sample in normal mode without tuning |
76 |
| - output = pm.sample( |
77 |
| - draws=self.subsampling_rate, |
78 |
| - step=self.next_step_method, |
79 |
| - start=q0_dict, |
80 |
| - tune=0, |
81 |
| - chains=1, |
82 |
| - progressbar=False, |
83 |
| - compute_convergence_checks=False, |
84 |
| - discard_tuned_samples=False, |
85 |
| - ).point(-1) |
86 |
| - |
87 |
| - # set logging back to normal |
88 |
| - _log.setLevel(logging.NOTSET) |
89 |
| - |
90 |
| - return output |
91 |
| - |
92 |
| - |
93 | 29 | class MetropolisMLDA(Metropolis):
|
94 | 30 | """
|
95 | 31 | Metropolis-Hastings sampling step tailored for use as base sampler in MLDA
|
@@ -214,19 +150,19 @@ class MLDA(ArrayStepShared):
|
214 | 150 |
|
215 | 151 | def __init__(
|
216 | 152 | self,
|
217 |
| - coarse_models, |
218 |
| - vars=None, |
219 |
| - base_S=None, |
220 |
| - base_proposal_dist=None, |
221 |
| - base_scaling=1.0, |
222 |
| - tune=True, |
223 |
| - base_tune_interval=100, |
224 |
| - model=None, |
225 |
| - mode=None, |
226 |
| - subsampling_rates=5, |
227 |
| - base_blocked=False, |
228 |
| - **kwargs, |
229 |
| - ): |
| 153 | + coarse_models: List[Model], |
| 154 | + vars: Optional[list] = None, |
| 155 | + base_S: Optional = None, |
| 156 | + base_proposal_dist: Optional[Type[Proposal]] = None, |
| 157 | + base_scaling: Union[float, int] = 1.0, |
| 158 | + tune: bool = True, |
| 159 | + base_tune_interval: int = 100, |
| 160 | + model: Optional[Model] = None, |
| 161 | + mode: Optional = None, |
| 162 | + subsampling_rates: List[int] = 5, |
| 163 | + base_blocked: bool = False, |
| 164 | + **kwargs |
| 165 | + ) -> None: |
230 | 166 |
|
231 | 167 | warnings.warn(
|
232 | 168 | "The MLDA implementation in PyMC3 is very young. "
|
@@ -416,3 +352,74 @@ def competence(var, has_grad):
|
416 | 352 | if var.dtype in pm.discrete_types:
|
417 | 353 | return Competence.INCOMPATIBLE
|
418 | 354 | return Competence.COMPATIBLE
|
| 355 | + |
| 356 | + |
| 357 | +# Available proposal distributions for MLDA |
| 358 | + |
| 359 | + |
| 360 | +class RecursiveDAProposal(Proposal): |
| 361 | + """ |
| 362 | + Recursive Delayed Acceptance proposal to be used with MLDA step sampler. |
| 363 | + Recursively calls an MLDA sampler if level > 0 and calls MetropolisMLDA |
| 364 | + sampler if level = 0. The sampler generates subsampling_rate samples and |
| 365 | + the last one is used as a proposal. Results in a hierarchy of chains |
| 366 | + each of which is used to propose samples to the chain above. |
| 367 | + """ |
| 368 | + |
| 369 | + def __init__(self, |
| 370 | + next_step_method: Union[MLDA, Metropolis, CompoundStep], |
| 371 | + next_model: Model, |
| 372 | + tune: bool, |
| 373 | + subsampling_rate: int) -> None: |
| 374 | + |
| 375 | + self.next_step_method = next_step_method |
| 376 | + self.next_model = next_model |
| 377 | + self.tune = tune |
| 378 | + self.subsampling_rate = subsampling_rate |
| 379 | + |
| 380 | + def __call__(self, |
| 381 | + q0_dict: dict) -> dict: |
| 382 | + """Returns proposed sample given the current sample |
| 383 | + in dictionary form (q0_dict). |
| 384 | + """ |
| 385 | + |
| 386 | + # Logging is reduced to avoid extensive console output |
| 387 | + # during multiple recursive calls of sample() |
| 388 | + _log = logging.getLogger("pymc3") |
| 389 | + _log.setLevel(logging.ERROR) |
| 390 | + |
| 391 | + with self.next_model: |
| 392 | + # Check if the tuning flag has been set to False |
| 393 | + # in which case tuning is stopped. The flag is set |
| 394 | + # to False (by MLDA's astep) when the burn-in |
| 395 | + # iterations of the highest-level MLDA sampler run out. |
| 396 | + # The change propagates to all levels. |
| 397 | + if self.tune: |
| 398 | + # Sample in tuning mode |
| 399 | + output = pm.sample( |
| 400 | + draws=0, |
| 401 | + step=self.next_step_method, |
| 402 | + start=q0_dict, |
| 403 | + tune=self.subsampling_rate, |
| 404 | + chains=1, |
| 405 | + progressbar=False, |
| 406 | + compute_convergence_checks=False, |
| 407 | + discard_tuned_samples=False, |
| 408 | + ).point(-1) |
| 409 | + else: |
| 410 | + # Sample in normal mode without tuning |
| 411 | + output = pm.sample( |
| 412 | + draws=self.subsampling_rate, |
| 413 | + step=self.next_step_method, |
| 414 | + start=q0_dict, |
| 415 | + tune=0, |
| 416 | + chains=1, |
| 417 | + progressbar=False, |
| 418 | + compute_convergence_checks=False, |
| 419 | + discard_tuned_samples=False, |
| 420 | + ).point(-1) |
| 421 | + |
| 422 | + # set logging back to normal |
| 423 | + _log.setLevel(logging.NOTSET) |
| 424 | + |
| 425 | + return output |
0 commit comments