Skip to content

Commit 60cf2cd

Browse files
committed
🏷️ type sampling
1 parent 6f15cbb commit 60cf2cd

File tree

4 files changed

+49
-37
lines changed

4 files changed

+49
-37
lines changed

pymc3/sampling.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414

1515
"""Functions for MCMC sampling."""
1616

17+
import collections.abc as abc
1718
import logging
1819
import pickle
1920
import sys
2021
import time
2122
import warnings
2223

2324
from collections import defaultdict
24-
from collections.abc import Iterable
2525
from copy import copy
26-
from typing import Any, Dict
27-
from typing import Iterable as TIterable
28-
from typing import List, Optional, Union, cast
26+
from typing import Any, Dict, Iterable, List, Optional, Union, cast
2927

3028
import arviz
3129
import numpy as np
@@ -57,8 +55,8 @@
5755
HamiltonianMC,
5856
Metropolis,
5957
Slice,
60-
arraystep,
6158
)
59+
from pymc3.step_methods.arraystep import PopulationArrayStepShared
6260
from pymc3.step_methods.hmc import quadpotential
6361
from pymc3.util import (
6462
chains_and_samples,
@@ -93,15 +91,30 @@
9391
CategoricalGibbsMetropolis,
9492
PGBART,
9593
)
94+
Step = Union[
95+
NUTS,
96+
HamiltonianMC,
97+
Metropolis,
98+
BinaryMetropolis,
99+
BinaryGibbsMetropolis,
100+
Slice,
101+
CategoricalGibbsMetropolis,
102+
PGBART,
103+
CompoundStep,
104+
]
105+
96106

97107
ArrayLike = Union[np.ndarray, List[float]]
98108
PointType = Dict[str, np.ndarray]
99109
PointList = List[PointType]
110+
Backend = Union[BaseTrace, MultiTrace, NDArray]
100111

101112
_log = logging.getLogger("pymc3")
102113

103114

