diff --git a/torch_np/random.py b/torch_np/random.py index 6bb446c4..a72a1e33 100644 --- a/torch_np/random.py +++ b/torch_np/random.py @@ -15,7 +15,19 @@ _default_dtype = _default_float_type.torch_dtype -__all__ = ["seed", "random_sample", "sample", "random", "rand", "randn", "normal"] +__all__ = [ + "seed", + "random_sample", + "sample", + "random", + "rand", + "randn", + "normal", + "choice", + "randint", + "shuffle", + "uniform", +] def array_or_scalar(values, py_type=float): @@ -79,7 +91,7 @@ def randint(low, high=None, size=None): if high is None: low, high = 0, low values = torch.randint(low, high, size=size) - return array_or_scalar(values) + return array_or_scalar(values, int) def choice(a, size=None, replace=True, p=None):