-
Notifications
You must be signed in to change notification settings - Fork 4
Port indexing tests #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
217d377
0cf51b7
6c920fd
6163b4b
e5cc442
4b6316c
fcc49e8
7a50ba9
fbe9cdb
c6b8952
7c1b7e8
bd06c19
a7c93f0
e07e341
792f7f2
57d6914
934023b
70fd4b2
d82033e
a1afd1a
1eeaaf7
4adc36f
e9c3b9a
4cedb5b
817b9b4
2aa9cb3
617086d
1c26579
55d6609
48ef530
7ac2b5d
8e6b59a
81215c3
35a5da4
bb762e3
33c6be7
e151168
d9faf74
29e7fb7
7bbf1d4
c8731f9
062452b
7e53288
42e0c45
9b1fce1
ea679a8
d5e2408
36e343e
215cf55
9cf3de8
a75ae90
5afc0e7
581917b
173bba6
8e857b8
40dd254
ba66944
c51d6cc
37e9adb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Vendored objects from numpy.lib.index_tricks | ||
""" | ||
__all__ = ["index_exp", "s_"] | ||
|
||
|
||
class IndexExpression: | ||
""" | ||
Written by Konrad Hinsen <[email protected]> | ||
last revision: 1999-7-23 | ||
|
||
Cosmetic changes by T. Oliphant 2001 | ||
""" | ||
|
||
def __init__(self, maketuple): | ||
self.maketuple = maketuple | ||
|
||
def __getitem__(self, item): | ||
if self.maketuple and not isinstance(item, tuple): | ||
return (item,) | ||
else: | ||
return item | ||
|
||
|
||
index_exp = IndexExpression(maketuple=True) | ||
s_ = IndexExpression(maketuple=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -381,14 +381,24 @@ def repeat(self, repeats, axis=None): | |
) | ||
|
||
### indexing ### | ||
def __getitem__(self, *args, **kwds): | ||
t_args = _helpers.ndarrays_to_tensors(*args) | ||
return ndarray._from_tensor_and_base( | ||
self._tensor.__getitem__(*t_args, **kwds), self | ||
) | ||
@staticmethod | ||
def _upcast_int_indices(index): | ||
if isinstance(index, torch.Tensor): | ||
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): | ||
return index.to(torch.int64) | ||
elif isinstance(index, tuple): | ||
return tuple(ndarray._upcast_int_indices(i) for i in index) | ||
return index | ||
|
||
def __getitem__(self, index): | ||
index = _helpers.ndarrays_to_tensors(index) | ||
index = ndarray._upcast_int_indices(index) | ||
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes indeed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's discuss this one in today's meeting |
||
|
||
def __setitem__(self, index, value): | ||
value = asarray(value).get() | ||
index = _helpers.ndarrays_to_tensors(index) | ||
index = ndarray._upcast_int_indices(index) | ||
value = _helpers.ndarrays_to_tensors(value) | ||
return self._tensor.__setitem__(index, value) | ||
|
||
### sorting ### | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from .utils import ( | ||
HAS_REFCOUNT, | ||
IS_WASM, | ||
_gen_alignment_data, | ||
assert_, | ||
assert_allclose, | ||
|
Uh oh!
There was an error while loading. Please reload this page.