From b7bed61207908799cf2b78c9bfc8cea04681294c Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 18 Mar 2025 07:49:58 +0100 Subject: [PATCH] Remove unnecessary handling of no longer supported RandomState --- .../extending_pytensor_solution_1.py | 6 +++--- doc/library/d3viz/index.ipynb | 2 +- doc/library/d3viz/index.rst | 2 +- doc/optimizations.rst | 2 +- pytensor/compile/monitormode.py | 5 +---- pytensor/compile/nanguardmode.py | 2 +- pytensor/link/jax/linker.py | 6 +++--- pytensor/link/numba/linker.py | 20 +------------------ pytensor/tensor/random/type.py | 17 ++++++++-------- tests/unittest_tools.py | 3 +-- 10 files changed, 22 insertions(+), 43 deletions(-) diff --git a/doc/extending/extending_pytensor_solution_1.py b/doc/extending/extending_pytensor_solution_1.py index 45329c73d6..ff470ec420 100644 --- a/doc/extending/extending_pytensor_solution_1.py +++ b/doc/extending/extending_pytensor_solution_1.py @@ -118,7 +118,7 @@ def setup_method(self): self.op_class = SumDiffOp def test_perform(self): - rng = np.random.RandomState(43) + rng = np.random.default_rng(43) x = matrix() y = matrix() f = pytensor.function([x, y], self.op_class()(x, y)) @@ -128,7 +128,7 @@ def test_perform(self): assert np.allclose([x_val + y_val, x_val - y_val], out) def test_gradient(self): - rng = np.random.RandomState(43) + rng = np.random.default_rng(43) def output_0(x, y): return self.op_class()(x, y)[0] @@ -150,7 +150,7 @@ def output_1(x, y): ) def test_infer_shape(self): - rng = np.random.RandomState(43) + rng = np.random.default_rng(43) x = dmatrix() y = dmatrix() diff --git a/doc/library/d3viz/index.ipynb b/doc/library/d3viz/index.ipynb index 778647daa3..5abd13ec01 100644 --- a/doc/library/d3viz/index.ipynb +++ b/doc/library/d3viz/index.ipynb @@ -95,7 +95,7 @@ "noutputs = 10\n", "nhiddens = 50\n", "\n", - "rng = np.random.RandomState(0)\n", + "rng = np.random.default_rng(0)\n", "x = pt.dmatrix('x')\n", "wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n", "bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n", diff --git a/doc/library/d3viz/index.rst b/doc/library/d3viz/index.rst index d411f874e8..f0727318b0 100644 --- a/doc/library/d3viz/index.rst +++ b/doc/library/d3viz/index.rst @@ -58,7 +58,7 @@ hidden layer and a softmax output layer. noutputs = 10 nhiddens = 50 - rng = np.random.RandomState(0) + rng = np.random.default_rng(0) x = pt.dmatrix('x') wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True) bh = pytensor.shared(np.zeros(nhiddens), borrow=True) diff --git a/doc/optimizations.rst b/doc/optimizations.rst index ed7011b8f2..7c1a0f8b15 100644 --- a/doc/optimizations.rst +++ b/doc/optimizations.rst @@ -239,7 +239,7 @@ Optimization o4 o3 o2 See :func:`insert_inplace_optimizer` inplace_random - Typically when a graph uses random numbers, the RandomState is stored + Typically when a graph uses random numbers, the random Generator is stored in a shared variable, used once per call and, updated after each function call. In this common case, it makes sense to update the random number generator in-place. diff --git a/pytensor/compile/monitormode.py b/pytensor/compile/monitormode.py index 8663bc8832..40c8c41dfe 100644 --- a/pytensor/compile/monitormode.py +++ b/pytensor/compile/monitormode.py @@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn): from pytensor.printing import debugprint for output in fn.outputs: - if ( - not isinstance(output[0], np.random.RandomState | np.random.Generator) - and np.isnan(output[0]).any() - ): + if not isinstance(output[0], np.random.Generator) and np.isnan(output[0]).any(): print("*** NaN detected ***") # noqa: T201 debugprint(node) print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201 diff --git a/pytensor/compile/nanguardmode.py b/pytensor/compile/nanguardmode.py index 32a06757d1..463d058155 100644 --- a/pytensor/compile/nanguardmode.py +++ b/pytensor/compile/nanguardmode.py @@ -34,7 +34,7 @@ def _is_numeric_value(arr, var): if isinstance(arr, _cdata_type): return False - elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator): + elif isinstance(arr, np.random.Generator): return False elif var is not None and isinstance(var.type, RandomType): return False diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 06370b4514..80bb48305f 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -1,6 +1,6 @@ import warnings -from numpy.random import Generator, RandomState +from numpy.random import Generator from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.link.basic import JITLinker @@ -21,7 +21,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): # Replace any shared RNG inputs so that their values can be updated in place # without affecting the original RNG container. This is necessary because - # JAX does not accept RandomState/Generators as inputs, and they will have to + # JAX does not accept Generators as inputs, and they will have to # be tipyfied if shared_rng_inputs: warnings.warn( @@ -79,7 +79,7 @@ def create_thunk_inputs(self, storage_map): thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] - if isinstance(sinput[0], RandomState | Generator): + if isinstance(sinput[0], Generator): new_value = jax_typify( sinput[0], dtype=getattr(sinput[0], "dtype", None) ) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..59dc81e1b0 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -16,22 +16,4 @@ def jit_compile(self, fn): return jitted_fn def create_thunk_inputs(self, storage_map): - from numpy.random import RandomState - - from pytensor.link.numba.dispatch import numba_typify - - thunk_inputs = [] - for n in self.fgraph.inputs: - sinput = storage_map[n] - if isinstance(sinput[0], RandomState): - new_value = numba_typify( - sinput[0], dtype=getattr(sinput[0], "dtype", None) - ) - # We need to remove the reference-based connection to the - # original `RandomState`/shared variable's storage, because - # subsequent attempts to use the same shared variable within - # other non-Numba-fied graphs will have problems. - sinput = [new_value] - thunk_inputs.append(sinput) - - return thunk_inputs + return [storage_map[n] for n in self.fgraph.inputs] diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index df8e3b691d..107dd4c41a 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -1,12 +1,13 @@ from typing import TypeVar import numpy as np +from numpy.random import Generator import pytensor from pytensor.graph.type import Type -T = TypeVar("T", np.random.RandomState, np.random.Generator) +T = TypeVar("T") gen_states_keys = { @@ -24,14 +25,10 @@ class RandomType(Type[T]): - r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""" - - @staticmethod - def may_share_memory(a: T, b: T): - return a._bit_generator is b._bit_generator # type: ignore[attr-defined] + r"""A Type wrapper for `numpy.random.Generator.""" -class RandomGeneratorType(RandomType[np.random.Generator]): +class RandomGeneratorType(RandomType[Generator]): r"""A Type wrapper for `numpy.random.Generator`. The reason this exists (and `Generic` doesn't suffice) is that @@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]): def __repr__(self): return "RandomGeneratorType" + @staticmethod + def may_share_memory(a: Generator, b: Generator): + return a._bit_generator is b._bit_generator # type: ignore[attr-defined] + def filter(self, data, strict=False, allow_downcast=None): """ XXX: This doesn't convert `data` to the same type of underlying RNG type @@ -58,7 +59,7 @@ def filter(self, data, strict=False, allow_downcast=None): `Type.filter`, we need to have it here to avoid surprising circular dependencies in sub-classes. """ - if isinstance(data, np.random.Generator): + if isinstance(data, Generator): return data if not strict and isinstance(data, dict): diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index adb83fe7c0..1bdfc01410 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -27,8 +27,7 @@ def fetch_seed(pseed=None): If config.unittest.rseed is set to "random", it will seed the rng with None, which is equivalent to seeding with a random seed. - Useful for seeding RandomState or Generator objects. - >>> rng = np.random.RandomState(fetch_seed()) + Useful for seeding Generator objects. >>> rng = np.random.default_rng(fetch_seed()) """