14
14
15
15
"""Functions for MCMC sampling."""
16
16
17
+ import collections .abc as abc
17
18
import logging
18
19
import pickle
19
20
import sys
20
21
import time
21
22
import warnings
22
23
23
24
from collections import defaultdict
24
- from collections .abc import Iterable
25
25
from copy import copy
26
- from typing import Any , Dict
27
- from typing import Iterable as TIterable
28
- from typing import List , Optional , Union , cast
26
+ from typing import Any , Dict , Iterable , List , Optional , Union , cast
29
27
30
28
import arviz
31
29
import numpy as np
57
55
HamiltonianMC ,
58
56
Metropolis ,
59
57
Slice ,
60
- arraystep ,
61
58
)
59
+ from pymc3 .step_methods .arraystep import PopulationArrayStepShared
62
60
from pymc3 .step_methods .hmc import quadpotential
63
61
from pymc3 .util import (
64
62
chains_and_samples ,
93
91
CategoricalGibbsMetropolis ,
94
92
PGBART ,
95
93
)
94
+ Step = Union [
95
+ NUTS ,
96
+ HamiltonianMC ,
97
+ Metropolis ,
98
+ BinaryMetropolis ,
99
+ BinaryGibbsMetropolis ,
100
+ Slice ,
101
+ CategoricalGibbsMetropolis ,
102
+ PGBART ,
103
+ CompoundStep ,
104
+ ]
105
+
96
106
97
107
ArrayLike = Union [np .ndarray , List [float ]]
98
108
PointType = Dict [str , np .ndarray ]
99
109
PointList = List [PointType ]
110
+ Backend = Union [BaseTrace , MultiTrace , NDArray ]
100
111
101
112
_log = logging .getLogger ("pymc3" )
102
113
103
114
104
- def instantiate_steppers (_model , steps , selected_steps , step_kwargs = None ):
115
+ def instantiate_steppers (
116
+ _model , steps : List [Step ], selected_steps , step_kwargs = None
117
+ ) -> Union [Step , List [Step ]]:
105
118
"""Instantiate steppers assigned to the model variables.
106
119
107
120
This function is intended to be called automatically from ``sample()``, but
@@ -142,7 +155,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
142
155
raise ValueError ("Unused step method arguments: %s" % unused_args )
143
156
144
157
if len (steps ) == 1 :
145
- steps = steps [0 ]
158
+ return steps [0 ]
146
159
147
160
return steps
148
161
@@ -216,7 +229,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
216
229
return instantiate_steppers (model , steps , selected_steps , step_kwargs )
217
230
218
231
219
- def _print_step_hierarchy (s , level = 0 ):
232
+ def _print_step_hierarchy (s : Step , level = 0 ) -> None :
220
233
if isinstance (s , CompoundStep ):
221
234
_log .info (">" * level + "CompoundStep" )
222
235
for i in s .methods :
@@ -447,7 +460,7 @@ def sample(
447
460
if random_seed is not None :
448
461
np .random .seed (random_seed )
449
462
random_seed = [np .random .randint (2 ** 30 ) for _ in range (chains )]
450
- if not isinstance (random_seed , Iterable ):
463
+ if not isinstance (random_seed , abc . Iterable ):
451
464
raise TypeError ("Invalid value for `random_seed`. Must be tuple, list or int" )
452
465
453
466
if not discard_tuned_samples and not return_inferencedata :
@@ -542,7 +555,7 @@ def sample(
542
555
543
556
has_population_samplers = np .any (
544
557
[
545
- isinstance (m , arraystep . PopulationArrayStepShared )
558
+ isinstance (m , PopulationArrayStepShared )
546
559
for m in (step .methods if isinstance (step , CompoundStep ) else [step ])
547
560
]
548
561
)
@@ -706,7 +719,7 @@ def _sample_many(
706
719
trace: MultiTrace
707
720
Contains samples of all chains
708
721
"""
709
- traces = []
722
+ traces : List [ Backend ] = []
710
723
for i in range (chains ):
711
724
trace = _sample (
712
725
draws = draws ,
@@ -1140,7 +1153,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
1140
1153
# has to be updated, therefore we identify the substeppers first.
1141
1154
population_steppers = []
1142
1155
for sm in stepper .methods if isinstance (stepper , CompoundStep ) else [stepper ]:
1143
- if isinstance (sm , arraystep . PopulationArrayStepShared ):
1156
+ if isinstance (sm , PopulationArrayStepShared ):
1144
1157
population_steppers .append (sm )
1145
1158
while True :
1146
1159
incoming = secondary_end .recv ()
@@ -1259,7 +1272,7 @@ def _prepare_iter_population(
1259
1272
population = [Point (start [c ], model = model ) for c in range (nchains )]
1260
1273
1261
1274
# 3. Set up the steppers
1262
- steppers = [ None ] * nchains
1275
+ steppers : List [ Step ] = []
1263
1276
for c in range (nchains ):
1264
1277
# need indepenent samplers for each chain
1265
1278
# it is important to copy the actual steppers (but not the delta_logp)
@@ -1269,9 +1282,9 @@ def _prepare_iter_population(
1269
1282
chainstep = copy (step )
1270
1283
# link population samplers to the shared population state
1271
1284
for sm in chainstep .methods if isinstance (step , CompoundStep ) else [chainstep ]:
1272
- if isinstance (sm , arraystep . PopulationArrayStepShared ):
1285
+ if isinstance (sm , PopulationArrayStepShared ):
1273
1286
sm .link_population (population , c )
1274
- steppers [ c ] = chainstep
1287
+ steppers . append ( chainstep )
1275
1288
1276
1289
# 4. configure tracking of sampler stats
1277
1290
for c in range (nchains ):
@@ -1349,7 +1362,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
1349
1362
steppers [c ].report ._finalize (strace )
1350
1363
1351
1364
1352
- def _choose_backend (trace , chain , ** kwds ):
1365
+ def _choose_backend (trace , chain , ** kwds ) -> Backend :
1353
1366
"""Selects or creates a NDArray trace backend for a particular chain.
1354
1367
1355
1368
Parameters
@@ -1562,8 +1575,8 @@ class _DefaultTrace:
1562
1575
`insert()` method
1563
1576
"""
1564
1577
1565
- trace_dict = {} # type : Dict[str, np.ndarray]
1566
- _len = None # type: int
1578
+ trace_dict : Dict [str , np .ndarray ] = {}
1579
+ _len : Optional [ int ] = None
1567
1580
1568
1581
def __init__ (self , samples : int ):
1569
1582
self ._len = samples
@@ -1600,7 +1613,7 @@ def sample_posterior_predictive(
1600
1613
trace ,
1601
1614
samples : Optional [int ] = None ,
1602
1615
model : Optional [Model ] = None ,
1603
- vars : Optional [TIterable [Tensor ]] = None ,
1616
+ vars : Optional [Iterable [Tensor ]] = None ,
1604
1617
var_names : Optional [List [str ]] = None ,
1605
1618
size : Optional [int ] = None ,
1606
1619
keep_size : Optional [bool ] = False ,
@@ -1885,8 +1898,8 @@ def sample_posterior_predictive_w(
1885
1898
def sample_prior_predictive (
1886
1899
samples = 500 ,
1887
1900
model : Optional [Model ] = None ,
1888
- vars : Optional [TIterable [str ]] = None ,
1889
- var_names : Optional [TIterable [str ]] = None ,
1901
+ vars : Optional [Iterable [str ]] = None ,
1902
+ var_names : Optional [Iterable [str ]] = None ,
1890
1903
random_seed = None ,
1891
1904
) -> Dict [str , np .ndarray ]:
1892
1905
"""Generate samples from the prior predictive distribution.
@@ -1918,17 +1931,15 @@ def sample_prior_predictive(
1918
1931
prior_vars = (
1919
1932
get_default_varnames (model .unobserved_RVs , include_transformed = True ) + model .potentials
1920
1933
)
1921
- vars_ = [var .name for var in prior_vars + prior_pred_vars ]
1922
- vars = set (vars_ )
1934
+ vars_ : Iterable [str ] = [var .name for var in prior_vars + prior_pred_vars ]
1923
1935
elif vars is None :
1924
- vars = var_names
1925
- vars_ = vars
1926
- elif vars is not None :
1936
+ assert var_names is not None # help mypy
1937
+ vars_ = var_names
1938
+ elif var_names is None :
1927
1939
warnings .warn ("vars argument is deprecated in favor of var_names." , DeprecationWarning )
1928
1940
vars_ = vars
1929
1941
else :
1930
1942
raise ValueError ("Cannot supply both vars and var_names arguments." )
1931
- vars = cast (TIterable [str ], vars ) # tell mypy that vars cannot be None here.
1932
1943
1933
1944
if random_seed is not None :
1934
1945
np .random .seed (random_seed )
@@ -1940,8 +1951,8 @@ def sample_prior_predictive(
1940
1951
if data is None :
1941
1952
raise AssertionError ("No variables sampled: attempting to sample %s" % names )
1942
1953
1943
- prior = {} # type : Dict[str, np.ndarray]
1944
- for var_name in vars :
1954
+ prior : Dict [str , np .ndarray ] = {}
1955
+ for var_name in vars_ :
1945
1956
if var_name in data :
1946
1957
prior [var_name ] = data [var_name ]
1947
1958
elif is_transformed_name (var_name ):
@@ -2093,15 +2104,15 @@ def init_nuts(
2093
2104
var = np .ones_like (mean )
2094
2105
potential = quadpotential .QuadPotentialDiagAdapt (model .ndim , mean , var , 10 )
2095
2106
elif init == "advi+adapt_diag_grad" :
2096
- approx = pm .fit (
2107
+ approx : pm . MeanField = pm .fit (
2097
2108
random_seed = random_seed ,
2098
2109
n = n_init ,
2099
2110
method = "advi" ,
2100
2111
model = model ,
2101
2112
callbacks = cb ,
2102
2113
progressbar = progressbar ,
2103
2114
obj_optimizer = pm .adagrad_window ,
2104
- ) # type: pm.MeanField
2115
+ )
2105
2116
start = approx .sample (draws = chains )
2106
2117
start = list (start )
2107
2118
stds = approx .bij .rmap (approx .std .eval ())
@@ -2119,7 +2130,7 @@ def init_nuts(
2119
2130
callbacks = cb ,
2120
2131
progressbar = progressbar ,
2121
2132
obj_optimizer = pm .adagrad_window ,
2122
- ) # type: pm.MeanField
2133
+ )
2123
2134
start = approx .sample (draws = chains )
2124
2135
start = list (start )
2125
2136
stds = approx .bij .rmap (approx .std .eval ())
@@ -2137,7 +2148,7 @@ def init_nuts(
2137
2148
callbacks = cb ,
2138
2149
progressbar = progressbar ,
2139
2150
obj_optimizer = pm .adagrad_window ,
2140
- ) # type: pm.MeanField
2151
+ )
2141
2152
start = approx .sample (draws = chains )
2142
2153
start = list (start )
2143
2154
stds = approx .bij .rmap (approx .std .eval ())
0 commit comments