Skip to content

Commit eac88c2

Browse files
lucianopazricardoV94
authored andcommitted
DataClassState uses RandomGeneratorState instead of Generator objects
1 parent a918d02 commit eac88c2

File tree

5 files changed

+53
-9
lines changed

5 files changed

+53
-9
lines changed

pymc/step_methods/compound.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131

3232
from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
3333
from pymc.model import modelcontext
34-
from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
34+
from pymc.step_methods.state import (
35+
DataClassState,
36+
RandomGeneratorState,
37+
WithSamplingState,
38+
dataclass_state,
39+
)
3540
from pymc.util import RandomGenerator, get_random_generator
3641

3742
__all__ = ("Competence", "CompoundStep")
@@ -91,7 +96,7 @@ def infer_warn_stats_info(
9196

9297
@dataclass_state
9398
class StepMethodState(DataClassState):
94-
rng: np.random.Generator
99+
rng: RandomGeneratorState
95100

96101

97102
class BlockedStep(ABC, WithSamplingState):

pymc/step_methods/hmc/quadpotential.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from scipy.sparse import issparse
2727

2828
from pymc.pytensorf import floatX
29-
from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
29+
from pymc.step_methods.state import (
30+
DataClassState,
31+
RandomGeneratorState,
32+
WithSamplingState,
33+
dataclass_state,
34+
)
3035
from pymc.util import RandomGenerator, get_random_generator
3136

3237
__all__ = [
@@ -105,7 +110,7 @@ def __str__(self):
105110

106111
@dataclass_state
107112
class PotentialState(DataClassState):
108-
rng: np.random.Generator
113+
rng: RandomGeneratorState
109114

110115

111116
class QuadPotential(WithSamplingState):

pymc/step_methods/state.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import numpy as np
1919

20+
from pymc.util import RandomGeneratorState, get_state_from_generator, random_generator_from_state
21+
2022
dataclass_state = dataclass(kw_only=True)
2123

2224

@@ -66,8 +68,11 @@ def sampling_state(self) -> DataClassState:
6668
kwargs = {}
6769
for field in fields(state_class):
6870
val = getattr(self, field.name)
71+
_val: Any
6972
if isinstance(val, WithSamplingState):
7073
_val = val.sampling_state
74+
elif isinstance(val, np.random.Generator):
75+
_val = get_state_from_generator(val)
7176
else:
7277
_val = val
7378
kwargs[field.name] = deepcopy(_val)
@@ -81,6 +86,8 @@ def sampling_state(self, state: DataClassState):
8186
), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
8287
for field in fields(state_class):
8388
state_val = deepcopy(getattr(state, field.name))
89+
if isinstance(state_val, RandomGeneratorState):
90+
state_val = random_generator_from_state(state_val)
8491
self_val = getattr(self, field.name)
8592
is_frozen = field.metadata.get("frozen", False)
8693
if is_frozen:

pymc/util.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import functools
1616
import warnings
1717

18+
from collections import namedtuple
1819
from collections.abc import Sequence
1920
from copy import deepcopy
2021
from typing import NewType, cast
@@ -601,6 +602,31 @@ def update(
601602
return None
602603

603604

605+
RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"])
606+
607+
608+
def get_state_from_generator(
609+
rng: np.random.Generator | np.random.BitGenerator,
610+
) -> RandomGeneratorState:
611+
assert isinstance(rng, (np.random.Generator | np.random.BitGenerator))
612+
bit_gen: np.random.BitGenerator = (
613+
rng.bit_generator if isinstance(rng, np.random.Generator) else rng
614+
)
615+
616+
return RandomGeneratorState(
617+
bit_generator_state=bit_gen.state,
618+
seed_seq_state=bit_gen.seed_seq.state, # type: ignore[attr-defined]
619+
)
620+
621+
622+
def random_generator_from_state(state: RandomGeneratorState) -> np.random.Generator:
623+
seed_seq = np.random.SeedSequence(**state.seed_seq_state)
624+
bit_generator_class = getattr(np.random, state.bit_generator_state["bit_generator"])
625+
bit_generator = bit_generator_class(seed_seq)
626+
bit_generator.state = state.bit_generator_state
627+
return np.random.Generator(bit_generator)
628+
629+
604630
def get_random_generator(
605631
seed: RandomGenerator | np.random.RandomState = None, copy: bool = True
606632
) -> np.random.Generator:
@@ -645,6 +671,10 @@ def get_random_generator(
645671
# In the former case, it will return seed, in the latter it will return
646672
# a new Generator object that has the same BitGenerator. This would potentially
647673
# make the new generator be shared across many users. To avoid this, we
648-
# deepcopy by default.
674+
# copy by default.
675+
# Also, because of https://github.com/numpy/numpy/issues/27727, we can't use
676+
# deepcopy. We must rebuild a Generator without losing the SeedSequence information
677+
if isinstance(seed, np.random.Generator | np.random.BitGenerator):
678+
return random_generator_from_state(get_state_from_generator(seed))
649679
seed = deepcopy(seed)
650680
return np.random.default_rng(seed)

tests/step_methods/test_metropolis.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
import warnings
1616

17-
from copy import deepcopy
18-
1917
import arviz as az
2018
import numpy as np
2119
import numpy.testing as npt
@@ -406,8 +404,7 @@ def test_sampling_state(step_method, model_fn):
406404
sampler = step_method(model.value_vars)
407405
if hasattr(sampler, "link_population"):
408406
sampler.link_population([initial_point] * 100, 0)
409-
sampler_orig = deepcopy(sampler)
410-
state_orig = sampler_orig.sampling_state
407+
state_orig = sampler.sampling_state
411408

412409
sample1, stat1 = sampler.step(initial_point)
413410
sampler.tune = False

0 commit comments

Comments
 (0)