Skip to content

Commit 7e39bfd

Browse files
authored
implement min_scalar_type (#133)
1 parent 1402def commit 7e39bfd

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

torch_np/_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,5 @@ def __getitem__(self, item):
6161
index_exp = IndexExpression(maketuple=True)
6262
s_ = IndexExpression(maketuple=False)
6363

64+
6465
__all__ += ["index_exp", "s_"]

torch_np/_funcs_impl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,3 +1656,43 @@ def histogram(
16561656
b = b.long()
16571657

16581658
return h, b
1659+
1660+
1661+
# ### odds and ends
1662+
1663+
1664+
def min_scalar_type(a: ArrayLike, /):
1665+
# https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288
1666+
1667+
from ._dtypes import DType
1668+
1669+
if a.numel() > 1:
1670+
# numpy docs: "For non-scalar array a, returns the vector’s dtype unmodified."
1671+
return DType(a.dtype)
1672+
1673+
if a.dtype == torch.bool:
1674+
dtype = torch.bool
1675+
1676+
elif a.dtype.is_complex:
1677+
fi = torch.finfo(torch.float32)
1678+
fits_in_single = a.dtype == torch.complex64 or (
1679+
fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max
1680+
)
1681+
dtype = torch.complex64 if fits_in_single else torch.complex128
1682+
1683+
elif a.dtype.is_floating_point:
1684+
for dt in [torch.float16, torch.float32, torch.float64]:
1685+
fi = torch.finfo(dt)
1686+
if fi.min <= a <= fi.max:
1687+
dtype = dt
1688+
break
1689+
else:
1690+
# must be integer
1691+
for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
1692+
# Prefer unsigned int where possible, as numpy does.
1693+
ii = torch.iinfo(dt)
1694+
if ii.min <= a <= ii.max:
1695+
dtype = dt
1696+
break
1697+
1698+
return DType(dtype)

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6124,14 +6124,26 @@ def test_complex_warning(self):
61246124
assert_equal(x, [1, 2])
61256125

61266126

6127-
@pytest.mark.xfail(reason='TODO')
61286127
class TestMinScalarType:
61296128

61306129
def test_usigned_shortshort(self):
61316130
dt = np.min_scalar_type(2**8-1)
61326131
wanted = np.dtype('uint8')
61336132
assert_equal(wanted, dt)
61346133

6134+
# three tests below are added based on what numpy does
6135+
def test_complex(self):
6136+
dt = np.min_scalar_type(0+0j)
6137+
assert dt == np.dtype('complex64')
6138+
6139+
def test_float(self):
6140+
dt = np.min_scalar_type(0.1)
6141+
assert dt == np.dtype('float16')
6142+
6143+
def test_nonscalar(self):
6144+
dt = np.min_scalar_type([0, 1, 2])
6145+
assert dt == np.dtype('int64')
6146+
61356147

61366148
from numpy.core._internal import _dtype_from_pep3118
61376149

0 commit comments

Comments
 (0)