14
14
15
15
"""Functions for MCMC sampling."""
16
16
17
+ import collections .abc as abc
17
18
import logging
18
19
import pickle
19
20
import sys
20
21
import time
21
22
import warnings
22
23
23
24
from collections import defaultdict
24
- from collections .abc import Iterable
25
25
from copy import copy
26
- from typing import Any , Dict
27
- from typing import Iterable as TIterable
28
- from typing import List , Optional , Union , cast
26
+ from typing import Any , Dict , Iterable , List , Optional , Set , Union , cast
29
27
30
28
import arviz
31
29
import numpy as np
57
55
HamiltonianMC ,
58
56
Metropolis ,
59
57
Slice ,
60
- arraystep ,
61
58
)
59
+ from pymc3 .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
62
60
from pymc3 .step_methods .hmc import quadpotential
63
61
from pymc3 .util import (
64
62
chains_and_samples ,
93
91
CategoricalGibbsMetropolis ,
94
92
PGBART ,
95
93
)
94
+ Step = Union [BlockedStep , CompoundStep ]
96
95
97
96
ArrayLike = Union [np .ndarray , List [float ]]
98
97
PointType = Dict [str , np .ndarray ]
99
98
PointList = List [PointType ]
99
+ Backend = Union [BaseTrace , MultiTrace , NDArray ]
100
100
101
101
_log = logging .getLogger ("pymc3" )
102
102
103
103
104
- def instantiate_steppers (_model , steps , selected_steps , step_kwargs = None ):
104
+ def instantiate_steppers (
105
+ _model , steps : List [Step ], selected_steps , step_kwargs = None
106
+ ) -> Union [Step , List [Step ]]:
105
107
"""Instantiate steppers assigned to the model variables.
106
108
107
109
This function is intended to be called automatically from ``sample()``, but
@@ -142,7 +144,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
142
144
raise ValueError ("Unused step method arguments: %s" % unused_args )
143
145
144
146
if len (steps ) == 1 :
145
- steps = steps [0 ]
147
+ return steps [0 ]
146
148
147
149
return steps
148
150
@@ -216,7 +218,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
216
218
return instantiate_steppers (model , steps , selected_steps , step_kwargs )
217
219
218
220
219
- def _print_step_hierarchy (s , level = 0 ):
221
+ def _print_step_hierarchy (s : Step , level = 0 ) -> None :
220
222
if isinstance (s , CompoundStep ):
221
223
_log .info (">" * level + "CompoundStep" )
222
224
for i in s .methods :
@@ -447,7 +449,7 @@ def sample(
447
449
if random_seed is not None :
448
450
np .random .seed (random_seed )
449
451
random_seed = [np .random .randint (2 ** 30 ) for _ in range (chains )]
450
- if not isinstance (random_seed , Iterable ):
452
+ if not isinstance (random_seed , abc . Iterable ):
451
453
raise TypeError ("Invalid value for `random_seed`. Must be tuple, list or int" )
452
454
453
455
if not discard_tuned_samples and not return_inferencedata :
@@ -542,7 +544,7 @@ def sample(
542
544
543
545
has_population_samplers = np .any (
544
546
[
545
- isinstance (m , arraystep . PopulationArrayStepShared )
547
+ isinstance (m , PopulationArrayStepShared )
546
548
for m in (step .methods if isinstance (step , CompoundStep ) else [step ])
547
549
]
548
550
)
@@ -706,7 +708,7 @@ def _sample_many(
706
708
trace: MultiTrace
707
709
Contains samples of all chains
708
710
"""
709
- traces = []
711
+ traces : List [ Backend ] = []
710
712
for i in range (chains ):
711
713
trace = _sample (
712
714
draws = draws ,
@@ -1140,7 +1142,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
1140
1142
# has to be updated, therefore we identify the substeppers first.
1141
1143
population_steppers = []
1142
1144
for sm in stepper .methods if isinstance (stepper , CompoundStep ) else [stepper ]:
1143
- if isinstance (sm , arraystep . PopulationArrayStepShared ):
1145
+ if isinstance (sm , PopulationArrayStepShared ):
1144
1146
population_steppers .append (sm )
1145
1147
while True :
1146
1148
incoming = secondary_end .recv ()
@@ -1259,7 +1261,7 @@ def _prepare_iter_population(
1259
1261
population = [Point (start [c ], model = model ) for c in range (nchains )]
1260
1262
1261
1263
# 3. Set up the steppers
1262
- steppers = [ None ] * nchains
1264
+ steppers : List [ Step ] = []
1263
1265
for c in range (nchains ):
1264
1266
# need indepenent samplers for each chain
1265
1267
# it is important to copy the actual steppers (but not the delta_logp)
@@ -1269,9 +1271,9 @@ def _prepare_iter_population(
1269
1271
chainstep = copy (step )
1270
1272
# link population samplers to the shared population state
1271
1273
for sm in chainstep .methods if isinstance (step , CompoundStep ) else [chainstep ]:
1272
- if isinstance (sm , arraystep . PopulationArrayStepShared ):
1274
+ if isinstance (sm , PopulationArrayStepShared ):
1273
1275
sm .link_population (population , c )
1274
- steppers [ c ] = chainstep
1276
+ steppers . append ( chainstep )
1275
1277
1276
1278
# 4. configure tracking of sampler stats
1277
1279
for c in range (nchains ):
@@ -1349,7 +1351,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
1349
1351
steppers [c ].report ._finalize (strace )
1350
1352
1351
1353
1352
- def _choose_backend (trace , chain , ** kwds ):
1354
+ def _choose_backend (trace , chain , ** kwds ) -> Backend :
1353
1355
"""Selects or creates a NDArray trace backend for a particular chain.
1354
1356
1355
1357
Parameters
@@ -1562,8 +1564,8 @@ class _DefaultTrace:
1562
1564
`insert()` method
1563
1565
"""
1564
1566
1565
- trace_dict = {} # type : Dict[str, np.ndarray]
1566
- _len = None # type: int
1567
+ trace_dict : Dict [str , np .ndarray ] = {}
1568
+ _len : Optional [ int ] = None
1567
1569
1568
1570
def __init__ (self , samples : int ):
1569
1571
self ._len = samples
@@ -1600,7 +1602,7 @@ def sample_posterior_predictive(
1600
1602
trace ,
1601
1603
samples : Optional [int ] = None ,
1602
1604
model : Optional [Model ] = None ,
1603
- vars : Optional [TIterable [Tensor ]] = None ,
1605
+ vars : Optional [Iterable [Tensor ]] = None ,
1604
1606
var_names : Optional [List [str ]] = None ,
1605
1607
size : Optional [int ] = None ,
1606
1608
keep_size : Optional [bool ] = False ,
@@ -1885,8 +1887,7 @@ def sample_posterior_predictive_w(
1885
1887
def sample_prior_predictive (
1886
1888
samples = 500 ,
1887
1889
model : Optional [Model ] = None ,
1888
- vars : Optional [TIterable [str ]] = None ,
1889
- var_names : Optional [TIterable [str ]] = None ,
1890
+ var_names : Optional [Iterable [str ]] = None ,
1890
1891
random_seed = None ,
1891
1892
) -> Dict [str , np .ndarray ]:
1892
1893
"""Generate samples from the prior predictive distribution.
@@ -1896,9 +1897,6 @@ def sample_prior_predictive(
1896
1897
samples : int
1897
1898
Number of samples from the prior predictive to generate. Defaults to 500.
1898
1899
model : Model (optional if in ``with`` context)
1899
- vars : Iterable[str]
1900
- A list of names of variables for which to compute the posterior predictive
1901
- samples. *DEPRECATED* - Use ``var_names`` argument instead.
1902
1900
var_names : Iterable[str]
1903
1901
A list of names of variables for which to compute the posterior predictive
1904
1902
samples. Defaults to both observed and unobserved RVs.
@@ -1913,22 +1911,14 @@ def sample_prior_predictive(
1913
1911
"""
1914
1912
model = modelcontext (model )
1915
1913
1916
- if vars is None and var_names is None :
1914
+ if var_names is None :
1917
1915
prior_pred_vars = model .observed_RVs
1918
1916
prior_vars = (
1919
1917
get_default_varnames (model .unobserved_RVs , include_transformed = True ) + model .potentials
1920
1918
)
1921
- vars_ = [var .name for var in prior_vars + prior_pred_vars ]
1922
- vars = set (vars_ )
1923
- elif vars is None :
1924
- vars = var_names
1925
- vars_ = vars
1926
- elif vars is not None :
1927
- warnings .warn ("vars argument is deprecated in favor of var_names." , DeprecationWarning )
1928
- vars_ = vars
1919
+ vars_ : Set [str ] = {var .name for var in prior_vars + prior_pred_vars }
1929
1920
else :
1930
- raise ValueError ("Cannot supply both vars and var_names arguments." )
1931
- vars = cast (TIterable [str ], vars ) # tell mypy that vars cannot be None here.
1921
+ vars_ = set (var_names )
1932
1922
1933
1923
if random_seed is not None :
1934
1924
np .random .seed (random_seed )
@@ -1940,8 +1930,8 @@ def sample_prior_predictive(
1940
1930
if data is None :
1941
1931
raise AssertionError ("No variables sampled: attempting to sample %s" % names )
1942
1932
1943
- prior = {} # type : Dict[str, np.ndarray]
1944
- for var_name in vars :
1933
+ prior : Dict [str , np .ndarray ] = {}
1934
+ for var_name in vars_ :
1945
1935
if var_name in data :
1946
1936
prior [var_name ] = data [var_name ]
1947
1937
elif is_transformed_name (var_name ):
@@ -2093,15 +2083,15 @@ def init_nuts(
2093
2083
var = np .ones_like (mean )
2094
2084
potential = quadpotential .QuadPotentialDiagAdapt (model .ndim , mean , var , 10 )
2095
2085
elif init == "advi+adapt_diag_grad" :
2096
- approx = pm .fit (
2086
+ approx : pm . MeanField = pm .fit (
2097
2087
random_seed = random_seed ,
2098
2088
n = n_init ,
2099
2089
method = "advi" ,
2100
2090
model = model ,
2101
2091
callbacks = cb ,
2102
2092
progressbar = progressbar ,
2103
2093
obj_optimizer = pm .adagrad_window ,
2104
- ) # type: pm.MeanField
2094
+ )
2105
2095
start = approx .sample (draws = chains )
2106
2096
start = list (start )
2107
2097
stds = approx .bij .rmap (approx .std .eval ())
@@ -2119,7 +2109,7 @@ def init_nuts(
2119
2109
callbacks = cb ,
2120
2110
progressbar = progressbar ,
2121
2111
obj_optimizer = pm .adagrad_window ,
2122
- ) # type: pm.MeanField
2112
+ )
2123
2113
start = approx .sample (draws = chains )
2124
2114
start = list (start )
2125
2115
stds = approx .bij .rmap (approx .std .eval ())
@@ -2137,7 +2127,7 @@ def init_nuts(
2137
2127
callbacks = cb ,
2138
2128
progressbar = progressbar ,
2139
2129
obj_optimizer = pm .adagrad_window ,
2140
- ) # type: pm.MeanField
2130
+ )
2141
2131
start = approx .sample (draws = chains )
2142
2132
start = list (start )
2143
2133
stds = approx .bij .rmap (approx .std .eval ())
0 commit comments