18
18
import warnings
19
19
20
20
from collections import defaultdict
21
- from collections .abc import Iterable
22
21
from itertools import repeat
23
22
24
23
import cloudpickle
34
33
from pymc .model import modelcontext
35
34
from pymc .sampling .parallel import _cpu_count
36
35
from pymc .smc .kernels import IMH
36
+ from pymc .util import RandomState , _get_seeds_per_chain
37
37
38
38
39
39
def sample_smc (
@@ -42,7 +42,7 @@ def sample_smc(
42
42
* ,
43
43
start = None ,
44
44
model = None ,
45
- random_seed = None ,
45
+ random_seed : RandomState = None ,
46
46
chains = None ,
47
47
cores = None ,
48
48
compute_convergence_checks = True ,
@@ -64,8 +64,10 @@ def sample_smc(
64
64
Starting point in parameter space. It should be a list of dict with length `chains`.
65
65
When None (default) the starting point is sampled from the prior distribution.
66
66
model: Model (optional if in ``with`` context)).
67
- random_seed: int
68
- random seed
67
+ random_seed : int, array-like of int, RandomState or Generator, optional
68
+ Random seed(s) used by the sampling steps. If a list, tuple or array of ints
69
+ is passed, each entry will be used to seed each chain. A ValueError will be
70
+ raised if the length does not match the number of chains.
69
71
chains : int
70
72
The number of chains to sample. Running independent chains is important for some
71
73
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
@@ -183,20 +185,7 @@ def sample_smc(
183
185
else :
184
186
cores = min (chains , cores )
185
187
186
- if random_seed == - 1 :
187
- raise FutureWarning (
188
- f"random_seed should be a non-negative integer or None, got: { random_seed } "
189
- "This will raise a ValueError in the Future"
190
- )
191
- random_seed = None
192
- if isinstance (random_seed , int ) or random_seed is None :
193
- rng = np .random .default_rng (seed = random_seed )
194
- random_seed = list (rng .integers (2 ** 30 , size = chains ))
195
- elif isinstance (random_seed , Iterable ):
196
- if len (random_seed ) != chains :
197
- raise ValueError (f"Length of seeds ({ len (seeds )} ) must match number of chains { chains } " )
198
- else :
199
- raise TypeError ("Invalid value for `random_seed`. Must be tuple, list, int or None" )
188
+ random_seed = _get_seeds_per_chain (random_state = random_seed , chains = chains )
200
189
201
190
model = modelcontext (model )
202
191
0 commit comments