18
18
__all__ = ["seed" , "random_sample" , "sample" , "random" , "rand" , "randn" , "normal" ]
19
19
20
20
21
+ def array_or_scalar (values , py_type = float ):
22
+ if values .numel () == 1 :
23
+ return py_type (values .item ())
24
+ else :
25
+ return asarray (values )
26
+
27
+
21
28
def seed (seed = None ):
22
29
if seed is not None :
23
30
torch .random .manual_seed ()
24
31
25
32
26
33
def random_sample (size = None ):
27
34
if size is None :
28
- values = torch .rand (())
29
- return float (values )
30
- else :
31
- values = torch .rand (size ).to (_default_dtype )
32
- return asarray (values )
35
+ size = ()
36
+ values = torch .empty (size , dtype = _default_dtype ).uniform_ ()
37
+ return array_or_scalar (values )
33
38
34
39
35
40
def rand (* size ):
@@ -42,25 +47,21 @@ def rand(*size):
42
47
43
48
def uniform (low = 0.0 , high = 1.0 , size = None ):
44
49
if size is None :
45
- values = torch .rand (())
46
- return float (low + (high - low ) * values )
47
- else :
48
- values = torch .rand (size ).to (_default_dtype )
49
- return asarray (low + (high - low ) * values )
50
+ size = ()
51
+ values = torch .empty (size , dtype = _default_dtype ).uniform_ (low , high )
52
+ return array_or_scalar (values )
50
53
51
54
52
55
def randn (* size ):
53
- if size == ():
54
- return float (torch .randn (size ))
55
- else :
56
- values = torch .randn (* size ).to (_default_dtypes )
57
- return asarray (values )
56
+ values = torch .randn (size , dtype = _default_dtype )
57
+ return array_or_scalar (values )
58
58
59
59
60
60
def normal (loc = 0.0 , scale = 1.0 , size = None ):
61
61
if size is None :
62
62
size = ()
63
- return loc + scale * randn (* size ).to (_default_dtype )
63
+ values = torch .empty (size , dtype = _default_dtype ).normal_ (loc , scale )
64
+ return array_or_scalar (values )
64
65
65
66
66
67
def shuffle (x ):
@@ -76,9 +77,9 @@ def randint(low, high=None, size=None):
76
77
if not isinstance (size , (tuple , list )):
77
78
size = (size ,)
78
79
if high is None :
79
- low , high = 0 , high
80
+ low , high = 0 , low
80
81
values = torch .randint (low , high , size = size )
81
- return asarray (values )
82
+ return array_or_scalar (values )
82
83
83
84
84
85
def choice (a , size = None , replace = True , p = None ):
0 commit comments