Skip to content

Commit a9d34bf

Browse files
authored
Merge pull request #25 from Quansight-Labs/random2
add a minimal `np.random` analog
2 parents 186be13 + f3ac7f1 commit a9d34bf

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
2+
from . import random
23
from ._binary_ufuncs import *
34
from ._detail._scalar_types import *
45
from ._detail._util import AxisError, UFuncTypeError

torch_np/random.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""Wrapper to mimic (parts of) np.random API surface.
2+
3+
NumPy has strict guarantees on reproducibility etc; here we don't give any.
4+
5+
Q: default dtype is float64 in numpy
6+
7+
"""
8+
from math import sqrt
9+
10+
import torch
11+
12+
from . import asarray
13+
from ._detail import _util
14+
from ._detail._scalar_types import default_float_type as _default_float_type
15+
16+
_default_dtype = _default_float_type.torch_dtype
17+
18+
__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn", "normal"]
19+
20+
21+
def array_or_scalar(values, py_type=float):
22+
if values.numel() == 1:
23+
return py_type(values.item())
24+
else:
25+
return asarray(values)
26+
27+
28+
def seed(seed=None):
29+
if seed is not None:
30+
torch.random.manual_seed()
31+
32+
33+
def random_sample(size=None):
34+
if size is None:
35+
size = ()
36+
values = torch.empty(size, dtype=_default_dtype).uniform_()
37+
return array_or_scalar(values)
38+
39+
40+
def rand(*size):
41+
return random_sample(size)
42+
43+
44+
sample = random_sample
45+
random = random_sample
46+
47+
48+
def uniform(low=0.0, high=1.0, size=None):
49+
if size is None:
50+
size = ()
51+
values = torch.empty(size, dtype=_default_dtype).uniform_(low, high)
52+
return array_or_scalar(values)
53+
54+
55+
def randn(*size):
56+
values = torch.randn(size, dtype=_default_dtype)
57+
return array_or_scalar(values)
58+
59+
60+
def normal(loc=0.0, scale=1.0, size=None):
61+
if size is None:
62+
size = ()
63+
values = torch.empty(size, dtype=_default_dtype).normal_(loc, scale)
64+
return array_or_scalar(values)
65+
66+
67+
def shuffle(x):
68+
x_tensor = asarray(x).get()
69+
perm = torch.randperm(x_tensor.shape[0])
70+
xp = x_tensor[perm]
71+
x_tensor.copy_(xp)
72+
73+
74+
def randint(low, high=None, size=None):
75+
if size is None:
76+
size = ()
77+
if not isinstance(size, (tuple, list)):
78+
size = (size,)
79+
if high is None:
80+
low, high = 0, low
81+
values = torch.randint(low, high, size=size)
82+
return array_or_scalar(values)
83+
84+
85+
def choice(a, size=None, replace=True, p=None):
86+
# https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
87+
if isinstance(a, int):
88+
a_tensor = torch.arange(a)
89+
else:
90+
a_tensor = asarray(a).get()
91+
92+
# number of draws
93+
if size is None:
94+
num_el = 1
95+
elif _util.is_sequence(size):
96+
num_el = 1
97+
for el in size:
98+
num_el *= el
99+
else:
100+
num_el = size
101+
102+
# prepare the probabilities
103+
if p is None:
104+
p_tensor = torch.ones_like(a_tensor) / a_tensor.shape[0]
105+
else:
106+
p_tensor = asarray(p, dtype=float).get()
107+
108+
# cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
109+
atol = sqrt(torch.finfo(torch.float64).eps)
110+
if abs(p_tensor.sum() - 1.0) > atol:
111+
raise ValueError("probabilities do not sum to 1.")
112+
113+
# actually sample
114+
indices = torch.multinomial(p_tensor, num_el, replacement=replace)
115+
116+
if _util.is_sequence(size):
117+
indices = indices.reshape(size)
118+
119+
samples = a_tensor[indices]
120+
121+
return asarray(samples)

0 commit comments

Comments
 (0)