Skip to content

Commit 789ec48

Browse files
authored
ENH: control a random stream to be from either numpy or pytorch (#135)
1 parent b975238 commit 789ec48

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

torch_np/random.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
from __future__ import annotations
99

10+
import functools
1011
from math import sqrt
1112
from typing import Optional
1213

@@ -29,21 +30,44 @@
2930
"randint",
3031
"shuffle",
3132
"uniform",
33+
"USE_NUMPY_RANDOM",
3234
]
3335

3436

37+
USE_NUMPY_RANDOM = False
38+
39+
40+
def deco_stream(func):
41+
@functools.wraps(func)
42+
def inner(*args, **kwds):
43+
if USE_NUMPY_RANDOM is False:
44+
return func(*args, **kwds)
45+
elif USE_NUMPY_RANDOM is True:
46+
from numpy import random as nr
47+
48+
f = getattr(nr, func.__name__)
49+
return f(*args, **kwds)
50+
else:
51+
raise ValueError(f"USE_NUMPY_RANDOM={USE_NUMPY_RANDOM} not understood.")
52+
53+
return inner
54+
55+
56+
@deco_stream
3557
def seed(seed=None):
3658
if seed is not None:
3759
torch.random.manual_seed(seed)
3860

3961

62+
@deco_stream
4063
def random_sample(size=None):
4164
if size is None:
4265
size = ()
4366
values = torch.empty(size, dtype=_default_dtype).uniform_()
4467
return array_or_scalar(values, return_scalar=size is None)
4568

4669

70+
@deco_stream
4771
def rand(*size):
4872
return random_sample(size)
4973

@@ -52,32 +76,37 @@ def rand(*size):
5276
random = random_sample
5377

5478

79+
@deco_stream
5580
def uniform(low=0.0, high=1.0, size=None):
5681
if size is None:
5782
size = ()
5883
values = torch.empty(size, dtype=_default_dtype).uniform_(low, high)
5984
return array_or_scalar(values, return_scalar=size is None)
6085

6186

87+
@deco_stream
6288
def randn(*size):
6389
values = torch.randn(size, dtype=_default_dtype)
6490
return array_or_scalar(values, return_scalar=size is None)
6591

6692

93+
@deco_stream
6794
def normal(loc=0.0, scale=1.0, size=None):
6895
if size is None:
6996
size = ()
7097
values = torch.empty(size, dtype=_default_dtype).normal_(loc, scale)
7198
return array_or_scalar(values, return_scalar=size is None)
7299

73100

101+
@deco_stream
74102
@normalizer
75103
def shuffle(x: ArrayLike):
76104
perm = torch.randperm(x.shape[0])
77105
xp = x[perm]
78106
x.copy_(xp)
79107

80108

109+
@deco_stream
81110
def randint(low, high=None, size=None):
82111
if size is None:
83112
size = ()
@@ -89,6 +118,7 @@ def randint(low, high=None, size=None):
89118
return array_or_scalar(values, int, return_scalar=size is None)
90119

91120

121+
@deco_stream
92122
@normalizer
93123
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
94124

torch_np/tests/test_random.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Light smoke test switching between numpy to pytorch random streams.
2+
"""
3+
import pytest
4+
5+
import torch_np as tnp
6+
from torch_np.testing import assert_equal
7+
8+
9+
def test_uniform():
10+
r = tnp.random.uniform(0, 1, size=10)
11+
12+
13+
def test_shuffle():
14+
x = tnp.arange(10)
15+
tnp.random.shuffle(x)
16+
17+
18+
def test_numpy_global():
19+
tnp.random.USE_NUMPY_RANDOM = True
20+
tnp.random.seed(12345)
21+
x = tnp.random.uniform(0, 1, size=11)
22+
23+
# check that the stream is identical to numpy's
24+
import numpy as _np
25+
26+
_np.random.seed(12345)
27+
x_np = _np.random.uniform(0, 1, size=11)
28+
29+
assert_equal(x, tnp.asarray(x_np))
30+
31+
# switch to the pytorch stream, variates differ
32+
tnp.random.USE_NUMPY_RANDOM = False
33+
tnp.random.seed(12345)
34+
35+
x_1 = tnp.random.uniform(0, 1, size=11)
36+
assert not (x_1 == x).all()
37+
38+
39+
def test_wrong_global():
40+
try:
41+
oldstate = tnp.random.USE_NUMPY_RANDOM
42+
43+
tnp.random.USE_NUMPY_RANDOM = "oops"
44+
with pytest.raises(ValueError):
45+
tnp.random.rand()
46+
47+
finally:
48+
tnp.random.USE_NUMPY_RANDOM = oldstate

0 commit comments

Comments
 (0)