Skip to content

Commit 490650f

Browse files
committed
BUG: random: randint returns integers, not floats
1 parent a9d34bf commit 490650f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torch_np/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
_default_dtype = _default_float_type.torch_dtype
1717

18-
__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn", "normal"]
18+
__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn",
19+
"normal", "choice", "randint", "shuffle", "uniform"]
1920

2021

2122
def array_or_scalar(values, py_type=float):
@@ -79,7 +80,7 @@ def randint(low, high=None, size=None):
7980
if high is None:
8081
low, high = 0, low
8182
values = torch.randint(low, high, size=size)
82-
return array_or_scalar(values)
83+
return array_or_scalar(values, int)
8384

8485

8586
def choice(a, size=None, replace=True, p=None):

0 commit comments

Comments
 (0)