23
23
24
24
from collections import defaultdict
25
25
from copy import copy
26
- from typing import Dict , Iterable , List , Optional , Sequence , Set , Tuple , Union , cast
26
+ from typing import (
27
+ Dict ,
28
+ Iterable ,
29
+ Iterator ,
30
+ List ,
31
+ Optional ,
32
+ Sequence ,
33
+ Set ,
34
+ Tuple ,
35
+ Union ,
36
+ cast ,
37
+ )
27
38
28
39
import aesara .gradient as tg
29
40
import cloudpickle
@@ -700,7 +711,7 @@ def _sample_many(
700
711
step ,
701
712
callback = None ,
702
713
** kwargs ,
703
- ):
714
+ ) -> MultiTrace :
704
715
"""Samples all chains sequentially.
705
716
706
717
Parameters
@@ -760,7 +771,7 @@ def _sample_population(
760
771
progressbar : bool = True ,
761
772
parallelize = False ,
762
773
** kwargs ,
763
- ):
774
+ ) -> MultiTrace :
764
775
"""Performs sampling of a population of chains using the ``PopulationStepper``.
765
776
766
777
Parameters
@@ -823,7 +834,7 @@ def _sample(
823
834
model : Optional [Model ] = None ,
824
835
callback = None ,
825
836
** kwargs ,
826
- ):
837
+ ) -> MultiTrace :
827
838
"""Main iteration for singleprocess sampling.
828
839
829
840
Multiple step methods are supported via compound step methods.
@@ -853,8 +864,8 @@ def _sample(
853
864
854
865
Returns
855
866
-------
856
- strace : pymc.backends.base.BaseTrace
857
- A ``BaseTrace `` object that contains the samples for this chain.
867
+ strace : MultiTrace
868
+ A ``MultiTrace `` object that contains the samples for this chain.
858
869
"""
859
870
skip_first = kwargs .get ("skip_first" , 0 )
860
871
@@ -888,7 +899,7 @@ def iter_sample(
888
899
model : Optional [Model ] = None ,
889
900
random_seed : Optional [Union [int , List [int ]]] = None ,
890
901
callback = None ,
891
- ):
902
+ ) -> Iterator [ MultiTrace ] :
892
903
"""Generate a trace on each iteration using the given step method.
893
904
894
905
Multiple step methods ared supported via compound step methods. Returns the
@@ -947,7 +958,7 @@ def _iter_sample(
947
958
model = None ,
948
959
random_seed = None ,
949
960
callback = None ,
950
- ):
961
+ ) -> Iterator [ Tuple [ Backend , bool ]] :
951
962
"""Generator for sampling one chain. (Used in singleprocess sampling.)
952
963
953
964
Parameters
@@ -1207,7 +1218,7 @@ def _prepare_iter_population(
1207
1218
model = None ,
1208
1219
random_seed = None ,
1209
1220
progressbar = True ,
1210
- ):
1221
+ ) -> Iterator [ Sequence [ BaseTrace ]] :
1211
1222
"""Prepare a PopulationStepper and traces for population sampling.
1212
1223
1213
1224
Parameters
@@ -1287,7 +1298,9 @@ def _prepare_iter_population(
1287
1298
return _iter_population (draws , tune , popstep , steppers , traces , population )
1288
1299
1289
1300
1290
- def _iter_population (draws , tune , popstep , steppers , traces , points ):
1301
+ def _iter_population (
1302
+ draws : int , tune : int , popstep : PopulationStepper , steppers , traces : Sequence [BaseTrace ], points
1303
+ ) -> Iterator [Sequence [BaseTrace ]]:
1291
1304
"""Iterate a ``PopulationStepper``.
1292
1305
1293
1306
Parameters
@@ -1393,14 +1406,14 @@ def _mp_sample(
1393
1406
discard_tuned_samples = True ,
1394
1407
mp_ctx = None ,
1395
1408
** kwargs ,
1396
- ):
1409
+ ) -> MultiTrace :
1397
1410
"""Main iteration for multiprocess sampling.
1398
1411
1399
1412
Parameters
1400
1413
----------
1401
1414
draws : int
1402
1415
The number of samples to draw
1403
- tune : int, optional
1416
+ tune : int
1404
1417
Number of iterations to tune, if applicable (defaults to None)
1405
1418
step : function
1406
1419
Step function
@@ -1501,7 +1514,7 @@ def _mp_sample(
1501
1514
trace .close ()
1502
1515
1503
1516
1504
- def _choose_chains (traces , tune ) :
1517
+ def _choose_chains (traces : Sequence [ BaseTrace ] , tune : Optional [ int ]) -> Tuple [ List [ BaseTrace ], int ] :
1505
1518
"""
1506
1519
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
1507
1520
@@ -1514,7 +1527,7 @@ def _choose_chains(traces, tune):
1514
1527
tune = 0
1515
1528
1516
1529
if not traces :
1517
- return []
1530
+ raise ValueError ( "No traces to slice." )
1518
1531
1519
1532
lengths = [max (0 , len (trace ) - tune ) for trace in traces ]
1520
1533
if not sum (lengths ):
0 commit comments