Skip to content

tnp.put() + testing #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def sctype_from_string(s):
return _aliases[s]
if s in _python_types:
return _python_types[s]
raise TypeError(f"data type '{s}' not understood")
raise TypeError(f"data type {s!r} not understood")


def sctype_from_torch_dtype(torch_dtype):
Expand Down
21 changes: 21 additions & 0 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,27 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
return torch.take_along_dim(arr, indices, axis)


def put(
a: NDArray,
ind: ArrayLike,
v: ArrayLike,
mode: NotImplementedType = "raise",
):
v = v.type(a.dtype)
# If ind is larger than v, expand v to at least the size of ind. Any
# unnecessary trailing elements are then trimmed.
if ind.numel() > v.numel():
ratio = (ind.numel() + v.numel() - 1) // v.numel()
v = v.unsqueeze(0).expand((ratio,) + v.shape)
# Trim unnecessary elements, regarldess if v was expanded or not. Note
# np.put() trims v to match ind by default too.
if ind.numel() < v.numel():
v = v.flatten()
v = v[: ind.numel()]
a.put_(ind, v)
return None


def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
axis = _util.normalize_axis_index(axis, arr.ndim)
Expand Down
15 changes: 15 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def __getitem__(self, key):
else:
raise KeyError(f"No flag key '{key}'")

def __setattr__(self, attr, value):
if attr.islower() and attr.upper() in FLAGS:
self[attr.upper()] = value
else:
super().__setattr__(attr, value)

def __setitem__(self, key, value):
if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
raise NotImplementedError("Modifying flags is not implemented")
else:
raise KeyError(f"No flag key '{key}'")


def create_method(fn, name=None):
name = name or fn.__name__
Expand Down Expand Up @@ -397,6 +409,9 @@ def __setitem__(self, index, value):
value = _util.cast_if_needed(value, self.tensor.dtype)
return self.tensor.__setitem__(index, value)

take = _funcs.take
put = _funcs.put


# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.
Expand Down
2 changes: 1 addition & 1 deletion torch_np/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"):

## with errstate(all='ignore'):
# ignore errors for non-numeric types
with contextlib.suppress(TypeError):
with contextlib.suppress(TypeError, RuntimeError):
error = abs(x - y)
if np.issubdtype(x.dtype, np.unsignedinteger):
error2 = abs(y - x)
Expand Down
17 changes: 9 additions & 8 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2637,11 +2637,10 @@ def test_trace(self):
ret = a.trace(out=out)
assert ret is out

@pytest.mark.xfail(reason="TODO: implement put")
def test_put(self):
icodes = np.typecodes['AllInteger']
fcodes = np.typecodes['AllFloat']
for dt in icodes + fcodes + 'O':
for dt in icodes + fcodes:
tgt = np.array([0, 1, 0, 3, 0, 5], dtype=dt)

# test 1-d
Expand All @@ -2667,11 +2666,6 @@ def test_put(self):
a.put([1, 3, 5], [True]*3)
assert_equal(a, tgt.reshape(2, 3))

# check must be writeable
a = np.zeros(6)
a.flags.writeable = False
assert_raises(ValueError, a.put, [1, 3, 5], [1, 3, 5])

# when calling np.put, make sure a
# TypeError is raised if the object
# isn't an ndarray
Expand Down Expand Up @@ -7381,7 +7375,6 @@ def test_1d_format(self):
from numpy.testing import IS_PYPY


@pytest.mark.skip(reason="not going to implement WRITEBACKIFCOPY")
class TestWritebackIfCopy:
# all these tests use the WRITEBACKIFCOPY mechanism
def test_argmax_with_out(self):
Expand All @@ -7396,6 +7389,7 @@ def test_argmin_with_out(self):
res = np.argmin(mat, 0, out=out)
assert_equal(res, range(5))

@pytest.mark.xfail(reason="XXX: place()")
def test_insert_noncontiguous(self):
a = np.arange(6).reshape(2,3).T # force non-c-contiguous
# uses arr_insert
Expand All @@ -7406,9 +7400,11 @@ def test_insert_noncontiguous(self):

def test_put_noncontiguous(self):
a = np.arange(6).reshape(2,3).T # force non-c-contiguous
assert not a.flags["C_CONTIGUOUS"] # sanity check
np.put(a, [0, 2], [44, 55])
assert_equal(a, np.array([[44, 3], [55, 4], [2, 5]]))

@pytest.mark.xfail(reason="XXX: putmask()")
def test_putmask_noncontiguous(self):
a = np.arange(6).reshape(2,3).T # force non-c-contiguous
# uses arr_putmask
Expand All @@ -7421,6 +7417,7 @@ def test_take_mode_raise(self):
np.take(a, [0, 2], out=out, mode='raise')
assert_equal(out, np.array([0, 2]))

@pytest.mark.xfail(reason="XXX: choose()")
def test_choose_mod_raise(self):
a = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
out = np.empty((3,3), dtype='int')
Expand All @@ -7430,6 +7427,7 @@ def test_choose_mod_raise(self):
[-10, 10, -10],
[ 10, -10, 10]]))

