Skip to content

Commit 702889b

Browse files
committed
Remove unnecessary handling of no longer supported RandomState
1 parent c822a8e commit 702889b

File tree

2 files changed

+4
-22
lines changed

2 files changed

+4
-22
lines changed

pytensor/link/jax/linker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22

3-
from numpy.random import Generator, RandomState
3+
from numpy.random import Generator
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
66
from pytensor.link.basic import JITLinker
@@ -21,7 +21,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2121

2222
# Replace any shared RNG inputs so that their values can be updated in place
2323
# without affecting the original RNG container. This is necessary because
24-
# JAX does not accept RandomState/Generators as inputs, and they will have to
24+
# JAX does not accept Generators as inputs, and they will have to
2525
# be tipyfied
2626
if shared_rng_inputs:
2727
warnings.warn(
@@ -79,7 +79,7 @@ def create_thunk_inputs(self, storage_map):
7979
thunk_inputs = []
8080
for n in self.fgraph.inputs:
8181
sinput = storage_map[n]
82-
if isinstance(sinput[0], RandomState | Generator):
82+
if isinstance(sinput[0], Generator):
8383
new_value = jax_typify(
8484
sinput[0], dtype=getattr(sinput[0], "dtype", None)
8585
)

pytensor/link/numba/linker.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,4 @@ def jit_compile(self, fn):
1616
return jitted_fn
1717

1818
def create_thunk_inputs(self, storage_map):
19-
from numpy.random import RandomState
20-
21-
from pytensor.link.numba.dispatch import numba_typify
22-
23-
thunk_inputs = []
24-
for n in self.fgraph.inputs:
25-
sinput = storage_map[n]
26-
if isinstance(sinput[0], RandomState):
27-
new_value = numba_typify(
28-
sinput[0], dtype=getattr(sinput[0], "dtype", None)
29-
)
30-
# We need to remove the reference-based connection to the
31-
# original `RandomState`/shared variable's storage, because
32-
# subsequent attempts to use the same shared variable within
33-
# other non-Numba-fied graphs will have problems.
34-
sinput = [new_value]
35-
thunk_inputs.append(sinput)
36-
37-
return thunk_inputs
19+
return [storage_map[n] for n in self.fgraph.inputs]

0 commit comments

Comments
 (0)