Skip to content

Commit 94211f6

Browse files
committed
WIP: start adding a minimal np.random
1 parent 7cf3344 commit 94211f6

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

torch_np/random.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Wrapper to mimic (parts of) np.random API surface.
2+
3+
NumPy has strict guarantees on reproducibility etc; here we don't give any.
4+
5+
Q: default dtype is float64 in numpy
6+
7+
"""
8+
import torch
9+
10+
from . import asarray, arange, ones_like, sqrt, finfo
11+
from ._detail._scalar_types import default_float_type as _default_float_type
12+
from ._detail import _util
13+
14+
_default_dtype = _default_float_type.torch_dtype
15+
16+
__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn", "normal"]
17+
18+
19+
def seed(seed=None):
20+
if seed is not None:
21+
torch.random.manual_seed()
22+
23+
24+
def random_sample(size=None):
25+
if size is None:
26+
values = torch.rand(())
27+
return float(values)
28+
else:
29+
values = torch.rand(size).to(_default_dtype)
30+
return asarray(values)
31+
32+
33+
def rand(*size):
34+
return random_sample(size)
35+
36+
37+
sample = random_sample
38+
random = random_sample
39+
40+
41+
def uniform(low=0.0, high=1.0, size=None):
42+
if size is None:
43+
values = torch.rand(())
44+
return float(low + (high - low) * values)
45+
else:
46+
values = torch.rand(size).to(_default_dtype)
47+
return asarray(low + (high - low) * values)
48+
49+
50+
def randn(*size):
51+
if size == ():
52+
return float(torch.randn(size))
53+
else:
54+
values = torch.randn(*size).to(_default_dtypes)
55+
return asarray(values)
56+
57+
58+
def normal(loc=0.0, scale=1.0, size=None):
59+
if size is None:
60+
size = ()
61+
return loc + scale * randn(*size).to(_default_dtype)
62+
63+
64+
def shuffle(x):
65+
x_tensor = asarray(x).get()
66+
perm = torch.randperm(x_tensor.shape[0])
67+
xp = x_tensor[perm]
68+
x_tensor.copy_(xp)
69+
70+
71+
def randint(low, high=None, size=None):
72+
if size is None:
73+
size = ()
74+
if not isinstance(size, (tuple, list)):
75+
size = (size,)
76+
if high is None:
77+
low, high = 0, high
78+
values = torch.randint(low, high, size=size)
79+
return asarray(values)
80+
81+
82+
def choice(a, size=None, replace=True, p=None):
83+
# https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
84+
if isinstance(a, int):
85+
a = arange(a)
86+
a_tensor = asarray(a).get()
87+
88+
# number of draws
89+
if size is None:
90+
num_el = 1
91+
elif _util.is_sequence(size):
92+
num_el = 1
93+
for el in size:
94+
num_el *= el
95+
else:
96+
num_el = size
97+
98+
# prepare the probabilities
99+
if p is None:
100+
p_tensor = torch.ones_like(a_tensor) / a_tensor.shape[0]
101+
else:
102+
p_tensor = asarray(p, dtype=float).get()
103+
104+
# cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
105+
atol = sqrt(finfo('float64').eps)
106+
if abs(p_tensor.sum() - 1.) > atol:
107+
raise ValueError("probabilities do not sum to 1.")
108+
109+
# actually sample
110+
indices = torch.multinomial(p_tensor, num_el, replacement=replace)
111+
112+
if _util.is_sequence(size):
113+
indices = indices.reshape(size)
114+
115+
samples = a_tensor[indices]
116+
117+
return asarray(samples)

0 commit comments

Comments
 (0)