Skip to content

Commit b386d94

Browse files
authored
Fixup mypy errors in sampling.py (#4327)
* 🏷️ type sampling * 🔥 remove deprecated vars from sample_prior_predictive
1 parent 70fdcf9 commit b386d94

File tree

5 files changed

+40
-49
lines changed

5 files changed

+40
-49
lines changed

pymc3/sampling.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414

1515
"""Functions for MCMC sampling."""
1616

17+
import collections.abc as abc
1718
import logging
1819
import pickle
1920
import sys
2021
import time
2122
import warnings
2223

2324
from collections import defaultdict
24-
from collections.abc import Iterable
2525
from copy import copy
26-
from typing import Any, Dict
27-
from typing import Iterable as TIterable
28-
from typing import List, Optional, Union, cast
26+
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
2927

3028
import arviz
3129
import numpy as np
@@ -57,8 +55,8 @@
5755
HamiltonianMC,
5856
Metropolis,
5957
Slice,
60-
arraystep,
6158
)
59+
from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
6260
from pymc3.step_methods.hmc import quadpotential
6361
from pymc3.util import (
6462
chains_and_samples,
@@ -93,15 +91,19 @@
9391
CategoricalGibbsMetropolis,
9492
PGBART,
9593
)
94+
Step = Union[BlockedStep, CompoundStep]
9695

9796
ArrayLike = Union[np.ndarray, List[float]]
9897
PointType = Dict[str, np.ndarray]
9998
PointList = List[PointType]
99+
Backend = Union[BaseTrace, MultiTrace, NDArray]
100100

101101
_log = logging.getLogger("pymc3")
102102

103103

