9
9
10
10
import torch
11
11
12
- from . import asarray
13
12
from ._detail import _dtypes_impl , _util
13
+ from . import _helpers
14
+ from ._normalizations import normalizer , ArrayLike
15
+ from typing import Optional
14
16
15
17
_default_dtype = _dtypes_impl .default_float_dtype
16
18
@@ -33,7 +35,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
33
35
if return_scalar :
34
36
return py_type (values .item ())
35
37
else :
36
- return asarray (values )
38
+ return _helpers . array_from (values )
37
39
38
40
39
41
def seed (seed = None ):
@@ -75,11 +77,11 @@ def normal(loc=0.0, scale=1.0, size=None):
75
77
return array_or_scalar (values , return_scalar = size is None )
76
78
77
79
78
- def shuffle ( x ):
79
- x_tensor = asarray ( x ). get ()
80
- perm = torch .randperm (x_tensor .shape [0 ])
81
- xp = x_tensor [perm ]
82
- x_tensor .copy_ (xp )
80
+ @ normalizer
81
+ def shuffle ( x : ArrayLike ):
82
+ perm = torch .randperm (x .shape [0 ])
83
+ xp = x [perm ]
84
+ x .copy_ (xp )
83
85
84
86
85
87
def randint (low , high = None , size = None ):
@@ -93,12 +95,14 @@ def randint(low, high=None, size=None):
93
95
return array_or_scalar (values , int , return_scalar = size is None )
94
96
95
97
96
- def choice (a , size = None , replace = True , p = None ):
98
+ @normalizer
99
+ def choice (a : ArrayLike , size = None , replace = True , p : Optional [ArrayLike ]= None ):
100
+
97
101
# https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
98
- if isinstance ( a , int ) :
99
- a_tensor = torch .arange (a )
100
- else :
101
- a_tensor = asarray ( a ). get ()
102
+ if a . numel () == 1 :
103
+ a = torch .arange (a )
104
+
105
+ # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises
102
106
103
107
# number of draws
104
108
if size is None :
@@ -112,21 +116,19 @@ def choice(a, size=None, replace=True, p=None):
112
116
113
117
# prepare the probabilities
114
118
if p is None :
115
- p_tensor = torch .ones_like (a_tensor ) / a_tensor .shape [0 ]
116
- else :
117
- p_tensor = asarray (p , dtype = float ).get ()
119
+ p = torch .ones_like (a ) / a .shape [0 ]
118
120
119
121
# cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
120
122
atol = sqrt (torch .finfo (torch .float64 ).eps )
121
- if abs (p_tensor .sum () - 1.0 ) > atol :
123
+ if abs (p .sum () - 1.0 ) > atol :
122
124
raise ValueError ("probabilities do not sum to 1." )
123
125
124
126
# actually sample
125
- indices = torch .multinomial (p_tensor , num_el , replacement = replace )
127
+ indices = torch .multinomial (p , num_el , replacement = replace )
126
128
127
129
if _util .is_sequence (size ):
128
130
indices = indices .reshape (size )
129
131
130
- samples = a_tensor [indices ]
132
+ samples = a [indices ]
131
133
132
- return asarray (samples )
134
+ return _helpers . array_from (samples )
0 commit comments