Skip to content

Commit 2aa19cc

Browse files
committed
Add typing to mlda.py
1 parent 1ceb4bd commit 2aa19cc

File tree

1 file changed

+86
-79
lines changed

1 file changed

+86
-79
lines changed

pymc3/step_methods/mlda.py

Lines changed: 86 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,81 +15,17 @@
1515
import numpy as np
1616
import warnings
1717
import logging
18+
from typing import Union, List, Optional, Type
1819

1920
from .arraystep import ArrayStepShared, metrop_select, Competence
2021
from .compound import CompoundStep
2122
from .metropolis import Proposal, Metropolis, delta_logp
23+
from ..model import Model
2224
import pymc3 as pm
2325

2426
__all__ = ["MetropolisMLDA", "RecursiveDAProposal", "MLDA"]
2527

2628

27-
# Available proposal distributions for MLDA
28-
29-
30-
class RecursiveDAProposal(Proposal):
31-
"""
32-
Recursive Delayed Acceptance proposal to be used with MLDA step sampler.
33-
Recursively calls an MLDA sampler if level > 0 and calls MetropolisMLDA
34-
sampler if level = 0. The sampler generates subsampling_rate samples and
35-
the last one is used as a proposal. Results in a hierarchy of chains
36-
each of which is used to propose samples to the chain above.
37-
"""
38-
39-
def __init__(self, next_step_method, next_model, tune, subsampling_rate):
40-
41-
self.next_step_method = next_step_method
42-
self.next_model = next_model
43-
self.tune = tune
44-
self.subsampling_rate = subsampling_rate
45-
46-
def __call__(self, q0_dict):
47-
"""Returns proposed sample given the current sample
48-
in dictionary form (q0_dict).
49-
"""
50-
51-
# Logging is reduced to avoid extensive console output
52-
# during multiple recursive calls of sample()
53-
_log = logging.getLogger("pymc3")
54-
_log.setLevel(logging.ERROR)
55-
56-
with self.next_model:
57-
# Check if the tuning flag has been set to False
58-
# in which case tuning is stopped. The flag is set
59-
# to False (by MLDA's astep) when the burn-in
60-
# iterations of the highest-level MLDA sampler run out.
61-
# The change propagates to all levels.
62-
if self.tune:
63-
# Sample in tuning mode
64-
output = pm.sample(
65-
draws=0,
66-
step=self.next_step_method,
67-
start=q0_dict,
68-
tune=self.subsampling_rate,
69-
chains=1,
70-
progressbar=False,
71-
compute_convergence_checks=False,
72-
discard_tuned_samples=False,
73-
).point(-1)
74-
else:
75-
# Sample in normal mode without tuning
76-
output = pm.sample(
77-
draws=self.subsampling_rate,
78-
step=self.next_step_method,
79-
start=q0_dict,
80-
tune=0,
81-
chains=1,
82-
progressbar=False,
83-
compute_convergence_checks=False,
84-
discard_tuned_samples=False,
85-
).point(-1)
86-
87-
# set logging back to normal
88-
_log.setLevel(logging.NOTSET)
89-
90-
return output
91-
92-
9329
class MetropolisMLDA(Metropolis):
9430
"""
9531
Metropolis-Hastings sampling step tailored for use as base sampler in MLDA
@@ -214,19 +150,19 @@ class MLDA(ArrayStepShared):
214150

215151
def __init__(
216152
self,
217-
coarse_models,
218-
vars=None,
219-
base_S=None,
220-
base_proposal_dist=None,
221-
base_scaling=1.0,
222-
tune=True,
223-
base_tune_interval=100,
224-
model=None,
225-
mode=None,
226-
subsampling_rates=5,
227-
base_blocked=False,
228-
**kwargs,
229-
):
153+
coarse_models: List[Model],
154+
vars: Optional[list] = None,
155+
base_S: Optional = None,
156+
base_proposal_dist: Optional[Type[Proposal]] = None,
157+
base_scaling: Union[float, int] = 1.0,
158+
tune: bool = True,
159+
base_tune_interval: int = 100,
160+
model: Optional[Model] = None,
161+
mode: Optional = None,
162+
subsampling_rates: List[int] = 5,
163+
base_blocked: bool = False,
164+
**kwargs
165+
) -> None:
230166

231167
warnings.warn(
232168
"The MLDA implementation in PyMC3 is very young. "
@@ -416,3 +352,74 @@ def competence(var, has_grad):
416352
if var.dtype in pm.discrete_types:
417353
return Competence.INCOMPATIBLE
418354
return Competence.COMPATIBLE
355+
356+
357+
# Available proposal distributions for MLDA
358+
359+
360+
class RecursiveDAProposal(Proposal):
361+
"""
362+
Recursive Delayed Acceptance proposal to be used with MLDA step sampler.
363+
Recursively calls an MLDA sampler if level > 0 and calls MetropolisMLDA
364+
sampler if level = 0. The sampler generates subsampling_rate samples and
365+
the last one is used as a proposal. Results in a hierarchy of chains
366+
each of which is used to propose samples to the chain above.
367+
"""
368+
369+
def __init__(self,
370+
next_step_method: Union[MLDA, Metropolis, CompoundStep],
371+
next_model: Model,
372+
tune: bool,
373+
subsampling_rate: int) -> None:
374+
375+
self.next_step_method = next_step_method
376+
self.next_model = next_model
377+
self.tune = tune
378+
self.subsampling_rate = subsampling_rate
379+
380+
def __call__(self,
381+
q0_dict: dict) -> dict:
382+
"""Returns proposed sample given the current sample
383+
in dictionary form (q0_dict).
384+
"""
385+
386+
# Logging is reduced to avoid extensive console output
387+
# during multiple recursive calls of sample()
388+
_log = logging.getLogger("pymc3")
389+
_log.setLevel(logging.ERROR)
390+
391+
with self.next_model:
392+
# Check if the tuning flag has been set to False
393+
# in which case tuning is stopped. The flag is set
394+
# to False (by MLDA's astep) when the burn-in
395+
# iterations of the highest-level MLDA sampler run out.
396+
# The change propagates to all levels.
397+
if self.tune:
398+
# Sample in tuning mode
399+
output = pm.sample(
400+
draws=0,
401+
step=self.next_step_method,
402+
start=q0_dict,
403+
tune=self.subsampling_rate,
404+
chains=1,
405+
progressbar=False,
406+
compute_convergence_checks=False,
407+
discard_tuned_samples=False,
408+
).point(-1)
409+
else:
410+
# Sample in normal mode without tuning
411+
output = pm.sample(
412+
draws=self.subsampling_rate,
413+
step=self.next_step_method,
414+
start=q0_dict,
415+
tune=0,
416+
chains=1,
417+
progressbar=False,
418+
compute_convergence_checks=False,
419+
discard_tuned_samples=False,
420+
).point(-1)
421+
422+
# set logging back to normal
423+
_log.setLevel(logging.NOTSET)
424+
425+
return output

0 commit comments

Comments
 (0)