diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index a4d8fb21..eff3ee53 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -155,6 +155,13 @@ def copy(self, order="C"): tensor = self._tensor.clone() return ndarray._from_tensor_and_base(tensor, None) + def view(self, dtype): + # XXX: 1) make dtype_to_torch decorator understand positional args + # 2) rid of .base and _from_tensor_and_base + torch_dtype = _dtypes.dtype(dtype).torch_dtype + tview = self._tensor.view(torch_dtype) + return ndarray._from_tensor_and_base(tview, self) + def tolist(self): return self._tensor.tolist() diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index c735fa81..33e78475 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -452,7 +452,6 @@ def test_logical_and_or_xor(self): assert_array_equal(self.im ^ False, self.im) -@pytest.mark.xfail(reason="TODO: needs fancy indexing") class TestBoolCmp: def setup_method(self): self.f = np.ones(256, dtype=np.float32) @@ -514,6 +513,7 @@ def test_float(self): r3 = 0 != self.f[i:] assert_array_equal(r, r2) assert_array_equal(r, r3) + # check bool == 0x1 assert_array_equal(r.view(np.int8), r.astype(np.int8)) assert_array_equal(r2.view(np.int8), r2.astype(np.int8)) @@ -540,6 +540,7 @@ def test_double(self): r3 = 0 != self.d[i:] assert_array_equal(r, r2) assert_array_equal(r, r3) + # check bool == 0x1 assert_array_equal(r.view(np.int8), r.astype(np.int8)) assert_array_equal(r2.view(np.int8), r2.astype(np.int8)) diff --git a/torch_np/tests/numpy_tests/lib/test_twodim_base.py b/torch_np/tests/numpy_tests/lib/test_twodim_base.py index bfd65530..a178014b 100644 --- a/torch_np/tests/numpy_tests/lib/test_twodim_base.py +++ b/torch_np/tests/numpy_tests/lib/test_twodim_base.py @@ -365,7 +365,6 @@ def test_mask_indices(): assert_array_equal(a[iu1], array([1, 2, 5])) -@pytest.mark.xfail(reason="TODO: fancy indexing") def test_tril_indices(): # indices without and with offset il1 = tril_indices(4) @@ -413,7 +412,6 @@ def test_tril_indices(): [-10, -10, -10, -10, -10]])) -@pytest.mark.xfail(reason="TODO: fancy indexing") class TestTriuIndices: def test_triu_indices(self): iu1 = triu_indices(4)