From 1253fd1a12ef01754fbdadd350cafec42d4f700d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 12 Apr 2023 18:28:13 +0100 Subject: [PATCH 01/14] Also suppress `RuntimeError` inside `assert_array_compare()` --- torch_np/testing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_np/testing/utils.py b/torch_np/testing/utils.py index 292134fc..b7654774 100644 --- a/torch_np/testing/utils.py +++ b/torch_np/testing/utils.py @@ -679,7 +679,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"): ## with errstate(all='ignore'): # ignore errors for non-numeric types - with contextlib.suppress(TypeError): + with contextlib.suppress(TypeError, RuntimeError): error = abs(x - y) if np.issubdtype(x.dtype, np.unsignedinteger): error2 = abs(y - x) From 73b69e5963a398c61576840b533e4998f6ea061a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 13 Apr 2023 12:21:45 +0100 Subject: [PATCH 02/14] Use repr in `sctype_from_string()` error message --- torch_np/_dtypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_np/_dtypes.py b/torch_np/_dtypes.py index 2c308cac..751de9b3 100644 --- a/torch_np/_dtypes.py +++ b/torch_np/_dtypes.py @@ -243,7 +243,7 @@ def sctype_from_string(s): return _aliases[s] if s in _python_types: return _python_types[s] - raise TypeError(f"data type '{s}' not understood") + raise TypeError(f"data type {s!r} not understood") def sctype_from_torch_dtype(torch_dtype): From efb912aa4464da6f67d681a7581e933fc7bd15e1 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 13 Apr 2023 12:22:18 +0100 Subject: [PATCH 03/14] `tnp.put` and `test_put` --- torch_np/_funcs_impl.py | 22 ++++++++++++ torch_np/tests/test_xps.py | 68 +++++++++++++++++++++++++++++++++----- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index e1171391..865c079a 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -894,6 +894,28 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): return torch.take_along_dim(arr, indices, axis) +def put( + a: ArrayLike, + ind: Sequence[ArrayLike], + v: ArrayLike, + mode: NotImplementedType = "raise", +): + index = torch.concat(ind) + index[index < 0] += a.numel() # normalise negative indices + index_u, index_c = torch.unique(index, return_counts=True) + duplicated_indices = index_u[index_c > 1] + if duplicated_indices.numel() > 0: + raise NotImplementedError( + "duplicated indices are not supported. duplicated indices: " + f"{duplicated_indices}" + ) + source = v + if source.numel() < index.numel(): + source = torch.broadcast_to(source, index.size()) + a.put_(index, source) + return None + + def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): (arr,), axis = _util.axis_none_ravel(arr, axis=axis) axis = _util.normalize_axis_index(axis, arr.ndim) diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index 1d64a26c..eae62939 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -10,22 +10,27 @@ pytest.importorskip("hypothesis") -from hypothesis import given +import numpy as np +import torch +from hypothesis import given, note from hypothesis import strategies as st from hypothesis.errors import HypothesisWarning +from hypothesis.extra import numpy as nps from hypothesis.extra.array_api import make_strategies_namespace -import torch_np as np +import torch_np as tnp +from torch_np._dtypes import sctypes +from torch_np.testing import assert_array_equal __all__ = ["xps"] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=HypothesisWarning) - np.bool = np.bool_ - xps = make_strategies_namespace(np, api_version="2022.12") + tnp.bool = tnp.bool_ + xps = make_strategies_namespace(tnp, api_version="2022.12") -default_dtypes = [np.bool, np.int64, np.float64, np.complex128] +default_dtypes = [tnp.bool, tnp.int64, tnp.float64, tnp.complex128] kind_to_strat = { "b": xps.boolean_dtypes(), "i": xps.integer_dtypes(), @@ -33,7 +38,7 @@ "f": xps.floating_dtypes(), "c": xps.complex_dtypes(), } -scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype) +scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(tnp.dtype) @pytest.mark.skip(reason="flaky") @@ -55,14 +60,14 @@ def test_full(shape, data): else: values_dtypes_strat = st.just(_dtype) values_strat = values_dtypes_strat.flatmap( - lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d)) + lambda d: values_strat.map(lambda v: tnp.asarray(v, dtype=d)) ) fill_value = data.draw(values_strat, label="fill_value") - out = np.full(shape, fill_value, **kw) + out = tnp.full(shape, fill_value, **kw) assert out.dtype == _dtype assert out.shape == shape if cmath.isnan(fill_value): - assert np.isnan(out).all() + assert tnp.isnan(out).all() else: assert (out == fill_value).all() @@ -89,3 +94,48 @@ def test_integer_indexing(x, data): idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx") result = x[idx] assert result.shape == result_shape + + +@given( + np_x=nps.arrays( + # We specifically use namespaced dtypes to prevent non-native byte-order issues + dtype=scalar_dtype_strat.map(lambda d: getattr(np, d.name)), + shape=nps.array_shapes(), + ), + data=st.data(), +) +def test_put(np_x, data): + # We cast arrays from torch_np.asarray as currently it doesn't carry over + # dtypes. XXX: Remove the below sanity check and subsequent casting when + # this is fixed. + assert tnp.asarray(np.zeros(5, dtype=np.int16)).dtype != tnp.int16 + + tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name) + + result_shapes = st.shared(nps.array_shapes()) + ind = data.draw( + nps.integer_array_indices(np_x.shape, result_shape=result_shapes), label="ind" + ) + v = data.draw(nps.arrays(dtype=np_x.dtype, shape=result_shapes), label="v") + + tnp_x_copy = tnp_x.copy() + np.put(np_x, ind, v) + note(f"(after put) {np_x=}") + assert_array_equal(tnp_x, tnp_x_copy) # sanity check + + note(f"{tnp_x=}") + tnp_ind = [] + for np_indices in ind: + tnp_indices = tnp.asarray(np_indices).astype(np_indices.dtype.name) + tnp_ind.append(tnp_indices) + tnp_ind = tuple(tnp_ind) + note(f"{tnp_ind=}") + tnp_v = tnp.asarray(v.copy()).astype(v.dtype.name) + note(f"{tnp_v=}") + try: + tnp.put(tnp_x, tnp_ind, tnp_v) + except NotImplementedError: + return + note(f"(after put) {tnp_x=}") + + assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype)) From 3cc0ecf4468525343ab0b75d436507888257f3e0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Apr 2023 12:12:42 +0100 Subject: [PATCH 04/14] Support non-array indices in `tnp.put()` --- torch_np/_funcs_impl.py | 10 ++++++++-- torch_np/tests/test_xps.py | 13 ++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 865c079a..95812ea2 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -900,7 +900,11 @@ def put( v: ArrayLike, mode: NotImplementedType = "raise", ): - index = torch.concat(ind) + indexes = list(ind) + for i, index in enumerate(indexes): + if not isinstance(index, torch.Tensor): + indexes[i] = torch.as_tensor(index) + index = torch.concat(indexes) index[index < 0] += a.numel() # normalise negative indices index_u, index_c = torch.unique(index, return_counts=True) duplicated_indices = index_u[index_c > 1] @@ -911,7 +915,9 @@ def put( ) source = v if source.numel() < index.numel(): - source = torch.broadcast_to(source, index.size()) + numel_ratio = float(index.numel() / source.numel()) + if numel_ratio.is_integer(): + source = torch.stack([source for _ in range(int(numel_ratio))]) a.put_(index, source) return None diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index eae62939..aa4fb817 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -125,9 +125,16 @@ def test_put(np_x, data): note(f"{tnp_x=}") tnp_ind = [] - for np_indices in ind: - tnp_indices = tnp.asarray(np_indices).astype(np_indices.dtype.name) - tnp_ind.append(tnp_indices) + list_at_ind = data.draw( + st.lists(st.booleans(), min_size=len(ind), max_size=len(ind)), + label="list_at_ind", + ) + for np_indices, use_list in zip(ind, list_at_ind): + if use_list: + indices = np_indices.tolist() + else: + indices = tnp.asarray(np_indices).astype(np_indices.dtype.name) + tnp_ind.append(indices) tnp_ind = tuple(tnp_ind) note(f"{tnp_ind=}") tnp_v = tnp.asarray(v.copy()).astype(v.dtype.name) From 6f70b715b1a730207ca184ba9ddede4b6fe83540 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Apr 2023 12:21:17 +0100 Subject: [PATCH 05/14] Treat `ind` as `ArrayLike` and rely on its normalisation --- torch_np/_funcs_impl.py | 24 +++++------------------- torch_np/tests/test_xps.py | 5 +---- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 95812ea2..7c5b5faa 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -896,29 +896,15 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): def put( a: ArrayLike, - ind: Sequence[ArrayLike], + ind: ArrayLike, v: ArrayLike, mode: NotImplementedType = "raise", ): - indexes = list(ind) - for i, index in enumerate(indexes): - if not isinstance(index, torch.Tensor): - indexes[i] = torch.as_tensor(index) - index = torch.concat(indexes) - index[index < 0] += a.numel() # normalise negative indices - index_u, index_c = torch.unique(index, return_counts=True) - duplicated_indices = index_u[index_c > 1] - if duplicated_indices.numel() > 0: - raise NotImplementedError( - "duplicated indices are not supported. duplicated indices: " - f"{duplicated_indices}" - ) - source = v - if source.numel() < index.numel(): - numel_ratio = float(index.numel() / source.numel()) + if v.numel() < ind.numel(): + numel_ratio = float(ind.numel() / v.numel()) if numel_ratio.is_integer(): - source = torch.stack([source for _ in range(int(numel_ratio))]) - a.put_(index, source) + v = torch.stack([v for _ in range(int(numel_ratio))]) + a.put_(ind, v) return None diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index aa4fb817..0e2ea279 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -139,10 +139,7 @@ def test_put(np_x, data): note(f"{tnp_ind=}") tnp_v = tnp.asarray(v.copy()).astype(v.dtype.name) note(f"{tnp_v=}") - try: - tnp.put(tnp_x, tnp_ind, tnp_v) - except NotImplementedError: - return + tnp.put(tnp_x, tnp_ind, tnp_v) note(f"(after put) {tnp_x=}") assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype)) From e7ac4e60c3c91eae122fe5b60aff29581e564f7b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Apr 2023 18:04:43 +0100 Subject: [PATCH 06/14] Alternative approach for repeating `v` in `tnp.put()` --- torch_np/_funcs_impl.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 7c5b5faa..9f58cff8 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -900,10 +900,11 @@ def put( v: ArrayLike, mode: NotImplementedType = "raise", ): - if v.numel() < ind.numel(): - numel_ratio = float(ind.numel() / v.numel()) - if numel_ratio.is_integer(): - v = torch.stack([v for _ in range(int(numel_ratio))]) + numel_ratio = ind.numel() / v.numel() + if numel_ratio.is_integer(): + sizes = [int(numel_ratio)] + sizes.extend([1 for _ in range(v.dim() - 1)]) + v = v.repeat(*sizes) a.put_(ind, v) return None From 8541ca3f6df15ad5e7ef0feee96921f155618f75 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Apr 2023 18:38:03 +0100 Subject: [PATCH 07/14] Raise `NotImplementedError` when attempting to modify `ndarray.flags` --- torch_np/_ndarray.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 3b67afa6..338e48b9 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -56,6 +56,18 @@ def __getitem__(self, key): else: raise KeyError(f"No flag key '{key}'") + def __setattr__(self, attr, value): + if attr.islower() and attr.upper() in FLAGS: + self[attr.upper()] = value + else: + super().__setattr__(attr, value) + + def __setitem__(self, key, value): + if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): + raise NotImplementedError("Modifying flags is not implemented") + else: + raise KeyError(f"No flag key '{key}'") + def create_method(fn, name=None): name = name or fn.__name__ From f7ee3b8cdabee90b43d1671a281b2c63b67f95a2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Apr 2023 18:43:22 +0100 Subject: [PATCH 08/14] Add `ndarray.put()`, cast `v` to `a` in `put()` Alos partially unxfails `test_multiarray.py::TestMethods::test_put` --- torch_np/_funcs_impl.py | 1 + torch_np/_ndarray.py | 3 +++ torch_np/tests/numpy_tests/core/test_multiarray.py | 8 ++------ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 9f58cff8..bacbfe89 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -900,6 +900,7 @@ def put( v: ArrayLike, mode: NotImplementedType = "raise", ): + v = v.type(a.dtype) numel_ratio = ind.numel() / v.numel() if numel_ratio.is_integer(): sizes = [int(numel_ratio)] diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 338e48b9..c6ba391b 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -409,6 +409,9 @@ def __setitem__(self, index, value): value = _util.cast_if_needed(value, self.tensor.dtype) return self.tensor.__setitem__(index, value) + take = _funcs.take + put = _funcs.put + # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index bc450c71..920448a5 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2641,7 +2641,7 @@ def test_trace(self): def test_put(self): icodes = np.typecodes['AllInteger'] fcodes = np.typecodes['AllFloat'] - for dt in icodes + fcodes + 'O': + for dt in icodes + fcodes: tgt = np.array([0, 1, 0, 3, 0, 5], dtype=dt) # test 1-d @@ -2667,14 +2667,10 @@ def test_put(self): a.put([1, 3, 5], [True]*3) assert_equal(a, tgt.reshape(2, 3)) - # check must be writeable - a = np.zeros(6) - a.flags.writeable = False - assert_raises(ValueError, a.put, [1, 3, 5], [1, 3, 5]) - # when calling np.put, make sure a # TypeError is raised if the object # isn't an ndarray + pytest.xfail("XXX: Argument normalisation prevents catching this") bad_array = [1, 2, 3] assert_raises(TypeError, np.put, bad_array, [0, 2], 5) From 1d750bf13ef4f6d9130d1c985a426a3e7fbe1868 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Apr 2023 18:51:10 +0100 Subject: [PATCH 09/14] Update `test_multiarray.py::TestWritebackIfCopy` skips/xfails --- torch_np/tests/numpy_tests/core/test_multiarray.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index 920448a5..c6910bfc 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2670,7 +2670,7 @@ def test_put(self): # when calling np.put, make sure a # TypeError is raised if the object # isn't an ndarray - pytest.xfail("XXX: Argument normalisation prevents catching this") + pytest.xfail(reason="XXX: Argument normalisation prevents catching this") bad_array = [1, 2, 3] assert_raises(TypeError, np.put, bad_array, [0, 2], 5) @@ -7377,7 +7377,6 @@ def test_1d_format(self): from numpy.testing import IS_PYPY -@pytest.mark.skip(reason="not going to implement WRITEBACKIFCOPY") class TestWritebackIfCopy: # all these tests use the WRITEBACKIFCOPY mechanism def test_argmax_with_out(self): @@ -7392,6 +7391,7 @@ def test_argmin_with_out(self): res = np.argmin(mat, 0, out=out) assert_equal(res, range(5)) + @pytest.mark.xfail(reason="XXX: place()") def test_insert_noncontiguous(self): a = np.arange(6).reshape(2,3).T # force non-c-contiguous # uses arr_insert @@ -7402,9 +7402,11 @@ def test_insert_noncontiguous(self): def test_put_noncontiguous(self): a = np.arange(6).reshape(2,3).T # force non-c-contiguous + assert not a.flags["C_CONTIGUOUS"] # sanity check np.put(a, [0, 2], [44, 55]) assert_equal(a, np.array([[44, 3], [55, 4], [2, 5]])) + @pytest.mark.xfail(reason="XXX: putmask()") def test_putmask_noncontiguous(self): a = np.arange(6).reshape(2,3).T # force non-c-contiguous # uses arr_putmask @@ -7417,6 +7419,7 @@ def test_take_mode_raise(self): np.take(a, [0, 2], out=out, mode='raise') assert_equal(out, np.array([0, 2])) + @pytest.mark.xfail(reason="XXX: choose()") def test_choose_mod_raise(self): a = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]]) out = np.empty((3,3), dtype='int') @@ -7426,6 +7429,7 @@ def test_choose_mod_raise(self): [-10, 10, -10], [ 10, -10, 10]])) + @pytest.mark.xfail(reason="XXX: ndarray.flat") def test_flatiter__array__(self): a = np.arange(9).reshape(3,3) b = a.T.flat @@ -7439,6 +7443,7 @@ def test_dot_out(self): b = np.dot(a, a, out=a) assert_equal(b, np.array([[15, 18, 21], [42, 54, 66], [69, 90, 111]])) + @pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()") def test_view_assign(self): from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_resolve @@ -7457,6 +7462,7 @@ def test_view_assign(self): arr_wb[...] = 100 assert_equal(arr, -100) + @pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()") @pytest.mark.leaks_references( reason="increments self in dealloc; ignore since deprecated path.") def test_dealloc_warning(self): @@ -7467,6 +7473,7 @@ def test_dealloc_warning(self): _multiarray_tests.npy_abuse_writebackifcopy(v) assert len(sup.log) == 1 + @pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()") def test_view_discard_refcount(self): from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_discard From 5dc5a2502b3f6ed77107aa47590762c13860734a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 19 Apr 2023 10:16:33 +0100 Subject: [PATCH 10/14] `put()`: `a: ArrayLike` -> `a: NDArray` Prevents normalising non-ndarray arguments --- torch_np/_funcs_impl.py | 2 +- torch_np/tests/numpy_tests/core/test_multiarray.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index bacbfe89..2558005e 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -895,7 +895,7 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): def put( - a: ArrayLike, + a: NDArray, ind: ArrayLike, v: ArrayLike, mode: NotImplementedType = "raise", diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index c6910bfc..ab66c443 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2637,7 +2637,6 @@ def test_trace(self): ret = a.trace(out=out) assert ret is out - @pytest.mark.xfail(reason="TODO: implement put") def test_put(self): icodes = np.typecodes['AllInteger'] fcodes = np.typecodes['AllFloat'] @@ -2670,7 +2669,6 @@ def test_put(self): # when calling np.put, make sure a # TypeError is raised if the object # isn't an ndarray - pytest.xfail(reason="XXX: Argument normalisation prevents catching this") bad_array = [1, 2, 3] assert_raises(TypeError, np.put, bad_array, [0, 2], 5) From 8f5ade2201ae3955df6b970f41eb56a4600179c8 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 19 Apr 2023 14:07:56 +0100 Subject: [PATCH 11/14] Simplify what `test_put` tests for * `list_at_ind` stuff should be covered when testing the normilizer * Manually converting `ind` and `v` to `tnp.ndarray` is mostly redundant --- torch_np/tests/test_xps.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index 0e2ea279..5f43a60b 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -96,6 +96,9 @@ def test_integer_indexing(x, data): assert result.shape == result_shape +@pytest.mark.filterwarnings( + "ignore:Creating a tensor from a list of numpy.ndarrays.*:UserWarning" +) @given( np_x=nps.arrays( # We specifically use namespaced dtypes to prevent non-native byte-order issues @@ -124,22 +127,7 @@ def test_put(np_x, data): assert_array_equal(tnp_x, tnp_x_copy) # sanity check note(f"{tnp_x=}") - tnp_ind = [] - list_at_ind = data.draw( - st.lists(st.booleans(), min_size=len(ind), max_size=len(ind)), - label="list_at_ind", - ) - for np_indices, use_list in zip(ind, list_at_ind): - if use_list: - indices = np_indices.tolist() - else: - indices = tnp.asarray(np_indices).astype(np_indices.dtype.name) - tnp_ind.append(indices) - tnp_ind = tuple(tnp_ind) - note(f"{tnp_ind=}") - tnp_v = tnp.asarray(v.copy()).astype(v.dtype.name) - note(f"{tnp_v=}") - tnp.put(tnp_x, tnp_ind, tnp_v) + tnp.put(tnp_x, ind, v) note(f"(after put) {tnp_x=}") assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype)) From 40f7ccbc7e8209bf376591477582792d95e6e2a6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 19 Apr 2023 18:28:07 +0100 Subject: [PATCH 12/14] Broadcast rather than extend `v` in `put()` Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> --- torch_np/_funcs_impl.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 2558005e..4af39045 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -901,11 +901,16 @@ def put( mode: NotImplementedType = "raise", ): v = v.type(a.dtype) - numel_ratio = ind.numel() / v.numel() - if numel_ratio.is_integer(): - sizes = [int(numel_ratio)] - sizes.extend([1 for _ in range(v.dim() - 1)]) - v = v.repeat(*sizes) + # If ind is larger than v, broadcast v to the would-be resulting shape. Any + # unnecessary trailing elements are then trimmed. + if ind.numel() > v.numel(): + result_shape = torch.broadcast_shapes(v.shape, ind.shape) + v = torch.broadcast_to(v, result_shape) + # Trim unnecessary elements, regarldess if v was broadcasted or not. Note + # np.put() trims v to match ind by default too. + if ind.numel() < v.numel(): + v = v.flatten() + v = v[: ind.numel()] a.put_(ind, v) return None From cae6d94e10a4ea6d68c41198d7ebd0c2cc179d75 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 25 Apr 2023 12:37:51 +0100 Subject: [PATCH 13/14] `put()`: support and test broadcast-incompatible args --- torch_np/_funcs_impl.py | 6 ++++-- torch_np/tests/test_xps.py | 18 ++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 4af39045..43da4487 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -904,8 +904,10 @@ def put( # If ind is larger than v, broadcast v to the would-be resulting shape. Any # unnecessary trailing elements are then trimmed. if ind.numel() > v.numel(): - result_shape = torch.broadcast_shapes(v.shape, ind.shape) - v = torch.broadcast_to(v, result_shape) + ratio = (ind.numel() + v.numel() - 1) // v.numel() + sizes = [ratio] + sizes.extend([1 for _ in range(v.dim() - 1)]) + v = v.repeat(*sizes) # Trim unnecessary elements, regarldess if v was broadcasted or not. Note # np.put() trims v to match ind by default too. if ind.numel() < v.numel(): diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index 5f43a60b..f1e2e328 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -4,6 +4,7 @@ These tests aren't specifically for testing Array API adoption! """ import cmath +import math import warnings import pytest @@ -115,11 +116,20 @@ def test_put(np_x, data): tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name) - result_shapes = st.shared(nps.array_shapes()) - ind = data.draw( - nps.integer_array_indices(np_x.shape, result_shape=result_shapes), label="ind" + result_shape = data.draw(nps.array_shapes(), label="result_shape") + ind_strat = nps.integer_array_indices( + np_x.shape, result_shape=st.just(result_shape) + ) + ind = data.draw(ind_strat | ind_strat.map(np.asarray), label="ind") + v = data.draw( + nps.arrays( + dtype=np_x.dtype, + shape=nps.array_shapes().filter( + lambda s: math.prod(s) > math.prod(result_shape) + ), + ), + label="v", ) - v = data.draw(nps.arrays(dtype=np_x.dtype, shape=result_shapes), label="v") tnp_x_copy = tnp_x.copy() np.put(np_x, ind, v) From 65c212713a940c27fa41638dab02649e18625ccd Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 25 Apr 2023 13:04:36 +0100 Subject: [PATCH 14/14] `put()`: expand over repeat internally Also test 0d indices --- torch_np/_funcs_impl.py | 8 +++----- torch_np/tests/test_xps.py | 9 ++++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 43da4487..09c11f88 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -901,14 +901,12 @@ def put( mode: NotImplementedType = "raise", ): v = v.type(a.dtype) - # If ind is larger than v, broadcast v to the would-be resulting shape. Any + # If ind is larger than v, expand v to at least the size of ind. Any # unnecessary trailing elements are then trimmed. if ind.numel() > v.numel(): ratio = (ind.numel() + v.numel() - 1) // v.numel() - sizes = [ratio] - sizes.extend([1 for _ in range(v.dim() - 1)]) - v = v.repeat(*sizes) - # Trim unnecessary elements, regarldess if v was broadcasted or not. Note + v = v.unsqueeze(0).expand((ratio,) + v.shape) + # Trim unnecessary elements, regarldess if v was expanded or not. Note # np.put() trims v to match ind by default too. if ind.numel() < v.numel(): v = v.flatten() diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index f1e2e328..10700092 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -117,9 +117,12 @@ def test_put(np_x, data): tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name) result_shape = data.draw(nps.array_shapes(), label="result_shape") - ind_strat = nps.integer_array_indices( - np_x.shape, result_shape=st.just(result_shape) - ) + if result_shape == (): + ind_strat = st.integers(np_x.size) + else: + ind_strat = nps.integer_array_indices( + np_x.shape, result_shape=st.just(result_shape) + ) ind = data.draw(ind_strat | ind_strat.map(np.asarray), label="ind") v = data.draw( nps.arrays(