104-
def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
104+
def instantiate_steppers(
105+
_model, steps: List[Step], selected_steps, step_kwargs=None
106+
) -> Union[Step, List[Step]]:
105107
"""Instantiate steppers assigned to the model variables.
106108
107109
This function is intended to be called automatically from ``sample()``, but
@@ -142,7 +144,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
142144
raise ValueError("Unused step method arguments: %s" % unused_args)
143145

144146
if len(steps) == 1:
145-
steps = steps[0]
147+
return steps[0]
146148

147149
return steps
148150

@@ -216,7 +218,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
216218
return instantiate_steppers(model, steps, selected_steps, step_kwargs)
217219

218220

219-
def _print_step_hierarchy(s, level=0):
221+
def _print_step_hierarchy(s: Step, level=0) -> None:
220222
if isinstance(s, CompoundStep):
221223
_log.info(">" * level + "CompoundStep")
222224
for i in s.methods:
@@ -447,7 +449,7 @@ def sample(
447449
if random_seed is not None:
448450
np.random.seed(random_seed)
449451
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
450-
if not isinstance(random_seed, Iterable):
452+
if not isinstance(random_seed, abc.Iterable):
451453
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
452454

453455
if not discard_tuned_samples and not return_inferencedata:
@@ -542,7 +544,7 @@ def sample(
542544

543545
has_population_samplers = np.any(
544546
[
545-
isinstance(m, arraystep.PopulationArrayStepShared)
547+
isinstance(m, PopulationArrayStepShared)
546548
for m in (step.methods if isinstance(step, CompoundStep) else [step])
547549
]
548550
)
@@ -706,7 +708,7 @@ def _sample_many(
706708
trace: MultiTrace
707709
Contains samples of all chains
708710
"""
709-
traces = []
711+
traces: List[Backend] = []
710712
for i in range(chains):
711713
trace = _sample(
712714
draws=draws,
@@ -1140,7 +1142,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
11401142
# has to be updated, therefore we identify the substeppers first.
11411143
population_steppers = []
11421144
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
1143-
if isinstance(sm, arraystep.PopulationArrayStepShared):
1145+
if isinstance(sm, PopulationArrayStepShared):
11441146
population_steppers.append(sm)
11451147
while True:
11461148
incoming = secondary_end.recv()
@@ -1259,7 +1261,7 @@ def _prepare_iter_population(
12591261
population = [Point(start[c], model=model) for c in range(nchains)]
12601262

12611263
# 3. Set up the steppers
1262-
steppers = [None] * nchains
1264+
steppers: List[Step] = []
12631265
for c in range(nchains):
12641266
# need indepenent samplers for each chain
12651267
# it is important to copy the actual steppers (but not the delta_logp)
@@ -1269,9 +1271,9 @@ def _prepare_iter_population(
12691271
chainstep = copy(step)
12701272
# link population samplers to the shared population state
12711273
for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
1272-
if isinstance(sm, arraystep.PopulationArrayStepShared):
1274+
if isinstance(sm, PopulationArrayStepShared):
12731275
sm.link_population(population, c)
1274-
steppers[c] = chainstep
1276+
steppers.append(chainstep)
12751277

12761278
# 4. configure tracking of sampler stats
12771279
for c in range(nchains):
@@ -1349,7 +1351,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
13491351
steppers[c].report._finalize(strace)
13501352

13511353

1352-
def _choose_backend(trace, chain, **kwds):
1354+
def _choose_backend(trace, chain, **kwds) -> Backend:
13531355
"""Selects or creates a NDArray trace backend for a particular chain.
13541356
13551357
Parameters
@@ -1562,8 +1564,8 @@ class _DefaultTrace:
15621564
`insert()` method
15631565
"""
15641566

1565-
trace_dict = {} # type: Dict[str, np.ndarray]
1566-
_len = None # type: int
1567+
trace_dict: Dict[str, np.ndarray] = {}
1568+
_len: Optional[int] = None
15671569

15681570
def __init__(self, samples: int):
15691571
self._len = samples
@@ -1600,7 +1602,7 @@ def sample_posterior_predictive(
16001602
trace,
16011603
samples: Optional[int] = None,
16021604
model: Optional[Model] = None,
1603-
vars: Optional[TIterable[Tensor]] = None,
1605+
vars: Optional[Iterable[Tensor]] = None,
16041606
var_names: Optional[List[str]] = None,
16051607
size: Optional[int] = None,
16061608
keep_size: Optional[bool] = False,
@@ -1885,8 +1887,7 @@ def sample_posterior_predictive_w(
18851887
def sample_prior_predictive(
18861888
samples=500,
18871889
model: Optional[Model] = None,
1888-
vars: Optional[TIterable[str]] = None,
1889-
var_names: Optional[TIterable[str]] = None,
1890+
var_names: Optional[Iterable[str]] = None,
18901891
random_seed=None,
18911892
) -> Dict[str, np.ndarray]:
18921893
"""Generate samples from the prior predictive distribution.
@@ -1896,9 +1897,6 @@ def sample_prior_predictive(
18961897
samples : int
18971898
Number of samples from the prior predictive to generate. Defaults to 500.
18981899
model : Model (optional if in ``with`` context)
1899-
vars : Iterable[str]
1900-
A list of names of variables for which to compute the posterior predictive
1901-
samples. *DEPRECATED* - Use ``var_names`` argument instead.
19021900
var_names : Iterable[str]
19031901
A list of names of variables for which to compute the posterior predictive
19041902
samples. Defaults to both observed and unobserved RVs.
@@ -1913,22 +1911,14 @@ def sample_prior_predictive(
19131911
"""
19141912
model = modelcontext(model)
19151913

1916-
if vars is None and var_names is None:
1914+
if var_names is None:
19171915
prior_pred_vars = model.observed_RVs
19181916
prior_vars = (
19191917
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
19201918
)
1921-
vars_ = [var.name for var in prior_vars + prior_pred_vars]
1922-
vars = set(vars_)
1923-
elif vars is None:
1924-
vars = var_names
1925-
vars_ = vars
1926-
elif vars is not None:
1927-
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
1928-
vars_ = vars
1919+
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
19291920
else:
1930-
raise ValueError("Cannot supply both vars and var_names arguments.")
1931-
vars = cast(TIterable[str], vars) # tell mypy that vars cannot be None here.
1921+
vars_ = set(var_names)
19321922

19331923
if random_seed is not None:
19341924
np.random.seed(random_seed)
@@ -1940,8 +1930,8 @@ def sample_prior_predictive(
19401930
if data is None:
19411931
raise AssertionError("No variables sampled: attempting to sample %s" % names)
19421932

1943-
prior = {} # type: Dict[str, np.ndarray]
1944-
for var_name in vars:
1933+
prior: Dict[str, np.ndarray] = {}
1934+
for var_name in vars_:
19451935
if var_name in data:
19461936
prior[var_name] = data[var_name]
19471937
elif is_transformed_name(var_name):
@@ -2093,15 +2083,15 @@ def init_nuts(
20932083
var = np.ones_like(mean)
20942084
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
20952085
elif init == "advi+adapt_diag_grad":
2096-
approx = pm.fit(
2086+
approx: pm.MeanField = pm.fit(
20972087
random_seed=random_seed,
20982088
n=n_init,
20992089
method="advi",
21002090
model=model,
21012091
callbacks=cb,
21022092
progressbar=progressbar,
21032093
obj_optimizer=pm.adagrad_window,
2104-
) # type: pm.MeanField
2094+
)
21052095
start = approx.sample(draws=chains)
21062096
start = list(start)
21072097
stds = approx.bij.rmap(approx.std.eval())
@@ -2119,7 +2109,7 @@ def init_nuts(
21192109
callbacks=cb,
21202110
progressbar=progressbar,
21212111
obj_optimizer=pm.adagrad_window,
2122-
) # type: pm.MeanField
2112+
)
21232113
start = approx.sample(draws=chains)
21242114
start = list(start)
21252115
stds = approx.bij.rmap(approx.std.eval())
@@ -2137,7 +2127,7 @@ def init_nuts(
21372127
callbacks=cb,
21382128
progressbar=progressbar,
21392129
obj_optimizer=pm.adagrad_window,
2140-
) # type: pm.MeanField
2130+
)
21412131
start = approx.sample(draws=chains)
21422132
start = list(start)
21432133
stds = approx.bij.rmap(approx.std.eval())

pymc3/step_methods/arraystep.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414

1515
from enum import IntEnum, unique
16+
from typing import Dict, List
1617

1718
import numpy as np
1819

1920
from numpy.random import uniform
2021

2122
from pymc3.blocking import ArrayOrdering, DictToArrayBijection
22-
from pymc3.model import modelcontext
23+
from pymc3.model import PyMC3Variable, modelcontext
2324
from pymc3.step_methods.compound import CompoundStep
2425
from pymc3.theanof import inputvars
2526
from pymc3.util import get_var_name
@@ -46,6 +47,8 @@ class Competence(IntEnum):
4647
class BlockedStep:
4748

4849
generates_stats = False
50+
stats_dtypes: List[Dict[str, np.dtype]] = []
51+
vars: List[PyMC3Variable] = []
4952

5053
def __new__(cls, *args, **kwargs):
5154
blocked = kwargs.get("blocked")

pymc3/step_methods/mlda.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,6 @@ class MLDA(ArrayStepShared):
356356
default_blocked = True
357357
generates_stats = True
358358

359-
# stat data types are different, depending on the base sampler.
360-
# these are assigned in the init method.
361-
stats_dtypes = None
362-
363359
def __init__(
364360
self,
365361
coarse_models: List[Model],

pymc3/tests/test_sampling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -903,9 +903,8 @@ def test_respects_shape(self):
903903
with pm.Model():
904904
mu = pm.Gamma("mu", 3, 1, shape=1)
905905
goals = pm.Poisson("goals", mu, shape=shape)
906-
with pytest.warns(DeprecationWarning):
907-
trace1 = pm.sample_prior_predictive(10, vars=["mu", "goals"])
908-
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
906+
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
907+
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
909908
if shape == 2: # want to test shape as an int
910909
shape = (2,)
911910
assert trace1["goals"].shape == (10,) + shape

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ convention = numpy
1111
[isort]
1212
lines_between_types = 1
1313
profile = black
14+
15+
[mypy]
16+
ignore_missing_imports = True

0 commit comments

Comments
 (0)