Skip to content

Commit cf64dbd

Browse files
authored
Merge pull request #116 from honno/put
`tnp.put()` + testing
2 parents 2ce7d21 + 65c2127 commit cf64dbd

File tree

6 files changed

+111
-19
lines changed

6 files changed

+111
-19
lines changed

torch_np/_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def sctype_from_string(s):
243243
return _aliases[s]
244244
if s in _python_types:
245245
return _python_types[s]
246-
raise TypeError(f"data type '{s}' not understood")
246+
raise TypeError(f"data type {s!r} not understood")
247247

248248

249249
def sctype_from_torch_dtype(torch_dtype):

torch_np/_funcs_impl.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,27 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
894894
return torch.take_along_dim(arr, indices, axis)
895895

896896

897+
def put(
898+
a: NDArray,
899+
ind: ArrayLike,
900+
v: ArrayLike,
901+
mode: NotImplementedType = "raise",
902+
):
903+
v = v.type(a.dtype)
904+
# If ind is larger than v, expand v to at least the size of ind. Any
905+
# unnecessary trailing elements are then trimmed.
906+
if ind.numel() > v.numel():
907+
ratio = (ind.numel() + v.numel() - 1) // v.numel()
908+
v = v.unsqueeze(0).expand((ratio,) + v.shape)
909+
# Trim unnecessary elements, regarldess if v was expanded or not. Note
910+
# np.put() trims v to match ind by default too.
911+
if ind.numel() < v.numel():
912+
v = v.flatten()
913+
v = v[: ind.numel()]
914+
a.put_(ind, v)
915+
return None
916+
917+
897918
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
898919
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
899920
axis = _util.normalize_axis_index(axis, arr.ndim)

torch_np/_ndarray.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def __getitem__(self, key):
5656
else:
5757
raise KeyError(f"No flag key '{key}'")
5858

59+
def __setattr__(self, attr, value):
60+
if attr.islower() and attr.upper() in FLAGS:
61+
self[attr.upper()] = value
62+
else:
63+
super().__setattr__(attr, value)
64+
65+
def __setitem__(self, key, value):
66+
if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
67+
raise NotImplementedError("Modifying flags is not implemented")
68+
else:
69+
raise KeyError(f"No flag key '{key}'")
70+
5971

6072
def create_method(fn, name=None):
6173
name = name or fn.__name__
@@ -397,6 +409,9 @@ def __setitem__(self, index, value):
397409
value = _util.cast_if_needed(value, self.tensor.dtype)
398410
return self.tensor.__setitem__(index, value)
399411

412+
take = _funcs.take
413+
put = _funcs.put
414+
400415

401416
# This is the ideally the only place which talks to ndarray directly.
402417
# The rest goes through asarray (preferred) or array.

torch_np/testing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"):
679679

680680
## with errstate(all='ignore'):
681681
# ignore errors for non-numeric types
682-
with contextlib.suppress(TypeError):
682+
with contextlib.suppress(TypeError, RuntimeError):
683683
error = abs(x - y)
684684
if np.issubdtype(x.dtype, np.unsignedinteger):
685685
error2 = abs(y - x)

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,11 +2637,10 @@ def test_trace(self):
26372637
ret = a.trace(out=out)
26382638
assert ret is out
26392639

2640-
@pytest.mark.xfail(reason="TODO: implement put")
26412640
def test_put(self):
26422641
icodes = np.typecodes['AllInteger']
26432642
fcodes = np.typecodes['AllFloat']
2644-
for dt in icodes + fcodes + 'O':
2643+
for dt in icodes + fcodes:
26452644
tgt = np.array([0, 1, 0, 3, 0, 5], dtype=dt)
26462645

26472646
# test 1-d
@@ -2667,11 +2666,6 @@ def test_put(self):
26672666
a.put([1, 3, 5], [True]*3)
26682667
assert_equal(a, tgt.reshape(2, 3))
26692668

