File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -159,6 +159,13 @@ def copy(self, order="C"):
159
159
tensor = self ._tensor .clone ()
160
160
return ndarray ._from_tensor_and_base (tensor , None )
161
161
162
+ def view (self , dtype ):
163
+ # XXX: 1) make dtype_to_torch decorator understand positional args
164
+ # 2) rid of .base and _from_tensor_and_base
165
+ torch_dtype = _dtypes .dtype (dtype ).torch_dtype
166
+ tview = self ._tensor .view (torch_dtype )
167
+ return ndarray ._from_tensor_and_base (tview , self )
168
+
162
169
def tolist (self ):
163
170
return self ._tensor .tolist ()
164
171
Original file line number Diff line number Diff line change @@ -453,7 +453,6 @@ def test_logical_and_or_xor(self):
453
453
assert_array_equal (self .im ^ False , self .im )
454
454
455
455
456
- @pytest .mark .xfail (reason = "TODO: needs fancy indexing" )
457
456
class TestBoolCmp :
458
457
def setup_method (self ):
459
458
self .f = np .ones (256 , dtype = np .float32 )
@@ -515,6 +514,7 @@ def test_float(self):
515
514
r3 = 0 != self .f [i :]
516
515
assert_array_equal (r , r2 )
517
516
assert_array_equal (r , r3 )
517
+
518
518
# check bool == 0x1
519
519
assert_array_equal (r .view (np .int8 ), r .astype (np .int8 ))
520
520
assert_array_equal (r2 .view (np .int8 ), r2 .astype (np .int8 ))
@@ -541,6 +541,7 @@ def test_double(self):
541
541
r3 = 0 != self .d [i :]
542
542
assert_array_equal (r , r2 )
543
543
assert_array_equal (r , r3 )
544
+
544
545
# check bool == 0x1
545
546
assert_array_equal (r .view (np .int8 ), r .astype (np .int8 ))
546
547
assert_array_equal (r2 .view (np .int8 ), r2 .astype (np .int8 ))
You can’t perform that action at this time.
0 commit comments