diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 6a510123..0fa4b494 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -61,4 +61,5 @@ def __getitem__(self, item): index_exp = IndexExpression(maketuple=True) s_ = IndexExpression(maketuple=False) + __all__ += ["index_exp", "s_"] diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 5a35efd8..0f70893c 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -1656,3 +1656,43 @@ def histogram( b = b.long() return h, b + + +# ### odds and ends + + +def min_scalar_type(a: ArrayLike, /): + # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288 + + from ._dtypes import DType + + if a.numel() > 1: + # numpy docs: "For non-scalar array a, returns the vector’s dtype unmodified." + return DType(a.dtype) + + if a.dtype == torch.bool: + dtype = torch.bool + + elif a.dtype.is_complex: + fi = torch.finfo(torch.float32) + fits_in_single = a.dtype == torch.complex64 or ( + fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max + ) + dtype = torch.complex64 if fits_in_single else torch.complex128 + + elif a.dtype.is_floating_point: + for dt in [torch.float16, torch.float32, torch.float64]: + fi = torch.finfo(dt) + if fi.min <= a <= fi.max: + dtype = dt + break + else: + # must be integer + for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: + # Prefer unsigned int where possible, as numpy does. + ii = torch.iinfo(dt) + if ii.min <= a <= ii.max: + dtype = dt + break + + return DType(dtype) diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index ab66c443..03a1adf9 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -6124,7 +6124,6 @@ def test_complex_warning(self): assert_equal(x, [1, 2]) -@pytest.mark.xfail(reason='TODO') class TestMinScalarType: def test_usigned_shortshort(self): @@ -6132,6 +6131,19 @@ def test_usigned_shortshort(self): wanted = np.dtype('uint8') assert_equal(wanted, dt) + # three tests below are added based on what numpy does + def test_complex(self): + dt = np.min_scalar_type(0+0j) + assert dt == np.dtype('complex64') + + def test_float(self): + dt = np.min_scalar_type(0.1) + assert dt == np.dtype('float16') + + def test_nonscalar(self): + dt = np.min_scalar_type([0, 1, 2]) + assert dt == np.dtype('int64') + from numpy.core._internal import _dtype_from_pep3118