@@ -40,7 +40,7 @@ def test_index_no_floats(self):
40
40
assert_raises (IndexError , lambda : a [- 1.4 , 0 , 0 ])
41
41
assert_raises (IndexError , lambda : a [0 , - 1.4 , 0 ])
42
42
# Note torch validates index arguments "depth-first", so will prioritise
43
- # raising TypeError, e.g.
43
+ # raising TypeError over IndexError , e.g.
44
44
#
45
45
# >>> a = np.array([[[5]]])
46
46
# >>> a[0.0:, 0.0]
@@ -52,8 +52,8 @@ def test_index_no_floats(self):
52
52
# TypeError: slice indices must be integers or None or have an
53
53
# __index__ method
54
54
#
55
- assert_raises (TypeError , lambda : a [0.0 :, 0.0 ])
56
- assert_raises (TypeError , lambda : a [0.0 :, 0.0 ,:])
55
+ assert_raises (( IndexError , TypeError ) , lambda : a [0.0 :, 0.0 ])
56
+ assert_raises (( IndexError , TypeError ) , lambda : a [0.0 :, 0.0 ,:])
57
57
58
58
def test_slicing_no_floats (self ):
59
59
a = np .array ([[5 ]])
@@ -199,7 +199,8 @@ def test_single_int_index(self):
199
199
# Index out of bounds produces IndexError
200
200
assert_raises (IndexError , a .__getitem__ , 1 << 30 )
201
201
# Index overflow produces IndexError
202
- assert_raises (IndexError , a .__getitem__ , 1 << 64 )
202
+ # Note torch raises RuntimeError here
203
+ assert_raises ((IndexError , RuntimeError ), a .__getitem__ , 1 << 64 )
203
204
204
205
def test_single_bool_index (self ):
205
206
# Single boolean index
0 commit comments