Skip to content

Commit b1e7807

Browse files
authored
Merge pull request #37 from Quansight-Labs/randint_int
BUG: random: randint returns integers, not floats
2 parents 1b2937e + 46af221 commit b1e7807

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

torch_np/random.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,19 @@
1515

1616
_default_dtype = _default_float_type.torch_dtype
1717

18-
__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn", "normal"]
18+
__all__ = [
19+
"seed",
20+
"random_sample",
21+
"sample",
22+
"random",
23+
"rand",
24+
"randn",
25+
"normal",
26+
"choice",
27+
"randint",
28+
"shuffle",
29+
"uniform",
30+
]
1931

2032

2133
def array_or_scalar(values, py_type=float):
@@ -79,7 +91,7 @@ def randint(low, high=None, size=None):
7991
if high is None:
8092
low, high = 0, low
8193
values = torch.randint(low, high, size=size)
82-
return array_or_scalar(values)
94+
return array_or_scalar(values, int)
8395

8496

8597
def choice(a, size=None, replace=True, p=None):

0 commit comments

Comments
 (0)