Skip to content

Commit 98cc942

Browse files
Add or fix type hints related to backends (#5294)
Also raise a `ValueError` in `_choose_chains` instead of returning an incompatible type.
1 parent 0757e57 commit 98cc942

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

pymc/sampling.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,18 @@
2323

2424
from collections import defaultdict
2525
from copy import copy
26-
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
26+
from typing import (
27+
Dict,
28+
Iterable,
29+
Iterator,
30+
List,
31+
Optional,
32+
Sequence,
33+
Set,
34+
Tuple,
35+
Union,
36+
cast,
37+
)
2738

2839
import aesara.gradient as tg
2940
import cloudpickle
@@ -700,7 +711,7 @@ def _sample_many(
700711
step,
701712
callback=None,
702713
**kwargs,
703-
):
714+
) -> MultiTrace:
704715
"""Samples all chains sequentially.
705716
706717
Parameters
@@ -760,7 +771,7 @@ def _sample_population(
760771
progressbar: bool = True,
761772
parallelize=False,
762773
**kwargs,
763-
):
774+
) -> MultiTrace:
764775
"""Performs sampling of a population of chains using the ``PopulationStepper``.
765776
766777
Parameters
@@ -823,7 +834,7 @@ def _sample(
823834
model: Optional[Model] = None,
824835
callback=None,
825836
**kwargs,
826-
):
837+
) -> MultiTrace:
827838
"""Main iteration for singleprocess sampling.
828839
829840
Multiple step methods are supported via compound step methods.
@@ -853,8 +864,8 @@ def _sample(
853864
854865
Returns
855866
-------
856-
strace : pymc.backends.base.BaseTrace
857-
A ``BaseTrace`` object that contains the samples for this chain.
867+
strace : MultiTrace
868+
A ``MultiTrace`` object that contains the samples for this chain.
858869
"""
859870
skip_first = kwargs.get("skip_first", 0)
860871

@@ -888,7 +899,7 @@ def iter_sample(
888899
model: Optional[Model] = None,
889900
random_seed: Optional[Union[int, List[int]]] = None,
890901
callback=None,
891-
):
902+
) -> Iterator[MultiTrace]:
892903
"""Generate a trace on each iteration using the given step method.
893904
894905
Multiple step methods ared supported via compound step methods. Returns the
@@ -947,7 +958,7 @@ def _iter_sample(
947958
model=None,
948959
random_seed=None,
949960
callback=None,
950-
):
961+
) -> Iterator[Tuple[Backend, bool]]:
951962
"""Generator for sampling one chain. (Used in singleprocess sampling.)
952963
953964
Parameters
@@ -1207,7 +1218,7 @@ def _prepare_iter_population(
12071218
model=None,
12081219
random_seed=None,
12091220
progressbar=True,
1210-
):
1221+
) -> Iterator[Sequence[BaseTrace]]:
12111222
"""Prepare a PopulationStepper and traces for population sampling.
12121223
12131224
Parameters
@@ -1287,7 +1298,9 @@ def _prepare_iter_population(
12871298
return _iter_population(draws, tune, popstep, steppers, traces, population)
12881299

12891300

1290-
def _iter_population(draws, tune, popstep, steppers, traces, points):
1301+
def _iter_population(
1302+
draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points
1303+
) -> Iterator[Sequence[BaseTrace]]:
12911304
"""Iterate a ``PopulationStepper``.
12921305
12931306
Parameters
@@ -1393,14 +1406,14 @@ def _mp_sample(
13931406
discard_tuned_samples=True,
13941407
mp_ctx=None,
13951408
**kwargs,
1396-
):
1409+
) -> MultiTrace:
13971410
"""Main iteration for multiprocess sampling.
13981411
13991412
Parameters
14001413
----------
14011414
draws : int
14021415
The number of samples to draw
1403-
tune : int, optional
1416+
tune : int
14041417
Number of iterations to tune, if applicable (defaults to None)
14051418
step : function
14061419
Step function
@@ -1501,7 +1514,7 @@ def _mp_sample(
15011514
trace.close()
15021515

15031516

1504-
def _choose_chains(traces, tune):
1517+
def _choose_chains(traces: Sequence[BaseTrace], tune: Optional[int]) -> Tuple[List[BaseTrace], int]:
15051518
"""
15061519
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
15071520
@@ -1514,7 +1527,7 @@ def _choose_chains(traces, tune):
15141527
tune = 0
15151528

15161529
if not traces:
1517-
return []
1530+
raise ValueError("No traces to slice.")
15181531

15191532
lengths = [max(0, len(trace) - tune) for trace in traces]
15201533
if not sum(lengths):

0 commit comments

Comments
 (0)