diff --git a/torch_np/random.py b/torch_np/random.py index 381f4d2e..c14a3e06 100644 --- a/torch_np/random.py +++ b/torch_np/random.py @@ -7,6 +7,7 @@ """ from __future__ import annotations +import functools from math import sqrt from typing import Optional @@ -29,14 +30,36 @@ "randint", "shuffle", "uniform", + "USE_NUMPY_RANDOM", ] +USE_NUMPY_RANDOM = False + + +def deco_stream(func): + @functools.wraps(func) + def inner(*args, **kwds): + if USE_NUMPY_RANDOM is False: + return func(*args, **kwds) + elif USE_NUMPY_RANDOM is True: + from numpy import random as nr + + f = getattr(nr, func.__name__) + return f(*args, **kwds) + else: + raise ValueError(f"USE_NUMPY_RANDOM={USE_NUMPY_RANDOM} not understood.") + + return inner + + +@deco_stream def seed(seed=None): if seed is not None: torch.random.manual_seed(seed) +@deco_stream def random_sample(size=None): if size is None: size = () @@ -44,6 +67,7 @@ def random_sample(size=None): return array_or_scalar(values, return_scalar=size is None) +@deco_stream def rand(*size): return random_sample(size) @@ -52,6 +76,7 @@ def rand(*size): random = random_sample +@deco_stream def uniform(low=0.0, high=1.0, size=None): if size is None: size = () @@ -59,11 +84,13 @@ def uniform(low=0.0, high=1.0, size=None): return array_or_scalar(values, return_scalar=size is None) +@deco_stream def randn(*size): values = torch.randn(size, dtype=_default_dtype) return array_or_scalar(values, return_scalar=size is None) +@deco_stream def normal(loc=0.0, scale=1.0, size=None): if size is None: size = () @@ -71,6 +98,7 @@ def normal(loc=0.0, scale=1.0, size=None): return array_or_scalar(values, return_scalar=size is None) +@deco_stream @normalizer def shuffle(x: ArrayLike): perm = torch.randperm(x.shape[0]) @@ -78,6 +106,7 @@ def shuffle(x: ArrayLike): x.copy_(xp) +@deco_stream def randint(low, high=None, size=None): if size is None: size = () @@ -89,6 +118,7 @@ def randint(low, high=None, size=None): return array_or_scalar(values, int, return_scalar=size is None) +@deco_stream @normalizer def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): diff --git a/torch_np/tests/test_random.py b/torch_np/tests/test_random.py new file mode 100644 index 00000000..f2614f17 --- /dev/null +++ b/torch_np/tests/test_random.py @@ -0,0 +1,48 @@ +"""Light smoke test switching between numpy to pytorch random streams. +""" +import pytest + +import torch_np as tnp +from torch_np.testing import assert_equal + + +def test_uniform(): + r = tnp.random.uniform(0, 1, size=10) + + +def test_shuffle(): + x = tnp.arange(10) + tnp.random.shuffle(x) + + +def test_numpy_global(): + tnp.random.USE_NUMPY_RANDOM = True + tnp.random.seed(12345) + x = tnp.random.uniform(0, 1, size=11) + + # check that the stream is identical to numpy's + import numpy as _np + + _np.random.seed(12345) + x_np = _np.random.uniform(0, 1, size=11) + + assert_equal(x, tnp.asarray(x_np)) + + # switch to the pytorch stream, variates differ + tnp.random.USE_NUMPY_RANDOM = False + tnp.random.seed(12345) + + x_1 = tnp.random.uniform(0, 1, size=11) + assert not (x_1 == x).all() + + +def test_wrong_global(): + try: + oldstate = tnp.random.USE_NUMPY_RANDOM + + tnp.random.USE_NUMPY_RANDOM = "oops" + with pytest.raises(ValueError): + tnp.random.rand() + + finally: + tnp.random.USE_NUMPY_RANDOM = oldstate