Skip to content

Commit 4caa5b9

Browse files
committed
Refactor RandomState/Generator reseeding logic in initial_point into aesaraf
1 parent b5a5b56 commit 4caa5b9

File tree

3 files changed

+72
-25
lines changed

3 files changed

+72
-25
lines changed

pymc/aesaraf.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Iterable,
2121
List,
2222
Optional,
23+
Sequence,
2324
Set,
2425
Tuple,
2526
Union,
@@ -893,6 +894,40 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
893894
)
894895

895896

897+
def find_rng_nodes(variables: Iterable[TensorVariable]):
898+
"""Return RNG variables in a graph"""
899+
return [
900+
node
901+
for node in graph_inputs(variables)
902+
if isinstance(
903+
node,
904+
(
905+
at.random.var.RandomStateSharedVariable,
906+
at.random.var.RandomGeneratorSharedVariable,
907+
),
908+
)
909+
]
910+
911+
912+
SeedSequenceSeed = Optional[Union[int, Sequence[int], np.ndarray, np.random.SeedSequence]]
913+
914+
915+
def reseed_rngs(
916+
rngs: Sequence[SharedVariable],
917+
seed: SeedSequenceSeed,
918+
) -> None:
919+
"""Create a new set of RandomState/Generator for each rng based on a seed"""
920+
bit_generators = [
921+
np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
922+
]
923+
for rng, bit_generator in zip(rngs, bit_generators):
924+
if isinstance(rng, at.random.var.RandomStateSharedVariable):
925+
new_rng = np.random.RandomState(bit_generator)
926+
else:
927+
new_rng = np.random.Generator(bit_generator)
928+
rng.set_value(new_rng, borrow=True)
929+
930+
896931
def compile_pymc(
897932
inputs, outputs, mode=None, **kwargs
898933
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:

pymc/initial_point.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
import aesara.tensor as at
2121
import numpy as np
2222

23-
from aesara.graph.basic import Variable, graph_inputs
23+
from aesara.graph.basic import Variable
2424
from aesara.graph.fg import FunctionGraph
2525
from aesara.tensor.var import TensorVariable
2626

27-
from pymc.aesaraf import compile_pymc
27+
from pymc.aesaraf import compile_pymc, find_rng_nodes, reseed_rngs
2828
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name
2929

3030
StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]]
@@ -150,19 +150,6 @@ def make_initial_point_fn(
150150
If `True` the returned variables will correspond to transformed initial values.
151151
"""
152152

153-
def find_rng_nodes(variables):
154-
return [
155-
node
156-
for node in graph_inputs(variables)
157-
if isinstance(
158-
node,
159-
(
160-
at.random.var.RandomStateSharedVariable,
161-
at.random.var.RandomGeneratorSharedVariable,
162-
),
163-
)
164-
]
165-
166153
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
167154
initval_strats = {
168155
**model.initial_values,
@@ -208,16 +195,7 @@ def make_seeded_function(func):
208195

209196
@functools.wraps(func)
210197
def inner(seed, *args, **kwargs):
211-
seeds = [
212-
np.random.PCG64(sub_seed)
213-
for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
214-
]
215-
for rng, seed in zip(rngs, seeds):
216-
if isinstance(rng, at.random.var.RandomStateSharedVariable):
217-
new_rng = np.random.RandomState(seed)
218-
else:
219-
new_rng = np.random.Generator(seed)
220-
rng.set_value(new_rng, True)
198+
reseed_rngs(rngs, seed)
221199
values = func(*args, **kwargs)
222200
return dict(zip(varnames, values))
223201

pymc/tests/test_aesaraf.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aesara.graph.basic import Constant, Variable, ancestors, equal_computations
2727
from aesara.tensor.random.basic import normal, uniform
2828
from aesara.tensor.random.op import RandomVariable
29+
from aesara.tensor.random.var import RandomStateSharedVariable
2930
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
3031
from aesara.tensor.var import TensorVariable
3132

@@ -36,6 +37,7 @@
3637
compile_pymc,
3738
convert_observed_data,
3839
extract_obs_data,
40+
reseed_rngs,
3941
rvs_to_value_vars,
4042
walk_model,
4143
)
@@ -507,3 +509,35 @@ def update(self, node):
507509
fn = compile_pymc(inputs=[], outputs=dummy_x)
508510
assert fn() == 2.0
509511
assert fn() == 3.0
512+
513+
514+
def test_reseed_rngs():
515+
# Reseed_rngs uses the `PCG64` bit_generator, which is currently the default
516+
# bit_generator used by NumPy. If this default changes in the future, this test will
517+
# catch that. We will then have to decide whether to switch to the new default in
518+
# PyMC or whether to stick with the older one (PCG64). This will pose a trade-off
519+
# between backwards reproducibility and better/faster seeding. If we decide to change,
520+
# the next line should be updated:
521+
default_rng = np.random.PCG64
522+
assert isinstance(np.random.default_rng().bit_generator, default_rng)
523+
524+
seed = 543
525+
526+
bit_generators = [default_rng(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(2)]
527+
528+
rngs = [
529+
aesara.shared(rng_type(default_rng()))
530+
for rng_type in (np.random.Generator, np.random.RandomState)
531+
]
532+
for rng, bit_generator in zip(rngs, bit_generators):
533+
if isinstance(rng, RandomStateSharedVariable):
534+
assert rng.get_value()._bit_generator.state != bit_generator.state
535+
else:
536+
assert rng.get_value().bit_generator.state != bit_generator.state
537+
538+
reseed_rngs(rngs, seed)
539+
for rng, bit_generator in zip(rngs, bit_generators):
540+
if isinstance(rng, RandomStateSharedVariable):
541+
assert rng.get_value()._bit_generator.state == bit_generator.state
542+
else:
543+
assert rng.get_value().bit_generator.state == bit_generator.state

0 commit comments

Comments
 (0)