104-
def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
115+
def instantiate_steppers(
116+
_model, steps: List[Step], selected_steps, step_kwargs=None
117+
) -> Union[Step, List[Step]]:
105118
"""Instantiate steppers assigned to the model variables.
106119
107120
This function is intended to be called automatically from ``sample()``, but
@@ -142,7 +155,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
142155
raise ValueError("Unused step method arguments: %s" % unused_args)
143156

144157
if len(steps) == 1:
145-
steps = steps[0]
158+
return steps[0]
146159

147160
return steps
148161

@@ -216,7 +229,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
216229
return instantiate_steppers(model, steps, selected_steps, step_kwargs)
217230

218231

219-
def _print_step_hierarchy(s, level=0):
232+
def _print_step_hierarchy(s: Step, level=0) -> None:
220233
if isinstance(s, CompoundStep):
221234
_log.info(">" * level + "CompoundStep")
222235
for i in s.methods:
@@ -447,7 +460,7 @@ def sample(
447460
if random_seed is not None:
448461
np.random.seed(random_seed)
449462
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
450-
if not isinstance(random_seed, Iterable):
463+
if not isinstance(random_seed, abc.Iterable):
451464
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
452465

453466
if not discard_tuned_samples and not return_inferencedata:
@@ -542,7 +555,7 @@ def sample(
542555

543556
has_population_samplers = np.any(
544557
[
545-
isinstance(m, arraystep.PopulationArrayStepShared)
558+
isinstance(m, PopulationArrayStepShared)
546559
for m in (step.methods if isinstance(step, CompoundStep) else [step])
547560
]
548561
)
@@ -706,7 +719,7 @@ def _sample_many(
706719
trace: MultiTrace
707720
Contains samples of all chains
708721
"""
709-
traces = []
722+
traces: List[Backend] = []
710723
for i in range(chains):
711724
trace = _sample(
712725
draws=draws,
@@ -1140,7 +1153,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
11401153
# has to be updated, therefore we identify the substeppers first.
11411154
population_steppers = []
11421155
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
1143-
if isinstance(sm, arraystep.PopulationArrayStepShared):
1156+
if isinstance(sm, PopulationArrayStepShared):
11441157
population_steppers.append(sm)
11451158
while True:
11461159
incoming = secondary_end.recv()
@@ -1259,7 +1272,7 @@ def _prepare_iter_population(
12591272
population = [Point(start[c], model=model) for c in range(nchains)]
12601273

12611274
# 3. Set up the steppers
1262-
steppers = [None] * nchains
1275+
steppers: List[Step] = []
12631276
for c in range(nchains):
12641277
# need indepenent samplers for each chain
12651278
# it is important to copy the actual steppers (but not the delta_logp)
@@ -1269,9 +1282,9 @@ def _prepare_iter_population(
12691282
chainstep = copy(step)
12701283
# link population samplers to the shared population state
12711284
for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
1272-
if isinstance(sm, arraystep.PopulationArrayStepShared):
1285+
if isinstance(sm, PopulationArrayStepShared):
12731286
sm.link_population(population, c)
1274-
steppers[c] = chainstep
1287+
steppers.append(chainstep)
12751288

12761289
# 4. configure tracking of sampler stats
12771290
for c in range(nchains):
@@ -1349,7 +1362,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
13491362
steppers[c].report._finalize(strace)
13501363

13511364

1352-
def _choose_backend(trace, chain, **kwds):
1365+
def _choose_backend(trace, chain, **kwds) -> Backend:
13531366
"""Selects or creates a NDArray trace backend for a particular chain.
13541367
13551368
Parameters
@@ -1562,8 +1575,8 @@ class _DefaultTrace:
15621575
`insert()` method
15631576
"""
15641577

1565-
trace_dict = {} # type: Dict[str, np.ndarray]
1566-
_len = None # type: int
1578+
trace_dict: Dict[str, np.ndarray] = {}
1579+
_len: Optional[int] = None
15671580

15681581
def __init__(self, samples: int):
15691582
self._len = samples
@@ -1600,7 +1613,7 @@ def sample_posterior_predictive(
16001613
trace,
16011614
samples: Optional[int] = None,
16021615
model: Optional[Model] = None,
1603-
vars: Optional[TIterable[Tensor]] = None,
1616+
vars: Optional[Iterable[Tensor]] = None,
16041617
var_names: Optional[List[str]] = None,
16051618
size: Optional[int] = None,
16061619
keep_size: Optional[bool] = False,
@@ -1885,8 +1898,8 @@ def sample_posterior_predictive_w(
18851898
def sample_prior_predictive(
18861899
samples=500,
18871900
model: Optional[Model] = None,
1888-
vars: Optional[TIterable[str]] = None,
1889-
var_names: Optional[TIterable[str]] = None,
1901+
vars: Optional[Iterable[str]] = None,
1902+
var_names: Optional[Iterable[str]] = None,
18901903
random_seed=None,
18911904
) -> Dict[str, np.ndarray]:
18921905
"""Generate samples from the prior predictive distribution.
@@ -1918,17 +1931,15 @@ def sample_prior_predictive(
19181931
prior_vars = (
19191932
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
19201933
)
1921-
vars_ = [var.name for var in prior_vars + prior_pred_vars]
1922-
vars = set(vars_)
1934+
vars_: Iterable[str] = [var.name for var in prior_vars + prior_pred_vars]
19231935
elif vars is None:
1924-
vars = var_names
1925-
vars_ = vars
1926-
elif vars is not None:
1936+
assert var_names is not None # help mypy
1937+
vars_ = var_names
1938+
elif var_names is None:
19271939
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
19281940
vars_ = vars
19291941
else:
19301942
raise ValueError("Cannot supply both vars and var_names arguments.")
1931-
vars = cast(TIterable[str], vars) # tell mypy that vars cannot be None here.
19321943

19331944
if random_seed is not None:
19341945
np.random.seed(random_seed)
@@ -1940,8 +1951,8 @@ def sample_prior_predictive(
19401951
if data is None:
19411952
raise AssertionError("No variables sampled: attempting to sample %s" % names)
19421953

1943-
prior = {} # type: Dict[str, np.ndarray]
1944-
for var_name in vars:
1954+
prior: Dict[str, np.ndarray] = {}
1955+
for var_name in vars_:
19451956
if var_name in data:
19461957
prior[var_name] = data[var_name]
19471958
elif is_transformed_name(var_name):
@@ -2093,15 +2104,15 @@ def init_nuts(
20932104
var = np.ones_like(mean)
20942105
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
20952106
elif init == "advi+adapt_diag_grad":
2096-
approx = pm.fit(
2107+
approx: pm.MeanField = pm.fit(
20972108
random_seed=random_seed,
20982109
n=n_init,
20992110
method="advi",
21002111
model=model,
21012112
callbacks=cb,
21022113
progressbar=progressbar,
21032114
obj_optimizer=pm.adagrad_window,
2104-
) # type: pm.MeanField
2115+
)
21052116
start = approx.sample(draws=chains)
21062117
start = list(start)
21072118
stds = approx.bij.rmap(approx.std.eval())
@@ -2119,7 +2130,7 @@ def init_nuts(
21192130
callbacks=cb,
21202131
progressbar=progressbar,
21212132
obj_optimizer=pm.adagrad_window,
2122-
) # type: pm.MeanField
2133+
)
21232134
start = approx.sample(draws=chains)
21242135
start = list(start)
21252136
stds = approx.bij.rmap(approx.std.eval())
@@ -2137,7 +2148,7 @@ def init_nuts(
21372148
callbacks=cb,
21382149
progressbar=progressbar,
21392150
obj_optimizer=pm.adagrad_window,
2140-
) # type: pm.MeanField
2151+
)
21412152
start = approx.sample(draws=chains)
21422153
start = list(start)
21432154
stds = approx.bij.rmap(approx.std.eval())

pymc3/step_methods/arraystep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from enum import IntEnum, unique
16+
from typing import Dict, List
1617

1718
import numpy as np
1819

@@ -46,6 +47,7 @@ class Competence(IntEnum):
4647
class BlockedStep:
4748

4849
generates_stats = False
50+
stats_dtypes: List[Dict[str, np.dtype]] = []
4951

5052
def __new__(cls, *args, **kwargs):
5153
blocked = kwargs.get("blocked")

pymc3/step_methods/mlda.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,6 @@ class MLDA(ArrayStepShared):
356356
default_blocked = True
357357
generates_stats = True
358358

359-
# stat data types are different, depending on the base sampler.
360-
# these are assigned in the init method.
361-
stats_dtypes = None
362-
363359
def __init__(
364360
self,
365361
coarse_models: List[Model],

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ convention = numpy
1111
[isort]
1212
lines_between_types = 1
1313
profile = black
14+
15+
[mypy]
16+
ignore_missing_imports = True

0 commit comments

Comments
 (0)