Skip to content

Commit d36bed3

Browse files
author
Juan Orduz
authored
Improve random seed processing for SMC sampling (#6298)
* improve random seed processing * improve type-hint
1 parent 3f9d2e2 commit d36bed3

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

pymc/smc/sampling.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import warnings
1919

2020
from collections import defaultdict
21-
from collections.abc import Iterable
2221
from itertools import repeat
2322

2423
import cloudpickle
@@ -34,6 +33,7 @@
3433
from pymc.model import modelcontext
3534
from pymc.sampling.parallel import _cpu_count
3635
from pymc.smc.kernels import IMH
36+
from pymc.util import RandomState, _get_seeds_per_chain
3737

3838

3939
def sample_smc(
@@ -42,7 +42,7 @@ def sample_smc(
4242
*,
4343
start=None,
4444
model=None,
45-
random_seed=None,
45+
random_seed: RandomState = None,
4646
chains=None,
4747
cores=None,
4848
compute_convergence_checks=True,
@@ -64,8 +64,10 @@ def sample_smc(
6464
Starting point in parameter space. It should be a list of dict with length `chains`.
6565
When None (default) the starting point is sampled from the prior distribution.
6666
model: Model (optional if in ``with`` context)).
67-
random_seed: int
68-
random seed
67+
random_seed : int, array-like of int, RandomState or Generator, optional
68+
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
69+
is passed, each entry will be used to seed each chain. A ValueError will be
70+
raised if the length does not match the number of chains.
6971
chains : int
7072
The number of chains to sample. Running independent chains is important for some
7173
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
@@ -183,20 +185,7 @@ def sample_smc(
183185
else:
184186
cores = min(chains, cores)
185187

186-
if random_seed == -1:
187-
raise FutureWarning(
188-
f"random_seed should be a non-negative integer or None, got: {random_seed}"
189-
"This will raise a ValueError in the Future"
190-
)
191-
random_seed = None
192-
if isinstance(random_seed, int) or random_seed is None:
193-
rng = np.random.default_rng(seed=random_seed)
194-
random_seed = list(rng.integers(2**30, size=chains))
195-
elif isinstance(random_seed, Iterable):
196-
if len(random_seed) != chains:
197-
raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}")
198-
else:
199-
raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None")
188+
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)
200189

201190
model = modelcontext(model)
202191

0 commit comments

Comments
 (0)