Skip to content

Commit a1739f6

Browse files
Replace some use of broadcastable with shape in tests.tensor.test_subtensor
1 parent 8761c77 commit a1739f6

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tests/tensor/test_subtensor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_ok_elem_2(self):
465465
def test_ok_row(self):
466466
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
467467
t = n[1]
468-
assert not any(n.type.broadcastable)
468+
assert not any(s == 1 for s in n.type.shape)
469469
assert isinstance(t.owner.op, Subtensor)
470470
tval = self.eval_output_and_check(t)
471471
assert tval.shape == (3,)
@@ -475,7 +475,7 @@ def test_ok_col(self):
475475
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
476476
t = n[:, 0]
477477
assert isinstance(t.owner.op, Subtensor)
478-
assert not any(n.type.broadcastable)
478+
assert not any(s == 1 for s in n.type.shape)
479479
tval = self.eval_output_and_check(t)
480480
assert tval.shape == (2,)
481481
assert np.all(tval == [0, 3])
@@ -1773,15 +1773,17 @@ def test_index_into_vec_w_vec(self):
17731773
def test_index_into_vec_w_matrix(self):
17741774
a = self.v[self.ix2]
17751775
assert a.dtype == self.v.dtype, (a.dtype, self.v.dtype)
1776-
assert a.broadcastable == self.ix2.broadcastable, (
1777-
a.broadcastable,
1778-
self.ix2.broadcastable,
1776+
assert a.type.ndim == self.ix2.type.ndim
1777+
assert all(
1778+
s1 == s2
1779+
for s1, s2 in zip(a.type.shape, self.ix2.type.shape)
1780+
if s1 == 1 or s2 == 1
17791781
)
17801782

17811783
def test_index_into_mat_w_row(self):
17821784
a = self.m[self.ixr]
17831785
assert a.dtype == self.m.dtype, (a.dtype, self.m.dtype)
1784-
assert a.broadcastable == (True, False, False)
1786+
assert a.type.shape == (1, None, None)
17851787

17861788
def test_index_w_int_and_vec(self):
17871789
# like test_ok_list, but with a single index on the first one
@@ -2447,7 +2449,7 @@ def test_AdvancedSubtensor_bool(self):
24472449
)
24482450

24492451
abs_res = n[~isinf(n)]
2450-
assert abs_res.broadcastable == (False,)
2452+
assert abs_res.type.shape == (None,)
24512453

24522454

24532455
@config.change_flags(compute_test_value="raise")
@@ -2468,9 +2470,7 @@ def idx_as_tensor(x):
24682470
def bcast_shape_tuple(x):
24692471
if not hasattr(x, "shape"):
24702472
return x
2471-
return tuple(
2472-
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
2473-
)
2473+
return tuple(s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape))
24742474

24752475

24762476
test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))

0 commit comments

Comments
 (0)