2670-
# check must be writeable
2671-
a = np.zeros(6)
2672-
a.flags.writeable = False
2673-
assert_raises(ValueError, a.put, [1, 3, 5], [1, 3, 5])
2674-
26752669
# when calling np.put, make sure a
26762670
# TypeError is raised if the object
26772671
# isn't an ndarray
@@ -7381,7 +7375,6 @@ def test_1d_format(self):
73817375
from numpy.testing import IS_PYPY
73827376

73837377

7384-
@pytest.mark.skip(reason="not going to implement WRITEBACKIFCOPY")
73857378
class TestWritebackIfCopy:
73867379
# all these tests use the WRITEBACKIFCOPY mechanism
73877380
def test_argmax_with_out(self):
@@ -7396,6 +7389,7 @@ def test_argmin_with_out(self):
73967389
res = np.argmin(mat, 0, out=out)
73977390
assert_equal(res, range(5))
73987391

7392+
@pytest.mark.xfail(reason="XXX: place()")
73997393
def test_insert_noncontiguous(self):
74007394
a = np.arange(6).reshape(2,3).T # force non-c-contiguous
74017395
# uses arr_insert
@@ -7406,9 +7400,11 @@ def test_insert_noncontiguous(self):
74067400

74077401
def test_put_noncontiguous(self):
74087402
a = np.arange(6).reshape(2,3).T # force non-c-contiguous
7403+
assert not a.flags["C_CONTIGUOUS"] # sanity check
74097404
np.put(a, [0, 2], [44, 55])
74107405
assert_equal(a, np.array([[44, 3], [55, 4], [2, 5]]))
74117406

7407+
@pytest.mark.xfail(reason="XXX: putmask()")
74127408
def test_putmask_noncontiguous(self):
74137409
a = np.arange(6).reshape(2,3).T # force non-c-contiguous
74147410
# uses arr_putmask
@@ -7421,6 +7417,7 @@ def test_take_mode_raise(self):
74217417
np.take(a, [0, 2], out=out, mode='raise')
74227418
assert_equal(out, np.array([0, 2]))
74237419

7420+
@pytest.mark.xfail(reason="XXX: choose()")
74247421
def test_choose_mod_raise(self):
74257422
a = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
74267423
out = np.empty((3,3), dtype='int')
@@ -7430,6 +7427,7 @@ def test_choose_mod_raise(self):
74307427
[-10, 10, -10],
74317428
[ 10, -10, 10]]))
74327429

7430+
@pytest.mark.xfail(reason="XXX: ndarray.flat")
74337431
def test_flatiter__array__(self):
74347432
a = np.arange(9).reshape(3,3)
74357433
b = a.T.flat
@@ -7443,6 +7441,7 @@ def test_dot_out(self):
74437441
b = np.dot(a, a, out=a)
74447442
assert_equal(b, np.array([[15, 18, 21], [42, 54, 66], [69, 90, 111]]))
74457443

7444+
@pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()")
74467445
def test_view_assign(self):
74477446
from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_resolve
74487447

@@ -7461,6 +7460,7 @@ def test_view_assign(self):
74617460
arr_wb[...] = 100
74627461
assert_equal(arr, -100)
74637462

7463+
@pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()")
74647464
@pytest.mark.leaks_references(
74657465
reason="increments self in dealloc; ignore since deprecated path.")
74667466
def test_dealloc_warning(self):
@@ -7471,6 +7471,7 @@ def test_dealloc_warning(self):
74717471
_multiarray_tests.npy_abuse_writebackifcopy(v)
74727472
assert len(sup.log) == 1
74737473

7474+
@pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()")
74747475
def test_view_discard_refcount(self):
74757476
from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_discard
74767477

