From 81abc36890302370dd3f0c36dd4fa51f1dc78b82 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 10 May 2023 20:39:04 +0300 Subject: [PATCH 1/3] ENH: control a random stream from numpy or pytorch Add a global switch to select the random stream origin: pytorch or numpy. If it's numpy, `tnp.random.uniform` just call `numpy.random.uniform` etc. --- torch_np/random.py | 30 +++++++++++++++++++++++++ torch_np/tests/test_random.py | 42 +++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 torch_np/tests/test_random.py diff --git a/torch_np/random.py b/torch_np/random.py index 381f4d2e..bbf45cac 100644 --- a/torch_np/random.py +++ b/torch_np/random.py @@ -7,6 +7,7 @@ """ from __future__ import annotations +import functools from math import sqrt from typing import Optional @@ -29,14 +30,36 @@ "randint", "shuffle", "uniform", + "RANDOM_STREAM", ] +RANDOM_STREAM = "pytorch" + + +def deco_stream(func): + @functools.wraps(func) + def inner(*args, **kwds): + if RANDOM_STREAM == "pytorch": + return func(*args, **kwds) + elif RANDOM_STREAM == "numpy": + from numpy import random as nr + + f = getattr(nr, func.__name__) + return f(*args, **kwds) + else: + raise ValueError(f"RANDOM_STREAM={RANDOM_STREAM} not understood.") + + return inner + + +@deco_stream def seed(seed=None): if seed is not None: torch.random.manual_seed(seed) +@deco_stream def random_sample(size=None): if size is None: size = () @@ -44,6 +67,7 @@ def random_sample(size=None): return array_or_scalar(values, return_scalar=size is None) +@deco_stream def rand(*size): return random_sample(size) @@ -52,6 +76,7 @@ def rand(*size): random = random_sample +@deco_stream def uniform(low=0.0, high=1.0, size=None): if size is None: size = () @@ -59,11 +84,13 @@ def uniform(low=0.0, high=1.0, size=None): return array_or_scalar(values, return_scalar=size is None) +@deco_stream def randn(*size): values = torch.randn(size, dtype=_default_dtype) return array_or_scalar(values, return_scalar=size is None) +@deco_stream def normal(loc=0.0, scale=1.0, size=None): if size is None: size = () @@ -71,6 +98,7 @@ def normal(loc=0.0, scale=1.0, size=None): return array_or_scalar(values, return_scalar=size is None) +@deco_stream @normalizer def shuffle(x: ArrayLike): perm = torch.randperm(x.shape[0]) @@ -78,6 +106,7 @@ def shuffle(x: ArrayLike): x.copy_(xp) +@deco_stream def randint(low, high=None, size=None): if size is None: size = () @@ -89,6 +118,7 @@ def randint(low, high=None, size=None): return array_or_scalar(values, int, return_scalar=size is None) +@deco_stream @normalizer def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): diff --git a/torch_np/tests/test_random.py b/torch_np/tests/test_random.py new file mode 100644 index 00000000..dc60fbfe --- /dev/null +++ b/torch_np/tests/test_random.py @@ -0,0 +1,42 @@ +"""Light smoke test switching between numpy to pytorch random streams. +""" +import pytest + +import torch_np as tnp +from torch_np.testing import assert_equal + + +def test_uniform(): + r = tnp.random.uniform(0, 1, size=10) + + +def test_shuffle(): + x = tnp.arange(10) + tnp.random.shuffle(x) + + +def test_numpy_global(): + tnp.random.RANDOM_STREAM = "numpy" + tnp.random.seed(12345) + x = tnp.random.uniform(0, 1, size=11) + + # check that the stream is identical to numpy's + import numpy as _np + + _np.random.seed(12345) + x_np = _np.random.uniform(0, 1, size=11) + + assert_equal(x, tnp.asarray(x_np)) + + # switch to the pytorch stream, variates differ + tnp.random.RANDOM_STREAM = "pytorch" + tnp.random.seed(12345) + + x_1 = tnp.random.uniform(0, 1, size=11) + assert not (x_1 == x).all() + + +def test_wrong_global(): + tnp.random.RANDOM_STREAM = "oops" + with pytest.raises(ValueError): + tnp.random.rand() From b0c01a717a7fde1b26e8eaa59d0d8d71ed2d8f46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 10 May 2023 23:21:47 +0300 Subject: [PATCH 2/3] MAINT: make USE_NUNMPY_RANDOM boolean flag --- torch_np/random.py | 10 +++++----- torch_np/tests/test_random.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_np/random.py b/torch_np/random.py index bbf45cac..c14a3e06 100644 --- a/torch_np/random.py +++ b/torch_np/random.py @@ -30,25 +30,25 @@ "randint", "shuffle", "uniform", - "RANDOM_STREAM", + "USE_NUMPY_RANDOM", ] -RANDOM_STREAM = "pytorch" +USE_NUMPY_RANDOM = False def deco_stream(func): @functools.wraps(func) def inner(*args, **kwds): - if RANDOM_STREAM == "pytorch": + if USE_NUMPY_RANDOM is False: return func(*args, **kwds) - elif RANDOM_STREAM == "numpy": + elif USE_NUMPY_RANDOM is True: from numpy import random as nr f = getattr(nr, func.__name__) return f(*args, **kwds) else: - raise ValueError(f"RANDOM_STREAM={RANDOM_STREAM} not understood.") + raise ValueError(f"USE_NUMPY_RANDOM={USE_NUMPY_RANDOM} not understood.") return inner diff --git a/torch_np/tests/test_random.py b/torch_np/tests/test_random.py index dc60fbfe..39911509 100644 --- a/torch_np/tests/test_random.py +++ b/torch_np/tests/test_random.py @@ -16,7 +16,7 @@ def test_shuffle(): def test_numpy_global(): - tnp.random.RANDOM_STREAM = "numpy" + tnp.random.USE_NUMPY_RANDOM = True tnp.random.seed(12345) x = tnp.random.uniform(0, 1, size=11) @@ -29,7 +29,7 @@ def test_numpy_global(): assert_equal(x, tnp.asarray(x_np)) # switch to the pytorch stream, variates differ - tnp.random.RANDOM_STREAM = "pytorch" + tnp.random.USE_NUMPY_RANDOM = False tnp.random.seed(12345) x_1 = tnp.random.uniform(0, 1, size=11) @@ -37,6 +37,6 @@ def test_numpy_global(): def test_wrong_global(): - tnp.random.RANDOM_STREAM = "oops" + tnp.random.USE_NUMPY_RANDOM = "oops" with pytest.raises(ValueError): tnp.random.rand() From 1bdfc6598f7224b8f88953fecf25bbc6c27ab54b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 10 May 2023 23:34:10 +0300 Subject: [PATCH 3/3] BUG: restore the global state in tests --- torch_np/tests/test_random.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torch_np/tests/test_random.py b/torch_np/tests/test_random.py index 39911509..f2614f17 100644 --- a/torch_np/tests/test_random.py +++ b/torch_np/tests/test_random.py @@ -37,6 +37,12 @@ def test_numpy_global(): def test_wrong_global(): - tnp.random.USE_NUMPY_RANDOM = "oops" - with pytest.raises(ValueError): - tnp.random.rand() + try: + oldstate = tnp.random.USE_NUMPY_RANDOM + + tnp.random.USE_NUMPY_RANDOM = "oops" + with pytest.raises(ValueError): + tnp.random.rand() + + finally: + tnp.random.USE_NUMPY_RANDOM = oldstate