Skip to content

Port indexing tests #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 59 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
217d377
Mirror `test_indexing.py`
honno Jan 20, 2023
0cf51b7
Assert for `TypeError` where appropiate for index validation tests
honno Jan 23, 2023
6c920fd
xfail `test_index_no_array_to_index`
honno Jan 23, 2023
6163b4b
Partially xfail `test_empty_tuple_index`
honno Jan 23, 2023
e5cc442
xfail `test_void_scalar_empty_tuple`
honno Jan 23, 2023
4b6316c
xfail `test_same_kind_index_casting`
honno Jan 24, 2023
fcc49e8
Allow for `RuntimeError` as well in `test_single_int_index`
honno Jan 24, 2023
7a50ba9
Smoke test for advance integer indexing
honno Jan 25, 2023
fbe9cdb
Upcast int tensor indices
honno Jan 25, 2023
c6b8952
Normalise indices for setitem
honno Jan 25, 2023
7c1b7e8
Update xfail reason for `test_same_kind_index_casting`
honno Jan 25, 2023
bd06c19
Allow `RuntimeError` in `test_boolean_assignment_value_mismatch`
honno Jan 31, 2023
a7c93f0
xfail `test_boolean_assignment_needs_api`
honno Jan 31, 2023
e07e341
xfail indexing tests for negative slice steps
honno Jan 31, 2023
792f7f2
Add `torch_np/tests/__init__.py` to fix relative imports
honno Feb 9, 2023
57d6914
Use correct helper in `__setitem__()`
honno Feb 9, 2023
934023b
xfail `test_too_many_fancy_indices_special_case`
honno Feb 9, 2023
70fd4b2
Include `RuntimeError` in `test_trivial_fancy_out_of_bounds` testing
honno Feb 9, 2023
d82033e
xfail view stuff in `test_indexing.py`
honno Feb 9, 2023
a1afd1a
xfail `test_subclass_writeable`
honno Feb 15, 2023
1eeaaf7
xfail `test_memory_order`
honno Feb 15, 2023
4adc36f
xfail `test_scalar_return_type`
honno Feb 15, 2023
e9c3b9a
xfail `test_small_regressions`
honno Feb 15, 2023
4cedb5b
xfail `test_broken_sequence_not_nd_index`
honno Feb 15, 2023
817b9b4
xfail `test_character_assignment`
honno Feb 15, 2023
2aa9cb3
Selectively xfail `test_too_many_advanced_indices`
honno Feb 15, 2023
617086d
Mirror `np.s_` and `np.index_exp`
honno Feb 15, 2023
1c26579
Allow for `RuntimeError` in more tests
honno Feb 15, 2023
55d6609
Broaden `test_broadcast_error_reports_correct_shape` for torch
honno Feb 16, 2023
48ef530
xfail `test_boolean_index_cast_assign`
honno Feb 16, 2023
7ac2b5d
xfail `test_object_assign`
honno Feb 16, 2023
8e6b59a
xfail `test_cast_equivalence`
honno Feb 16, 2023
81215c3
Partially xfail `test_nontuple_ndindex`
honno Feb 17, 2023
35a5da4
xfail before use of currently-unimplemented `take()`
honno Feb 17, 2023
bb762e3
xfail `test_non_integer_sequence_multiplication`
honno Feb 17, 2023
33c6be7
Convert wrapped array shape inputs into tensors for `x.reshape()`
honno Feb 20, 2023
e151168
Add `tnp.True_` and `tnp.False_` aliases
honno Feb 20, 2023
d9faf74
Partially xfail `test_bool_as_int_argument_errors`
honno Feb 20, 2023
29e7fb7
Remove finnicky regex assertions in `test_boolean_indexing_fast_path`
honno Feb 20, 2023
7bbf1d4
Partially xfail `test_array_to_index_error`
honno Feb 20, 2023
c8731f9
xfail `TestNonIntegerArrayLike::test_basic`
honno Feb 20, 2023
062452b
xfail `TestMultipleEllipsisError::test_basic`
honno Feb 20, 2023
7e53288
Recognize dtype name aliases in `dtype()`
honno Feb 20, 2023
42e0c45
xfail `TestMultiIndexingAutomated`
honno Feb 20, 2023
9b1fce1
Skip `test_full`
honno Feb 20, 2023
ea679a8
Skip the view related tests
honno Feb 20, 2023
d5e2408
Clarify that its torch which deprecates uint8s for indexing
honno Feb 20, 2023
36e343e
Skip rather than xfail cases which clearly won't be implemented
honno Feb 20, 2023
215cf55
Favour `t.to()` over `t.type()` for dtype conversions
honno Feb 20, 2023
9cf3de8
Regression test for recursive behaviour in `ndarrays_to_tensors()`
honno Feb 20, 2023
a75ae90
Accept uint8s as advance integer indices
honno Feb 20, 2023
5afc0e7
Prefer tuples to lists
honno Feb 21, 2023
581917b
Vendor `s_` and `index_exp`
honno Feb 21, 2023
173bba6
Just delete tests which aren't even interesting to keep skip
honno Feb 21, 2023
8e857b8
Style
honno Feb 21, 2023
40dd254
Comment why normalise `newshape` in `_detail.implementations.reshape`
honno Feb 21, 2023
ba66944
xfail `test_index_is_larger`
honno Feb 21, 2023
c51d6cc
Use `ndarays_to_tensors()` for normalising setitem value
honno Feb 21, 2023
37e9adb
Update `test_boolean_assignment_value_mismatch` expected errors
honno Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
from . import random
from ._binary_ufuncs import *
from ._detail._index_tricks import *
from ._detail._util import AxisError, UFuncTypeError
from ._dtypes import *
from ._getlimits import errstate, finfo, iinfo
Expand All @@ -15,3 +16,6 @@
inf = float("inf")
nan = float("nan")
from math import pi # isort: skip

