Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit edda25b

Browse files
committedMar 10, 2023
MAINT: use normalizations in tnp.random
1 parent 302e862 commit edda25b

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed
 

‎torch_np/random.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import torch
1111

12-
from . import asarray
1312
from ._detail import _dtypes_impl, _util
13+
from . import _helpers
14+
from ._normalizations import normalizer, ArrayLike
15+
from typing import Optional
1416

1517
_default_dtype = _dtypes_impl.default_float_dtype
1618

@@ -33,7 +35,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
3335
if return_scalar:
3436
return py_type(values.item())
3537
else:
36-
return asarray(values)
38+
return _helpers.array_from(values)
3739

3840

3941
def seed(seed=None):
@@ -75,11 +77,11 @@ def normal(loc=0.0, scale=1.0, size=None):
7577
return array_or_scalar(values, return_scalar=size is None)
7678

7779

78-
def shuffle(x):
79-
x_tensor = asarray(x).get()
80-
perm = torch.randperm(x_tensor.shape[0])
81-
xp = x_tensor[perm]
82-
x_tensor.copy_(xp)
80+
@normalizer
81+
def shuffle(x: ArrayLike):
82+
perm = torch.randperm(x.shape[0])
83+
xp = x[perm]
84+
x.copy_(xp)
8385

8486

8587
def randint(low, high=None, size=None):
@@ -93,12 +95,14 @@ def randint(low, high=None, size=None):
9395
return array_or_scalar(values, int, return_scalar=size is None)
9496

9597

96-
def choice(a, size=None, replace=True, p=None):
98+
@normalizer
99+
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike]=None):
100+
97101
# https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
98-
if isinstance(a, int):
99-
a_tensor = torch.arange(a)
100-
else:
101-
a_tensor = asarray(a).get()
102+
if a.numel() == 1:
103+
a = torch.arange(a)
104+
105+
# TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises
102106

103107
# number of draws
104108
if size is None:
@@ -112,21 +116,19 @@ def choice(a, size=None, replace=True, p=None):
112116

113117
# prepare the probabilities
114118
if p is None:
115-
p_tensor = torch.ones_like(a_tensor) / a_tensor.shape[0]
116-
else:
117-
p_tensor = asarray(p, dtype=float).get()
119+
p = torch.ones_like(a) / a.shape[0]
118120

119121
# cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
120122
atol = sqrt(torch.finfo(torch.float64).eps)
121-
if abs(p_tensor.sum() - 1.0) > atol:
123+
if abs(p.sum() - 1.0) > atol:
122124
raise ValueError("probabilities do not sum to 1.")
123125

124126
# actually sample
125-
indices = torch.multinomial(p_tensor, num_el, replacement=replace)
127+
indices = torch.multinomial(p, num_el, replacement=replace)
126128

127129
if _util.is_sequence(size):
128130
indices = indices.reshape(size)
129131

130-
samples = a_tensor[indices]
132+
samples = a[indices]
131133

132-
return asarray(samples)
134+
return _helpers.array_from(samples)

0 commit comments

Comments
 (0)
Please sign in to comment.