7
7
"""
8
8
from __future__ import annotations
9
9
10
+ import functools
10
11
from math import sqrt
11
12
from typing import Optional
12
13
29
30
"randint" ,
30
31
"shuffle" ,
31
32
"uniform" ,
33
+ "USE_NUMPY_RANDOM" ,
32
34
]
33
35
34
36
37
+ USE_NUMPY_RANDOM = False
38
+
39
+
40
+ def deco_stream (func ):
41
+ @functools .wraps (func )
42
+ def inner (* args , ** kwds ):
43
+ if USE_NUMPY_RANDOM is False :
44
+ return func (* args , ** kwds )
45
+ elif USE_NUMPY_RANDOM is True :
46
+ from numpy import random as nr
47
+
48
+ f = getattr (nr , func .__name__ )
49
+ return f (* args , ** kwds )
50
+ else :
51
+ raise ValueError (f"USE_NUMPY_RANDOM={ USE_NUMPY_RANDOM } not understood." )
52
+
53
+ return inner
54
+
55
+
56
+ @deco_stream
35
57
def seed (seed = None ):
36
58
if seed is not None :
37
59
torch .random .manual_seed (seed )
38
60
39
61
62
+ @deco_stream
40
63
def random_sample (size = None ):
41
64
if size is None :
42
65
size = ()
43
66
values = torch .empty (size , dtype = _default_dtype ).uniform_ ()
44
67
return array_or_scalar (values , return_scalar = size is None )
45
68
46
69
70
+ @deco_stream
47
71
def rand (* size ):
48
72
return random_sample (size )
49
73
@@ -52,32 +76,37 @@ def rand(*size):
52
76
random = random_sample
53
77
54
78
79
+ @deco_stream
55
80
def uniform (low = 0.0 , high = 1.0 , size = None ):
56
81
if size is None :
57
82
size = ()
58
83
values = torch .empty (size , dtype = _default_dtype ).uniform_ (low , high )
59
84
return array_or_scalar (values , return_scalar = size is None )
60
85
61
86
87
+ @deco_stream
62
88
def randn (* size ):
63
89
values = torch .randn (size , dtype = _default_dtype )
64
90
return array_or_scalar (values , return_scalar = size is None )
65
91
66
92
93
+ @deco_stream
67
94
def normal (loc = 0.0 , scale = 1.0 , size = None ):
68
95
if size is None :
69
96
size = ()
70
97
values = torch .empty (size , dtype = _default_dtype ).normal_ (loc , scale )
71
98
return array_or_scalar (values , return_scalar = size is None )
72
99
73
100
101
+ @deco_stream
74
102
@normalizer
75
103
def shuffle (x : ArrayLike ):
76
104
perm = torch .randperm (x .shape [0 ])
77
105
xp = x [perm ]
78
106
x .copy_ (xp )
79
107
80
108
109
+ @deco_stream
81
110
def randint (low , high = None , size = None ):
82
111
if size is None :
83
112
size = ()
@@ -89,6 +118,7 @@ def randint(low, high=None, size=None):
89
118
return array_or_scalar (values , int , return_scalar = size is None )
90
119
91
120
121
+ @deco_stream
92
122
@normalizer
93
123
def choice (a : ArrayLike , size = None , replace = True , p : Optional [ArrayLike ] = None ):
94
124
0 commit comments