From 5219e3e086c629477c5f534499a9dfe1fabd207c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 9 Feb 2023 21:40:38 +0300 Subject: [PATCH 1/7] BUG: fix dtype handling in full --- torch_np/_wrapper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index d18b1565..e8938457 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -314,8 +314,7 @@ def full(shape, fill_value, dtype=None, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError - if isinstance(fill_value, ndarray): - fill_value = fill_value.get() + fill_value = asarray(fill_value).get() if dtype is None: torch_dtype = asarray(fill_value).get().dtype else: From 402d05540a437f55fe5e20ede9d2d6f1b4a50643 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 9 Feb 2023 21:46:03 +0300 Subject: [PATCH 2/7] MAINT: simplify full/dtype --- torch_np/_wrapper.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index e8938457..ecd9395a 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -309,17 +309,16 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None): result = result.reshape(shape) return result - +@_decorators.dtype_to_torch def full(shape, fill_value, dtype=None, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError fill_value = asarray(fill_value).get() if dtype is None: - torch_dtype = asarray(fill_value).get().dtype - else: - torch_dtype = _dtypes.torch_dtype_from(dtype) - return asarray(torch.full(shape, fill_value, dtype=torch_dtype)) + dtype = fill_value.dtype + result = torch.full(shape, fill_value, dtype=dtype) + return asarray(result) @asarray_replacer() From c26036007779106ea929459b2f420a51a7da2bbf Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 9 Feb 2023 22:09:36 +0300 Subject: [PATCH 3/7] MAINT: empty/full/ones/zeros: dtype handling --- torch_np/_wrapper.py | 35 +++++++++++++++---- .../tests/numpy_tests/core/test_numeric.py | 9 ++++- .../tests/numpy_tests/lib/test_shape_base_.py | 2 +- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index ecd9395a..3363ffc2 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -288,12 +288,20 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None): raise ValueError("Maximum allowed size exceeded") +@_decorators.dtype_to_torch def empty(shape, dtype=float, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError - torch_dtype = _dtypes.torch_dtype_from(dtype) - return asarray(torch.empty(shape, dtype=torch_dtype)) + + if dtype is None: + from ._detail._scalar_types import default_float_type + + dtype = default_float_type.torch_dtype + + result = torch.empty(shape, dtype=dtype) + + return asarray(result) # NB: *_like function deliberately deviate from numpy: it has subok=True @@ -309,15 +317,22 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None): result = result.reshape(shape) return result + @_decorators.dtype_to_torch def full(shape, fill_value, dtype=None, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError + fill_value = asarray(fill_value).get() if dtype is None: - dtype = fill_value.dtype + dtype = fill_value.dtype + + if not isinstance(shape, (tuple, list)): + shape = (shape,) + result = torch.full(shape, fill_value, dtype=dtype) + return asarray(result) @@ -333,12 +348,19 @@ def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None): return result +@_decorators.dtype_to_torch def ones(shape, dtype=None, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError - torch_dtype = _dtypes.torch_dtype_from(dtype) - return asarray(torch.ones(shape, dtype=torch_dtype)) + if dtype is None: + from ._detail._scalar_types import default_float_type + + dtype = default_float_type.torch_dtype + + result = torch.ones(shape, dtype=dtype) + + return asarray(result) @asarray_replacer() @@ -360,7 +382,8 @@ def zeros(shape, dtype=None, order="C", *, like=None): raise NotImplementedError if dtype is None: dtype = _dtypes_impl.default_float_dtype - return asarray(torch.zeros(shape, dtype=dtype)) + result = torch.zeros(shape, dtype=dtype) + return asarray(result) @asarray_replacer() diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index 3fddbc5a..b4910bb0 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -2666,12 +2666,19 @@ def test_mode(self): np.convolve(d, k, mode=None) +class TestDtypePositional: + + @pytest.mark.xfail(reason='TODO: restore dtypes as positional args') + def test_dtype_positional(self): + np.empty((2,), bool) + + class TestArgwhere: @pytest.mark.parametrize('nd', [0, 1, 2]) def test_nd(self, nd): # get an nd array with multiple elements in every dimension - x = np.empty((2,)*nd, bool) + x = np.empty((2,)*nd, dtype=bool) # none x[...] = False diff --git a/torch_np/tests/numpy_tests/lib/test_shape_base_.py b/torch_np/tests/numpy_tests/lib/test_shape_base_.py index 3400f4b2..dc991af7 100644 --- a/torch_np/tests/numpy_tests/lib/test_shape_base_.py +++ b/torch_np/tests/numpy_tests/lib/test_shape_base_.py @@ -719,7 +719,7 @@ def test_kroncompare(self): for s in shape: b = randint(0, 10, size=s) for r in reps: - a = np.ones(r, b.dtype) + a = np.ones(r, dtype=b.dtype) # TODO: restore dtype positional arg large = tile(b, r) klarge = kron(a, b) assert_equal(large, klarge) From 4f1e66e74b9df1c4e0c6b8c9c7c9531142a41e51 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 9 Feb 2023 23:05:33 +0300 Subject: [PATCH 4/7] MAINT: un-xfail TestTypes --- torch_np/_ndarray.py | 17 +- .../tests/numpy_tests/core/test_numeric.py | 215 +----------------- 2 files changed, 23 insertions(+), 209 deletions(-) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 6787425a..e9f9678f 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -494,11 +494,24 @@ def wrapped(x, *args, **kwds): def can_cast(from_, to, casting="safe"): - from_ = from_.dtype if isinstance(from_, ndarray) else _dtypes.dtype(from_) - to_ = to.dtype if isinstance(to, ndarray) else _dtypes.dtype(to) + from_ = _extract_dtype(from_) + to_ = extract_dtype(to_) return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) +''' + # XXX: merge with _dtypes.can_cast. The Q is who converts from ndarray, if needed. + try: + from_dtype = asarray(from_).dtype + except (TypeError, RuntimeError): + # not an array_like; try convering to a dtype + from_dtype = _dtypes.dtype(from_) + + try: + to_dtype = asarray(to).dtype + except (TypeError, RuntimeError): + to_dtype = _dtypes.dtype(to) +''' def _extract_dtype(entry): try: diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index b4910bb0..242d44b9 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -743,7 +743,6 @@ def test_warnings(self): assert_("underflow" in str(w[-1].message)) -@pytest.mark.xfail(reason="TODO") class TestTypes: def check_promotion_cases(self, promote_func): # tests that the scalars get coerced correctly. @@ -765,12 +764,9 @@ def check_promotion_cases(self, promote_func): assert_equal(promote_func(b, u8), np.dtype(np.uint8)) assert_equal(promote_func(i8, u8), np.dtype(np.int16)) assert_equal(promote_func(u8, i32), np.dtype(np.int32)) - assert_equal(promote_func(i64, u32), np.dtype(np.int64)) - assert_equal(promote_func(u64, i32), np.dtype(np.float64)) assert_equal(promote_func(i32, f32), np.dtype(np.float64)) assert_equal(promote_func(i64, f32), np.dtype(np.float64)) assert_equal(promote_func(f32, i16), np.dtype(np.float32)) - assert_equal(promote_func(f32, u32), np.dtype(np.float64)) assert_equal(promote_func(f32, c64), np.dtype(np.complex64)) assert_equal(promote_func(c128, f32), np.dtype(np.complex128)) @@ -779,14 +775,7 @@ def check_promotion_cases(self, promote_func): assert_equal(promote_func(np.array([b]), u8), np.dtype(np.uint8)) assert_equal(promote_func(np.array([b]), i32), np.dtype(np.int32)) assert_equal(promote_func(np.array([i8]), i64), np.dtype(np.int8)) - assert_equal(promote_func(u64, np.array([i32])), np.dtype(np.int32)) - assert_equal(promote_func(np.int32(-1), np.array([u64])), - np.dtype(np.float64)) assert_equal(promote_func(f64, np.array([f32])), np.dtype(np.float32)) - assert_equal(promote_func(fld, np.array([f32])), np.dtype(np.float32)) - assert_equal(promote_func(np.array([f64]), fld), np.dtype(np.float64)) - assert_equal(promote_func(fld, np.array([c64])), - np.dtype(np.complex64)) assert_equal(promote_func(c64, np.array([f64])), np.dtype(np.complex128)) assert_equal(promote_func(np.complex64(3j), np.array([f64])), @@ -797,7 +786,6 @@ def check_promotion_cases(self, promote_func): assert_equal(promote_func(np.array([b]), f64), np.dtype(np.float64)) assert_equal(promote_func(np.array([b]), i64), np.dtype(np.int64)) assert_equal(promote_func(np.array([i8]), f64), np.dtype(np.float64)) - assert_equal(promote_func(np.array([u16]), f64), np.dtype(np.float64)) # float and complex are treated as the same "kind" for # the purposes of array-scalar promotion, so that you can do @@ -856,159 +844,17 @@ def res_type(a, b): def test_result_type(self): self.check_promotion_cases(np.result_type) + + @pytest.mark.skip(reason='array(None) not supported') + def test_tesult_type_2(self): assert_(np.result_type(None) == np.dtype(None)) + @pytest.mark.skip(reason='no endianness in dtypes') def test_promote_types_endian(self): # promote_types should always return native-endian types assert_equal(np.promote_types('i8', '>i8'), np.dtype('i8')) - assert_equal(np.promote_types('>i8', '>U16'), np.dtype('U21')) - assert_equal(np.promote_types('U16', '>i8'), np.dtype('U21')) - assert_equal(np.promote_types('i8', 'no')) assert_(np.can_cast('i8', 'equiv')) @@ -1032,60 +880,13 @@ def test_can_cast(self): assert_(np.can_cast('u4', 'unsafe')) - assert_(np.can_cast('bool', 'S5')) - assert_(not np.can_cast('bool', 'S4')) - - assert_(np.can_cast('b', 'S4')) - assert_(not np.can_cast('b', 'S3')) - - assert_(np.can_cast('u1', 'S3')) - assert_(not np.can_cast('u1', 'S2')) - assert_(np.can_cast('u2', 'S5')) - assert_(not np.can_cast('u2', 'S4')) - assert_(np.can_cast('u4', 'S10')) - assert_(not np.can_cast('u4', 'S9')) - assert_(np.can_cast('u8', 'S20')) - assert_(not np.can_cast('u8', 'S19')) - - assert_(np.can_cast('i1', 'S4')) - assert_(not np.can_cast('i1', 'S3')) - assert_(np.can_cast('i2', 'S6')) - assert_(not np.can_cast('i2', 'S5')) - assert_(np.can_cast('i4', 'S11')) - assert_(not np.can_cast('i4', 'S10')) - assert_(np.can_cast('i8', 'S21')) - assert_(not np.can_cast('i8', 'S20')) - - assert_(np.can_cast('bool', 'S5')) - assert_(not np.can_cast('bool', 'S4')) - - assert_(np.can_cast('b', 'U4')) - assert_(not np.can_cast('b', 'U3')) - - assert_(np.can_cast('u1', 'U3')) - assert_(not np.can_cast('u1', 'U2')) - assert_(np.can_cast('u2', 'U5')) - assert_(not np.can_cast('u2', 'U4')) - assert_(np.can_cast('u4', 'U10')) - assert_(not np.can_cast('u4', 'U9')) - assert_(np.can_cast('u8', 'U20')) - assert_(not np.can_cast('u8', 'U19')) - - assert_(np.can_cast('i1', 'U4')) - assert_(not np.can_cast('i1', 'U3')) - assert_(np.can_cast('i2', 'U6')) - assert_(not np.can_cast('i2', 'U5')) - assert_(np.can_cast('i4', 'U11')) - assert_(not np.can_cast('i4', 'U10')) - assert_(np.can_cast('i8', 'U21')) - assert_(not np.can_cast('i8', 'U20')) - assert_raises(TypeError, np.can_cast, 'i4', None) assert_raises(TypeError, np.can_cast, None, 'i4') # Also test keyword arguments assert_(np.can_cast(from_=np.int32, to=np.int64)) + @pytest.mark.xfail(reason='value-based casting?') def test_can_cast_values(self): # gh-5917 for dt in np.sctypes['int'] + np.sctypes['uint']: From efc832571d1d909b7c82069d4b2f82fe483cae5f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 10 Feb 2023 12:16:14 +0000 Subject: [PATCH 5/7] More thorough `test_full()` Also implements `DType.kind` to raise `NotImplementedError` --- torch_np/_dtypes.py | 4 +++ torch_np/tests/test_xps.py | 50 +++++++++++++++++++++----------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/torch_np/_dtypes.py b/torch_np/_dtypes.py index 11b480c5..9f934af6 100644 --- a/torch_np/_dtypes.py +++ b/torch_np/_dtypes.py @@ -285,6 +285,10 @@ def name(self): def type(self): return self._scalar_type + @property + def kind(self): + raise NotImplementedError + @property def typecode(self): return self._scalar_type.typecode diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index 5c6b4bdd..0befad45 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -25,34 +25,40 @@ xps = make_strategies_namespace(np, api_version="2022.12") +default_dtypes = [np.bool, np.int64, np.float64, np.complex128] +kind_to_strat = { + "b": xps.boolean_dtypes(), + "i": xps.integer_dtypes(), + "u": xps.unsigned_integer_dtypes(sizes=8), + "f": xps.floating_dtypes(), + "c": xps.complex_dtypes(), +} +scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype) + + @given(shape=xps.array_shapes(), data=st.data()) def test_full(shape, data): if data.draw(st.booleans(), label="pass kwargs?"): - kw = {} - else: - dtype = data.draw(st.none() | xps.scalar_dtypes(), label="dtype") + dtype = data.draw(st.none() | scalar_dtype_strat, label="dtype") kw = {"dtype": dtype} - _dtype = kw.get("dtype", None) or data.draw( - st.sampled_from([np.bool, np.int64, np.float64, np.complex128]), label="_dtype" - ) + else: + kw = {} + _dtype = kw.get("dtype", None) or data.draw(scalar_dtype_strat, label="_dtype") values_strat = xps.from_dtype(_dtype) - fill_value = data.draw( - values_strat | values_strat.map(lambda v: np.asarray(v, dtype=_dtype)), - label="fill_value", - ) - out = np.full(shape, fill_value, **kw) - if kw.get("dtype", None) is None and not isinstance(fill_value, np.ndarray): - if isinstance(fill_value, bool): - assert out.dtype == np.bool - elif isinstance(fill_value, int): - assert out.dtype == np.int64 - elif isinstance(fill_value, float): - assert out.dtype == np.float64 + if _dtype not in default_dtypes or data.draw( + st.booleans(), label="fill_value is array?" + ): + if specified_dtype := kw.get("dtype", None): + kind = specified_dtype.name[0] + values_dtypes_strat = kind_to_strat[kind] else: - assert isinstance(fill_value, complex) # sanity check - assert out.dtype == np.complex128 - else: - assert out.dtype == _dtype + values_dtypes_strat = st.just(_dtype) + values_strat = values_dtypes_strat.flatmap( + lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d)) + ) + fill_value = data.draw(values_strat, label="fill_value") + out = np.full(shape, fill_value, **kw) + assert out.dtype == _dtype assert out.shape == shape if cmath.isnan(fill_value): assert np.isnan(out).all() From 89a8d541958388ca5a664814858ecb8812335068 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 13 Feb 2023 19:32:23 +0300 Subject: [PATCH 6/7] MAINT: rebase, pick up changes from main --- torch_np/_ndarray.py | 27 +++++++-------------------- torch_np/_wrapper.py | 13 ++----------- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index e9f9678f..befa3a9e 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -493,26 +493,6 @@ def wrapped(x, *args, **kwds): ###### dtype routines -def can_cast(from_, to, casting="safe"): - from_ = _extract_dtype(from_) - to_ = extract_dtype(to_) - - return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) - -''' - # XXX: merge with _dtypes.can_cast. The Q is who converts from ndarray, if needed. - try: - from_dtype = asarray(from_).dtype - except (TypeError, RuntimeError): - # not an array_like; try convering to a dtype - from_dtype = _dtypes.dtype(from_) - - try: - to_dtype = asarray(to).dtype - except (TypeError, RuntimeError): - to_dtype = _dtypes.dtype(to) -''' - def _extract_dtype(entry): try: dty = _dtypes.dtype(entry) @@ -521,6 +501,13 @@ def _extract_dtype(entry): return dty +def can_cast(from_, to, casting="safe"): + from_ = _extract_dtype(from_) + to_ = _extract_dtype(to) + + return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) + + def result_type(*arrays_and_dtypes): dtypes = [] diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 3363ffc2..8fe98ad7 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -293,14 +293,9 @@ def empty(shape, dtype=float, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError - if dtype is None: - from ._detail._scalar_types import default_float_type - - dtype = default_float_type.torch_dtype - + dtype = _dtypes_impl.default_float_dtype result = torch.empty(shape, dtype=dtype) - return asarray(result) @@ -354,12 +349,8 @@ def ones(shape, dtype=None, order="C", *, like=None): if order != "C": raise NotImplementedError if dtype is None: - from ._detail._scalar_types import default_float_type - - dtype = default_float_type.torch_dtype - + dtype = _dtypes_impl.default_float_dtype result = torch.ones(shape, dtype=dtype) - return asarray(result) From 1b41e1433813d6c6243a83f44041657651153b91 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 13 Feb 2023 21:39:35 +0300 Subject: [PATCH 7/7] ENH: match numpy for tnp.dtype(...).kind --- torch_np/_dtypes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_np/_dtypes.py b/torch_np/_dtypes.py index 9f934af6..c61db2c0 100644 --- a/torch_np/_dtypes.py +++ b/torch_np/_dtypes.py @@ -197,6 +197,7 @@ class bool_(generic): _typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]} _torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]} + _aliases = { "u1": uint8, "i1": int8, @@ -287,7 +288,8 @@ def type(self): @property def kind(self): - raise NotImplementedError + # https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html + return _torch_dtypes[self.torch_dtype].name[0] @property def typecode(self):