Skip to content

Commit 7286b66

Browse files
committed
MAINT: random: encapsula tearray_or_scalar logic into a helper function
1 parent bd550d0 commit 7286b66

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

torch_np/random.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,23 @@
1818
__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn", "normal"]
1919

2020

21+
def array_or_scalar(values, py_type=float):
22+
if values.numel() == 1:
23+
return py_type(values.item())
24+
else:
25+
return asarray(values)
26+
27+
2128
def seed(seed=None):
2229
if seed is not None:
2330
torch.random.manual_seed()
2431

2532

2633
def random_sample(size=None):
2734
if size is None:
28-
values = torch.rand(())
29-
return float(values)
30-
else:
31-
values = torch.rand(size).to(_default_dtype)
32-
return asarray(values)
35+
size = ()
36+
values = torch.empty(size, dtype=_default_dtype).uniform_()
37+
return array_or_scalar(values)
3338

3439

3540
def rand(*size):
@@ -42,25 +47,21 @@ def rand(*size):
4247

4348
def uniform(low=0.0, high=1.0, size=None):
4449
if size is None:
45-
values = torch.rand(())
46-
return float(low + (high - low) * values)
47-
else:
48-
values = torch.rand(size).to(_default_dtype)
49-
return asarray(low + (high - low) * values)
50+
size = ()
51+
values = torch.empty(size, dtype=_default_dtype).uniform_(low, high)
52+
return array_or_scalar(values)
5053

5154

5255
def randn(*size):
53-
if size == ():
54-
return float(torch.randn(size))
55-
else:
56-
values = torch.randn(*size).to(_default_dtypes)
57-
return asarray(values)
56+
values = torch.randn(size, dtype=_default_dtype)
57+
return array_or_scalar(values)
5858

5959

6060
def normal(loc=0.0, scale=1.0, size=None):
6161
if size is None:
6262
size = ()
63-
return loc + scale * randn(*size).to(_default_dtype)
63+
values = torch.empty(size, dtype=_default_dtype).normal_(loc, scale)
64+
return array_or_scalar(values)
6465

6566

6667
def shuffle(x):
@@ -76,9 +77,9 @@ def randint(low, high=None, size=None):
7677
if not isinstance(size, (tuple, list)):
7778
size = (size,)
7879
if high is None:
79-
low, high = 0, high
80+
low, high = 0, low
8081
values = torch.randint(low, high, size=size)
81-
return asarray(values)
82+
return array_or_scalar(values)
8283

8384

8485
def choice(a, size=None, replace=True, p=None):

0 commit comments

Comments
 (0)