Skip to content

Commit 7d72236

Browse files
Replace use of broadcastable with shape in aesara.tensor.shape
1 parent 46a46af commit 7d72236

File tree

2 files changed

+45
-31
lines changed

2 files changed

+45
-31
lines changed

aesara/tensor/rewriting/shape.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,8 @@ def set_shape(self, r, s, override=False):
364364
else:
365365
shape_vars.append(self.unpack(s[i], r))
366366
assert all(
367-
not hasattr(r.type, "broadcastable")
368-
or not r.type.broadcastable[i]
367+
not hasattr(r.type, "shape")
368+
or r.type.shape[i] != 1
369369
or self.lscalar_one.equals(shape_vars[i])
370370
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
371371
for i in range(r.type.ndim)
@@ -447,9 +447,9 @@ def update_shape(self, r, other_r):
447447
merged_shape.append(other_shape[i])
448448
assert all(
449449
(
450-
not hasattr(r.type, "broadcastable")
451-
or not r.type.broadcastable[i]
452-
and not other_r.type.broadcastable[i]
450+
not hasattr(r.type, "shape")
451+
or r.type.shape[i] != 1
452+
and other_r.type.shape[i] != 1
453453
)
454454
or self.lscalar_one.equals(merged_shape[i])
455455
or self.lscalar_one.equals(
@@ -474,8 +474,8 @@ def set_shape_i(self, r, i, s_i):
474474
else:
475475
new_shape.append(s_j)
476476
assert all(
477-
not hasattr(r.type, "broadcastable")
478-
or not r.type.broadcastable[idx]
477+
not hasattr(r.type, "shape")
478+
or r.type.shape[idx] != 1
479479
or self.lscalar_one.equals(new_shape[idx])
480480
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
481481
for idx in range(r.type.ndim)
@@ -781,7 +781,11 @@ def f(fgraph, node):
781781
# We should try to figure out why we lost the information about this
782782
# constant value... but in the meantime, better not apply this
783783
# rewrite.
784-
if rval.broadcastable == node.outputs[0].broadcastable:
784+
if rval.type.ndim == node.outputs[0].type.ndim and all(
785+
s1 == s1
786+
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
787+
if s1 == 1 or s2 == 1
788+
):
785789
return [rval]
786790
else:
787791
return False
@@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node):
816820
if (
817821
inp.type.ndim == 1
818822
and output.type.ndim == 1
819-
and inp.type.broadcastable == output.type.broadcastable
823+
and all(
824+
s1 == s2
825+
for s1, s2 in zip(inp.type.shape, output.type.shape)
826+
if s1 == 1 or s2 == 1
827+
)
820828
):
821829
return [inp]
822830

@@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node):
862870
shape_match[dim] = True
863871
continue
864872

865-
# Match 1 if input.broadcastable[dim] is True
873+
# Match 1 if input.type.shape[dim] == 1
866874
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
867875
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
868876
shape_match[dim] = True
@@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node):
931939
if index != output.type.ndim:
932940
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
933941
copy_stack_trace(output, inner)
934-
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
942+
new_node = [
943+
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
944+
inner
945+
)
946+
]
935947
copy_stack_trace(output, new_node)
936948
return new_node
937949

@@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
10961108

10971109
new_order = node.inputs[0].owner.op.new_order
10981110
inp = node.inputs[0].owner.inputs[0]
1099-
broadcastables = node.inputs[0].broadcastable
11001111
new_order_of_nonbroadcast = []
1101-
for i, bd in zip(new_order, broadcastables):
1102-
if not bd:
1112+
for i, s in zip(new_order, node.inputs[0].type.shape):
1113+
if s != 1:
11031114
new_order_of_nonbroadcast.append(i)
11041115
no_change_in_order = all(
11051116
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
@@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node):
11231134
"""
11241135
if isinstance(node.op, Unbroadcast):
11251136
x = node.inputs[0]
1126-
if x.broadcastable == node.outputs[0].broadcastable:
1137+
if x.type.ndim == node.outputs[0].type.ndim and all(
1138+
s1 == s2
1139+
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape)
1140+
if s1 == 1 or s2 == 1
1141+
):
11271142
# No broadcastable flag was modified
11281143
# No need to copy over stack trace,
11291144
# because x should already have a stack trace.

tests/tensor/test_shape.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@
5555

5656
def test_shape_basic():
5757
s = shape([])
58-
assert s.type.broadcastable == (True,)
58+
assert s.type.shape == (1,)
5959

6060
s = shape([10])
61-
assert s.type.broadcastable == (True,)
61+
assert s.type.shape == (1,)
6262

6363
s = shape(lscalar())
64-
assert s.type.broadcastable == (False,)
64+
assert s.type.shape == (0,)
6565

6666
class MyType(Type):
6767
def filter(self, *args, **kwargs):
@@ -71,7 +71,7 @@ def __eq__(self, other):
7171
return isinstance(other, MyType) and other.thingy == self.thingy
7272

7373
s = shape(Variable(MyType(), None))
74-
assert s.type.broadcastable == (False,)
74+
assert s.type.shape == (None,)
7575

7676
s = shape(np.array(1))
7777
assert np.array_equal(eval_outputs([s]), [])
@@ -119,15 +119,14 @@ def test_basics(self):
119119
b = dmatrix()
120120
d = dmatrix()
121121

122-
# basic to 1 dim(without list)
123-
c = reshape(b, as_tensor_variable(6), ndim=1)
124-
f = self.function([b], c)
125-
126122
b_val1 = np.asarray([[0, 1, 2], [3, 4, 5]])
127123
c_val1 = np.asarray([0, 1, 2, 3, 4, 5])
128124
b_val2 = b_val1.T
129125
c_val2 = np.asarray([0, 3, 1, 4, 2, 5])
130126

127+
# basic to 1 dim(without list)
128+
c = reshape(b, as_tensor_variable(6), ndim=1)
129+
f = self.function([b], c)
131130
f_out1 = f(b_val1)
132131
f_out2 = f(b_val2)
133132
assert np.array_equal(f_out1, c_val1), (f_out1, c_val1)
@@ -191,10 +190,10 @@ def just_vals(v):
191190
f(np.asarray([[0, 1, 2], [3, 4, 5]])),
192191
np.asarray([[[0], [1], [2]], [[3], [4], [5]]]),
193192
)
194-
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == (
195-
False,
196-
False,
197-
True,
193+
assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
194+
None,
195+
None,
196+
1,
198197
)
199198

200199
# test broadcast flag for constant value of 1 if it cannot be
@@ -205,10 +204,10 @@ def just_vals(v):
205204
f(np.asarray([[0, 1, 2], [3, 4, 5]])),
206205
np.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]),
207206
)
208-
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == (
209-
False,
210-
False,
211-
True,
207+
assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
208+
None,
209+
None,
210+
1,
212211
)
213212

214213
def test_m1(self):

0 commit comments

Comments
 (0)