Skip to content

Commit 8bdf8b6

Browse files
authored
Merge pull request #23 from honno/test-indexing
Port indexing tests
2 parents 4c5eb7d + 37e9adb commit 8bdf8b6

13 files changed

+1324
-31
lines changed

torch_np/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
22
from . import random
33
from ._binary_ufuncs import *
4+
from ._detail._index_tricks import *
45
from ._detail._util import AxisError, UFuncTypeError
56
from ._dtypes import *
67
from ._getlimits import errstate, finfo, iinfo
@@ -15,3 +16,6 @@
1516
inf = float("inf")
1617
nan = float("nan")
1718
from math import pi # isort: skip
19+
20+
False_ = asarray(False, bool_)
21+
True_ = asarray(True, bool_)

torch_np/_detail/_index_tricks.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Vendored objects from numpy.lib.index_tricks
3+
"""
4+
__all__ = ["index_exp", "s_"]
5+
6+
7+
class IndexExpression:
8+
"""
9+
Written by Konrad Hinsen <[email protected]>
10+
last revision: 1999-7-23
11+
12+
Cosmetic changes by T. Oliphant 2001
13+
"""
14+
15+
def __init__(self, maketuple):
16+
self.maketuple = maketuple
17+
18+
def __getitem__(self, item):
19+
if self.maketuple and not isinstance(item, tuple):
20+
return (item,)
21+
else:
22+
return item
23+
24+
25+
index_exp = IndexExpression(maketuple=True)
26+
s_ = IndexExpression(maketuple=False)

torch_np/_detail/implementations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3+
from .. import _helpers
34
from . import _dtypes_impl, _util
45

56
# ### equality, equivalence, allclose ###
@@ -503,8 +504,10 @@ def reshape(tensor, *shape, order="C"):
503504
if order != "C":
504505
raise NotImplementedError
505506
newshape = shape[0] if len(shape) == 1 else shape
507+
# convert any tnp.ndarray inputs into tensors before passing to torch.Tensor.reshape
508+
t_newshape = _helpers.ndarrays_to_tensors(newshape)
506509
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
507-
result = tensor.reshape(newshape)
510+
result = tensor.reshape(t_newshape)
508511
return result
509512

510513

torch_np/_dtypes.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,23 +161,23 @@ class bool_(generic):
161161

162162

163163
# name aliases : FIXME (OS, bitness)
164-
intp = int64
165-
int_ = int64
166-
intc = int32
167-
168-
byte = int8
169-
short = int16
170-
longlong = int64 # XXX: is this correct?
171-
172-
ubyte = uint8
173-
174-
half = float16
175-
single = float32
176-
double = float64
177-
float_ = float64
178-
179-
csingle = complex64
180-
cdouble = complex128
164+
_name_aliases = {
165+
"intp": int64,
166+
"int_": int64,
167+
"intc": int32,
168+
"byte": int8,
169+
"short": int16,
170+
"longlong": int64, # XXX: is this correct?
171+
"ubyte": uint8,
172+
"half": float16,
173+
"single": float32,
174+
"double": float64,
175+
"float_": float64,
176+
"csingle": complex64,
177+
"cdouble": complex128,
178+
}
179+
for name, obj in _name_aliases.items():
180+
globals()[name] = obj
181181

182182

183183
# Replicate this NumPy-defined way of grouping scalar types,
@@ -232,6 +232,8 @@ def sctype_from_string(s):
232232
"""Normalize a string value: a type 'name' or a typecode or a width alias."""
233233
if s in _names:
234234
return _names[s]
235+
if s in _name_aliases.keys():
236+
return _name_aliases[s]
235237
if s in _typecodes:
236238
return _typecodes[s]
237239
if s in _aliases:

torch_np/_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,23 @@ def ndarrays_to_tensors(*inputs):
7474
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
7575
from ._ndarray import asarray, ndarray
7676

77-
return tuple(
78-
[value.get() if isinstance(value, ndarray) else value for value in inputs]
79-
)
77+
if len(inputs) == 0:
78+
return ValueError()
79+
elif len(inputs) == 1:
80+
input_ = inputs[0]
81+
if isinstance(input_, ndarray):
82+
return input_.get()
83+
elif isinstance(input_, tuple):
84+
result = []
85+
for sub_input in input_:
86+
sub_result = ndarrays_to_tensors(sub_input)
87+
result.append(sub_result)
88+
return tuple(result)
89+
else:
90+
return input_
91+
else:
92+
assert isinstance(inputs, tuple) # sanity check
93+
return ndarrays_to_tensors(inputs)
8094

8195

8296
def to_tensors(*inputs):

torch_np/_ndarray.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,24 @@ def repeat(self, repeats, axis=None):
381381
)
382382

383383
### indexing ###
384-
def __getitem__(self, *args, **kwds):
385-
t_args = _helpers.ndarrays_to_tensors(*args)
386-
return ndarray._from_tensor_and_base(
387-
self._tensor.__getitem__(*t_args, **kwds), self
388-
)
384+
@staticmethod
385+
def _upcast_int_indices(index):
386+
if isinstance(index, torch.Tensor):
387+
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
388+
return index.to(torch.int64)
389+
elif isinstance(index, tuple):
390+
return tuple(ndarray._upcast_int_indices(i) for i in index)
391+
return index
392+
393+
def __getitem__(self, index):
394+
index = _helpers.ndarrays_to_tensors(index)
395+
index = ndarray._upcast_int_indices(index)
396+
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
389397

390398
def __setitem__(self, index, value):
391-
value = asarray(value).get()
399+
index = _helpers.ndarrays_to_tensors(index)
400+
index = ndarray._upcast_int_indices(index)
401+
value = _helpers.ndarrays_to_tensors(value)
392402
return self._tensor.__setitem__(index, value)
393403

394404
### sorting ###

torch_np/_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
291291
@_decorators.dtype_to_torch
292292
def full(shape, fill_value, dtype=None, order="C", *, like=None):
293293
_util.subok_not_ok(like)
294+
if isinstance(shape, int):
295+
shape = (shape,)
294296
if order != "C":
295297
raise NotImplementedError
296298
fill_value = asarray(fill_value).get()

torch_np/testing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .utils import (
2+
HAS_REFCOUNT,
3+
IS_WASM,
24
_gen_alignment_data,
35
assert_,
46
assert_allclose,

torch_np/testing/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from tempfile import mkdtemp, mkstemp
1818
from warnings import WarningMessage
1919

20-
from pytest import raises as assert_raises
21-
2220
import torch_np as np
2321
from torch_np import arange, array
2422
from torch_np import asarray as asanyarray
@@ -31,8 +29,8 @@
3129
"assert_array_equal",
3230
"assert_array_less",
3331
"assert_string_equal",
32+
"assert_",
3433
"assert_array_almost_equal",
35-
"assert_raises",
3634
"build_err_msg",
3735
"decorate_methods",
3836
"print_assert_equal",

0 commit comments

Comments
 (0)