Skip to content

Un-xfail tests which needed fancy indexing; add .view #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def copy(self, order="C"):
tensor = self._tensor.clone()
return ndarray._from_tensor_and_base(tensor, None)

def view(self, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should to the refactor where methods are implemented in terms of free functions sooner than later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I was holding this refactor to minimize the number of conflicts for gh-23. Now that it's in, it's about time indeed.

# XXX: 1) make dtype_to_torch decorator understand positional args
# 2) rid of .base and _from_tensor_and_base
torch_dtype = _dtypes.dtype(dtype).torch_dtype
tview = self._tensor.view(torch_dtype)
return ndarray._from_tensor_and_base(tview, self)

def tolist(self):
return self._tensor.tolist()

Expand Down
3 changes: 2 additions & 1 deletion torch_np/tests/numpy_tests/core/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ def test_logical_and_or_xor(self):
assert_array_equal(self.im ^ False, self.im)


@pytest.mark.xfail(reason="TODO: needs fancy indexing")
class TestBoolCmp:
def setup_method(self):
self.f = np.ones(256, dtype=np.float32)
Expand Down Expand Up @@ -514,6 +513,7 @@ def test_float(self):
r3 = 0 != self.f[i:]
assert_array_equal(r, r2)
assert_array_equal(r, r3)

# check bool == 0x1
assert_array_equal(r.view(np.int8), r.astype(np.int8))
assert_array_equal(r2.view(np.int8), r2.astype(np.int8))
Expand All @@ -540,6 +540,7 @@ def test_double(self):
r3 = 0 != self.d[i:]
assert_array_equal(r, r2)
assert_array_equal(r, r3)

# check bool == 0x1
assert_array_equal(r.view(np.int8), r.astype(np.int8))
assert_array_equal(r2.view(np.int8), r2.astype(np.int8))
Expand Down
2 changes: 0 additions & 2 deletions torch_np/tests/numpy_tests/lib/test_twodim_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ def test_mask_indices():
assert_array_equal(a[iu1], array([1, 2, 5]))


@pytest.mark.xfail(reason="TODO: fancy indexing")
def test_tril_indices():
# indices without and with offset
il1 = tril_indices(4)
Expand Down Expand Up @@ -413,7 +412,6 @@ def test_tril_indices():
[-10, -10, -10, -10, -10]]))


@pytest.mark.xfail(reason="TODO: fancy indexing")
class TestTriuIndices:
def test_triu_indices(self):
iu1 = triu_indices(4)
Expand Down