@pytest.mark.xfail(reason="XXX: ndarray.flat")
def test_flatiter__array__(self):
a = np.arange(9).reshape(3,3)
b = a.T.flat
Expand All @@ -7443,6 +7441,7 @@ def test_dot_out(self):
b = np.dot(a, a, out=a)
assert_equal(b, np.array([[15, 18, 21], [42, 54, 66], [69, 90, 111]]))

@pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()")
def test_view_assign(self):
from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_resolve

Expand All @@ -7461,6 +7460,7 @@ def test_view_assign(self):
arr_wb[...] = 100
assert_equal(arr, -100)

@pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()")
@pytest.mark.leaks_references(
reason="increments self in dealloc; ignore since deprecated path.")
def test_dealloc_warning(self):
Expand All @@ -7471,6 +7471,7 @@ def test_dealloc_warning(self):
_multiarray_tests.npy_abuse_writebackifcopy(v)
assert len(sup.log) == 1

@pytest.mark.skip(reason="XXX: npy_create_writebackifcopy()")
def test_view_discard_refcount(self):
from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_discard

Expand Down
73 changes: 64 additions & 9 deletions torch_np/tests/test_xps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,42 @@
These tests aren't specifically for testing Array API adoption!
"""
import cmath
import math
import warnings

import pytest

pytest.importorskip("hypothesis")

from hypothesis import given
import numpy as np
import torch
from hypothesis import given, note
from hypothesis import strategies as st
from hypothesis.errors import HypothesisWarning
from hypothesis.extra import numpy as nps
from hypothesis.extra.array_api import make_strategies_namespace

import torch_np as np
import torch_np as tnp
from torch_np._dtypes import sctypes
from torch_np.testing import assert_array_equal

__all__ = ["xps"]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=HypothesisWarning)
np.bool = np.bool_
xps = make_strategies_namespace(np, api_version="2022.12")
tnp.bool = tnp.bool_
xps = make_strategies_namespace(tnp, api_version="2022.12")


default_dtypes = [np.bool, np.int64, np.float64, np.complex128]
default_dtypes = [tnp.bool, tnp.int64, tnp.float64, tnp.complex128]
kind_to_strat = {
"b": xps.boolean_dtypes(),
"i": xps.integer_dtypes(),
"u": xps.unsigned_integer_dtypes(sizes=8),
"f": xps.floating_dtypes(),
"c": xps.complex_dtypes(),
}
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype)
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(tnp.dtype)


@pytest.mark.skip(reason="flaky")
Expand All @@ -55,14 +61,14 @@ def test_full(shape, data):
else:
values_dtypes_strat = st.just(_dtype)
values_strat = values_dtypes_strat.flatmap(
lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d))
lambda d: values_strat.map(lambda v: tnp.asarray(v, dtype=d))
)
fill_value = data.draw(values_strat, label="fill_value")
out = np.full(shape, fill_value, **kw)
out = tnp.full(shape, fill_value, **kw)
assert out.dtype == _dtype
assert out.shape == shape
if cmath.isnan(fill_value):
assert np.isnan(out).all()
assert tnp.isnan(out).all()
else:
assert (out == fill_value).all()

Expand All @@ -89,3 +95,52 @@ def test_integer_indexing(x, data):
idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx")
result = x[idx]
assert result.shape == result_shape


@pytest.mark.filterwarnings(
"ignore:Creating a tensor from a list of numpy.ndarrays.*:UserWarning"
)
@given(
np_x=nps.arrays(
# We specifically use namespaced dtypes to prevent non-native byte-order issues
dtype=scalar_dtype_strat.map(lambda d: getattr(np, d.name)),
shape=nps.array_shapes(),
),
data=st.data(),
)
def test_put(np_x, data):
# We cast arrays from torch_np.asarray as currently it doesn't carry over
# dtypes. XXX: Remove the below sanity check and subsequent casting when
# this is fixed.
assert tnp.asarray(np.zeros(5, dtype=np.int16)).dtype != tnp.int16

tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name)

result_shape = data.draw(nps.array_shapes(), label="result_shape")
if result_shape == ():
ind_strat = st.integers(np_x.size)
else:
ind_strat = nps.integer_array_indices(
np_x.shape, result_shape=st.just(result_shape)
)
ind = data.draw(ind_strat | ind_strat.map(np.asarray), label="ind")
v = data.draw(
nps.arrays(
dtype=np_x.dtype,
shape=nps.array_shapes().filter(
lambda s: math.prod(s) > math.prod(result_shape)
),
),
label="v",
)

tnp_x_copy = tnp_x.copy()
np.put(np_x, ind, v)
note(f"(after put) {np_x=}")
assert_array_equal(tnp_x, tnp_x_copy) # sanity check

note(f"{tnp_x=}")
tnp.put(tnp_x, ind, v)
note(f"(after put) {tnp_x=}")

assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))