torch_np/tests/test_xps.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,42 @@
44
These tests aren't specifically for testing Array API adoption!
55
"""
66
import cmath
7+
import math
78
import warnings
89

910
import pytest
1011

1112
pytest.importorskip("hypothesis")
1213

13-
from hypothesis import given
14+
import numpy as np
15+
import torch
16+
from hypothesis import given, note
1417
from hypothesis import strategies as st
1518
from hypothesis.errors import HypothesisWarning
19+
from hypothesis.extra import numpy as nps
1620
from hypothesis.extra.array_api import make_strategies_namespace
1721

18-
import torch_np as np
22+
import torch_np as tnp
23+
from torch_np._dtypes import sctypes
24+
from torch_np.testing import assert_array_equal
1925

2026
__all__ = ["xps"]
2127

2228
with warnings.catch_warnings():
2329
warnings.filterwarnings("ignore", category=HypothesisWarning)
24-
np.bool = np.bool_
25-
xps = make_strategies_namespace(np, api_version="2022.12")
30+
tnp.bool = tnp.bool_
31+
xps = make_strategies_namespace(tnp, api_version="2022.12")
2632

2733

28-
default_dtypes = [np.bool, np.int64, np.float64, np.complex128]
34+
default_dtypes = [tnp.bool, tnp.int64, tnp.float64, tnp.complex128]
2935
kind_to_strat = {
3036
"b": xps.boolean_dtypes(),
3137
"i": xps.integer_dtypes(),
3238
"u": xps.unsigned_integer_dtypes(sizes=8),
3339
"f": xps.floating_dtypes(),
3440
"c": xps.complex_dtypes(),
3541
}
36-
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype)
42+
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(tnp.dtype)
3743

3844

3945
@pytest.mark.skip(reason="flaky")
@@ -55,14 +61,14 @@ def test_full(shape, data):
5561
else:
5662
values_dtypes_strat = st.just(_dtype)
5763
values_strat = values_dtypes_strat.flatmap(
58-
lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d))
64+
lambda d: values_strat.map(lambda v: tnp.asarray(v, dtype=d))
5965
)
6066
fill_value = data.draw(values_strat, label="fill_value")
61-
out = np.full(shape, fill_value, **kw)
67+
out = tnp.full(shape, fill_value, **kw)
6268
assert out.dtype == _dtype
6369
assert out.shape == shape
6470
if cmath.isnan(fill_value):
65-
assert np.isnan(out).all()
71+
assert tnp.isnan(out).all()
6672
else:
6773
assert (out == fill_value).all()
6874

@@ -89,3 +95,52 @@ def test_integer_indexing(x, data):
8995
idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx")
9096
result = x[idx]
9197
assert result.shape == result_shape
98+
99+
100+
@pytest.mark.filterwarnings(
101+
"ignore:Creating a tensor from a list of numpy.ndarrays.*:UserWarning"
102+
)
103+
@given(
104+
np_x=nps.arrays(
105+
# We specifically use namespaced dtypes to prevent non-native byte-order issues
106+
dtype=scalar_dtype_strat.map(lambda d: getattr(np, d.name)),
107+
shape=nps.array_shapes(),
108+
),
109+
data=st.data(),
110+
)
111+
def test_put(np_x, data):
112+
# We cast arrays from torch_np.asarray as currently it doesn't carry over
113+
# dtypes. XXX: Remove the below sanity check and subsequent casting when
114+
# this is fixed.
115+
assert tnp.asarray(np.zeros(5, dtype=np.int16)).dtype != tnp.int16
116+
117+
tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name)
118+
119+
result_shape = data.draw(nps.array_shapes(), label="result_shape")
120+
if result_shape == ():
121+
ind_strat = st.integers(np_x.size)
122+
else:
123+
ind_strat = nps.integer_array_indices(
124+
np_x.shape, result_shape=st.just(result_shape)
125+
)
126+
ind = data.draw(ind_strat | ind_strat.map(np.asarray), label="ind")
127+
v = data.draw(
128+
nps.arrays(
129+
dtype=np_x.dtype,
130+
shape=nps.array_shapes().filter(
131+
lambda s: math.prod(s) > math.prod(result_shape)
132+
),
133+
),
134+
label="v",
135+
)
136+
137+
tnp_x_copy = tnp_x.copy()
138+
np.put(np_x, ind, v)
139+
note(f"(after put) {np_x=}")
140+
assert_array_equal(tnp_x, tnp_x_copy) # sanity check
141+
142+
note(f"{tnp_x=}")
143+
tnp.put(tnp_x, ind, v)
144+
note(f"(after put) {tnp_x=}")
145+
146+
assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))

0 commit comments

Comments
 (0)