False_ = asarray(False, bool_)
True_ = asarray(True, bool_)
26 changes: 26 additions & 0 deletions torch_np/_detail/_index_tricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Vendored objects from numpy.lib.index_tricks
"""
__all__ = ["index_exp", "s_"]


class IndexExpression:
"""
Written by Konrad Hinsen <[email protected]>
last revision: 1999-7-23

Cosmetic changes by T. Oliphant 2001
"""

def __init__(self, maketuple):
self.maketuple = maketuple

def __getitem__(self, item):
if self.maketuple and not isinstance(item, tuple):
return (item,)
else:
return item


index_exp = IndexExpression(maketuple=True)
s_ = IndexExpression(maketuple=False)
5 changes: 4 additions & 1 deletion torch_np/_detail/implementations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from .. import _helpers
from . import _dtypes_impl, _util

# ### equality, equivalence, allclose ###
Expand Down Expand Up @@ -503,8 +504,10 @@ def reshape(tensor, *shape, order="C"):
if order != "C":
raise NotImplementedError
newshape = shape[0] if len(shape) == 1 else shape
# convert any tnp.ndarray inputs into tensors before passing to torch.Tensor.reshape
t_newshape = _helpers.ndarrays_to_tensors(newshape)
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
result = tensor.reshape(newshape)
result = tensor.reshape(t_newshape)
return result


Expand Down
36 changes: 19 additions & 17 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,23 @@ class bool_(generic):


# name aliases : FIXME (OS, bitness)
intp = int64
int_ = int64
intc = int32

byte = int8
short = int16
longlong = int64 # XXX: is this correct?

ubyte = uint8

half = float16
single = float32
double = float64
float_ = float64

csingle = complex64
cdouble = complex128
_name_aliases = {
"intp": int64,
"int_": int64,
"intc": int32,
"byte": int8,
"short": int16,
"longlong": int64, # XXX: is this correct?
"ubyte": uint8,
"half": float16,
"single": float32,
"double": float64,
"float_": float64,
"csingle": complex64,
"cdouble": complex128,
}
for name, obj in _name_aliases.items():
globals()[name] = obj


# Replicate this NumPy-defined way of grouping scalar types,
Expand Down Expand Up @@ -232,6 +232,8 @@ def sctype_from_string(s):
"""Normalize a string value: a type 'name' or a typecode or a width alias."""
if s in _names:
return _names[s]
if s in _name_aliases.keys():
return _name_aliases[s]
if s in _typecodes:
return _typecodes[s]
if s in _aliases:
Expand Down
20 changes: 17 additions & 3 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,23 @@ def ndarrays_to_tensors(*inputs):
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
from ._ndarray import asarray, ndarray

return tuple(
[value.get() if isinstance(value, ndarray) else value for value in inputs]
)
if len(inputs) == 0:
return ValueError()
elif len(inputs) == 1:
input_ = inputs[0]
if isinstance(input_, ndarray):
return input_.get()
elif isinstance(input_, tuple):
result = []
for sub_input in input_:
sub_result = ndarrays_to_tensors(sub_input)
result.append(sub_result)
return tuple(result)
else:
return input_
else:
assert isinstance(inputs, tuple) # sanity check
return ndarrays_to_tensors(inputs)


def to_tensors(*inputs):
Expand Down
22 changes: 16 additions & 6 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,24 @@ def repeat(self, repeats, axis=None):
)

### indexing ###
def __getitem__(self, *args, **kwds):
t_args = _helpers.ndarrays_to_tensors(*args)
return ndarray._from_tensor_and_base(
self._tensor.__getitem__(*t_args, **kwds), self
)
@staticmethod
def _upcast_int_indices(index):
if isinstance(index, torch.Tensor):
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
return index.to(torch.int64)
elif isinstance(index, tuple):
return tuple(ndarray._upcast_int_indices(i) for i in index)
return index

def __getitem__(self, index):
index = _helpers.ndarrays_to_tensors(index)
index = ndarray._upcast_int_indices(index)
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ev-br this is wrong, as __getitem__ may return a fresh tensor when using advanced indexing. This is related to #47. Of course, it was wrong before this PR, so no need to do anything here :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes indeed.
If the outcome of gh-47 is to keep the base, it'll need
base = self if result._base is self._tensor._base else None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's discuss this one in today's meeting


def __setitem__(self, index, value):
value = asarray(value).get()
index = _helpers.ndarrays_to_tensors(index)
index = ndarray._upcast_int_indices(index)
value = _helpers.ndarrays_to_tensors(value)
return self._tensor.__setitem__(index, value)

### sorting ###
Expand Down
2 changes: 2 additions & 0 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
@_decorators.dtype_to_torch
def full(shape, fill_value, dtype=None, order="C", *, like=None):
_util.subok_not_ok(like)
if isinstance(shape, int):
shape = (shape,)
if order != "C":
raise NotImplementedError
fill_value = asarray(fill_value).get()
Expand Down
2 changes: 2 additions & 0 deletions torch_np/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .utils import (
HAS_REFCOUNT,
IS_WASM,
_gen_alignment_data,
assert_,
assert_allclose,
Expand Down
4 changes: 1 addition & 3 deletions torch_np/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from tempfile import mkdtemp, mkstemp
from warnings import WarningMessage

from pytest import raises as assert_raises

import torch_np as np
from torch_np import arange, array
from torch_np import asarray as asanyarray
Expand All @@ -31,8 +29,8 @@
"assert_array_equal",
"assert_array_less",
"assert_string_equal",
"assert_",
"assert_array_almost_equal",
"assert_raises",
"build_err_msg",
"decorate_methods",
"print_assert_equal",
Expand Down
Loading