Skip to content

Commit b379b0f

Browse files
Use broadcasted output shape in local_useless_switch optimization
Closes #270
1 parent 175e784 commit b379b0f

File tree

2 files changed

+51
-92
lines changed

2 files changed

+51
-92
lines changed

tests/tensor/test_opt.py

Lines changed: 45 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6163,8 +6163,7 @@ class TestLocalUselessSwitch:
61636163
def setup_method(self):
61646164
self.mode = mode_opt.excluding("constant_folding")
61656165

6166-
def test_const0(self):
6167-
6166+
def test_const_0(self):
61686167
for dtype1 in ["int32", "int64"]:
61696168
for dtype2 in ["int32", "int64"]:
61706169
x = tt.matrix("x", dtype=dtype1)
@@ -6186,10 +6185,15 @@ def test_const0(self):
61866185
)
61876186
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
61886187
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
6189-
assert np.all(f(vx, vy) == vy)
6188+
np_res = np.where(0, vx, vy)
6189+
assert np.array_equal(f(vx, vy), np_res)
61906190

6191-
def test_const1(self):
6191+
res_non_bool_np = np.where(np.ones(10), 0, 1)
6192+
non_bool_graph = tt.switch(np.ones(10), 0, 1)
6193+
non_bool_fn = function([], non_bool_graph, mode=self.mode)
6194+
assert np.array_equal(non_bool_fn(), res_non_bool_np)
61926195

6196+
def test_const_1(self):
61936197
for dtype1 in ["int32", "int64"]:
61946198
for dtype2 in ["int32", "int64"]:
61956199
x = tt.matrix("x", dtype=dtype1)
@@ -6211,10 +6215,10 @@ def test_const1(self):
62116215
)
62126216
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
62136217
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
6214-
assert np.all(f(vx, vy) == vx)
6218+
np_res = np.where(1, vx, vy)
6219+
assert np.array_equal(f(vx, vy), np_res)
62156220

62166221
def test_left_is_right(self):
6217-
62186222
for dtype1 in ["int32", "int64"]:
62196223
x = tt.matrix("x", dtype=dtype1)
62206224
varc = tt.matrix("varc", dtype=dtype1)
@@ -6239,12 +6243,11 @@ def test_left_is_right(self):
62396243

62406244
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
62416245
vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
6242-
assert np.all(f1(vx) == vx)
6243-
assert np.all(f0(vx) == vx)
6244-
assert np.all(f2(vx, vc) == vx)
6246+
assert np.array_equal(f1(vx), vx)
6247+
assert np.array_equal(f0(vx), vx)
6248+
assert np.array_equal(f2(vx, vc), vx)
62456249

62466250
def test_shape_le_0(self):
6247-
62486251
for dtype1 in ["float32", "float64"]:
62496252
x = tt.matrix("x", dtype=dtype1)
62506253
z0 = tt.switch(tt.le(x.shape[0], 0), 0, x.shape[0])
@@ -6259,84 +6262,63 @@ def test_shape_le_0(self):
62596262
assert f0(vx) == 0
62606263
assert f1(vx) == 5
62616264

6262-
def test_broadcast1(self):
6265+
def test_broadcasting_1(self):
62636266
# test switch(cst, matrix, row)
62646267
x = tt.matrix("x", dtype="int32")
62656268
y = tt.vector("y", dtype="int64")
62666269

62676270
z = tt.switch(1, x, y)
62686271
f = function([x, y], z, mode=self.mode)
6269-
assert (
6270-
len(
6271-
[
6272-
node.op
6273-
for node in f.maker.fgraph.toposort()
6274-
if isinstance(node.op, tt.Elemwise)
6275-
and not isinstance(node.op.scalar_op, scal.basic.Cast)
6276-
]
6277-
)
6278-
== 0
6279-
)
6272+
6273+
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Elemwise)
6274+
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, scal.basic.Cast)
6275+
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
6276+
62806277
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
62816278
vy = np.array([10, 11, 12], dtype="int64")
6282-
assert np.all(f(vx, vy) == vx)
6279+
np_res = np.where(1, vx, vy)
6280+
assert np.array_equal(f(vx, vy), np_res)
62836281

62846282
z = tt.switch(0, x, y)
62856283
f = function([x, y], z, mode=self.mode)
6286-
assert (
6287-
len(
6288-
[
6289-
node.op
6290-
for node in f.maker.fgraph.toposort()
6291-
if isinstance(node.op, tt.Elemwise)
6292-
]
6293-
)
6294-
== 0
6295-
)
6284+
6285+
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Alloc)
6286+
assert f.maker.fgraph.inputs[1] == f.maker.fgraph.outputs[0].owner.inputs[0]
6287+
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
6288+
62966289
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
62976290
vy = np.array([10, 11, 12], dtype="int64")
6298-
assert np.all(f(vx, vy) == vy)
6291+
np_res = np.where(0, vx, vy)
6292+
assert np.array_equal(f(vx, vy), np_res)
62996293

6300-
def test_broadcast2(self):
6294+
def test_broadcasting_2(self):
63016295
# test switch(cst, vector, matrix)
63026296

6303-
# This case is not optimized for now.
63046297
x = tt.vector("x", dtype="int32")
63056298
y = tt.matrix("y", dtype="int64")
63066299
z = tt.switch(1, x, y)
63076300
f = function([x, y], z, mode=self.mode)
6308-
assert (
6309-
len(
6310-
[
6311-
node.op
6312-
for node in f.maker.fgraph.toposort()
6313-
if isinstance(node.op, tt.Elemwise)
6314-
and not isinstance(node.op.scalar_op, scal.basic.Cast)
6315-
]
6316-
)
6317-
== 0
6318-
)
6301+
6302+
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Alloc)
6303+
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
6304+
63196305
vx = np.array([4, 5, 6], dtype="int32")
63206306
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64")
6321-
assert np.all(f(vx, vy) == vx)
6307+
np_res = np.where(1, vx, vy)
6308+
assert np.array_equal(f(vx, vy), np_res)
63226309

63236310
z = tt.switch(0, x, y)
63246311
f = function([x, y], z, mode=self.mode)
6325-
assert (
6326-
len(
6327-
[
6328-
node.op
6329-
for node in f.maker.fgraph.toposort()
6330-
if isinstance(node.op, tt.Elemwise)
6331-
]
6332-
)
6333-
== 0
6334-
)
6312+
6313+
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
6314+
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
6315+
63356316
vx = np.array([4, 5, 6], dtype="int32")
63366317
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64")
6337-
assert np.all(f(vx, vy) == vy)
6318+
np_res = np.where(0, vx, vy)
6319+
assert np.array_equal(f(vx, vy), np_res)
63386320

6339-
def test_broadcast3(self):
6321+
def test_broadcasting_3(self):
63406322
# test switch(matrix, same_vector, same_vector)
63416323

63426324
x = tt.matrix("x", dtype="int32")
@@ -6346,16 +6328,9 @@ def test_broadcast3(self):
63466328
vx = np.array([[0, 1], [1, 0]], dtype="int32")
63476329
vy = np.array([7, 8], dtype="int64")
63486330
utt.assert_allclose(f(vx, vy), np.where(vx, vy, vy))
6349-
assert (
6350-
len(
6351-
[
6352-
node.op
6353-
for node in f.maker.fgraph.toposort()
6354-
if isinstance(node.op, tt.Elemwise)
6355-
]
6356-
)
6357-
== 0
6358-
)
6331+
6332+
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Alloc)
6333+
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
63596334

63606335

63616336
class TestLocalMergeSwitchSameCond:

theano/tensor/opt.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
ProdWithoutZeros,
9898
Sum,
9999
)
100+
from theano.tensor.extra_ops import broadcast_shape
100101
from theano.tensor.sort import TopKOp
101102
from theano.tensor.subtensor import (
102103
AdvancedIncSubtensor,
@@ -4327,7 +4328,9 @@ def local_useless_switch(fgraph, node):
43274328
T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
43284329
"""
43294330
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch):
4331+
43304332
cond = tt.extract_constant(node.inputs[0], only_process_constants=True)
4333+
43314334
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
43324335
cond, np.number
43334336
):
@@ -4336,37 +4339,18 @@ def local_useless_switch(fgraph, node):
43364339
else:
43374340
correct_out = node.inputs[1]
43384341

4339-
if correct_out.ndim != node.outputs[0].ndim:
4340-
# TODO: broadcast?
4341-
return False
43424342
if correct_out.dtype != node.outputs[0].dtype:
43434343
out = tt.cast(correct_out, node.outputs[0].dtype)
43444344
else:
43454345
out = correct_out
43464346

4347-
if out.type.broadcastable != node.outputs[0].type.broadcastable:
4348-
# We need to copy data to the new dimensions during execution
4349-
4350-
# We should not depend on node.outputs as this would
4351-
# make the new node depend on the old one that will
4352-
# get optimized again. So this create a cycle.
4353-
shps = []
4354-
for idx, (b1, b2), in enumerate(
4355-
zip(out.type.broadcastable, node.outputs[0].type.broadcastable)
4356-
):
4357-
if b1 == b2:
4358-
shps.append(out.shape[idx])
4359-
elif not node.inputs[1].type.broadcastable[idx]:
4360-
shps.append(node.inputs[1].shape[idx])
4361-
else:
4362-
shps.append(node.inputs[2].shape[idx])
4363-
out = alloc(out, *shps)
4364-
else:
4365-
out = out
4347+
out_shape = broadcast_shape(*node.inputs)
4348+
out = alloc(out, *out_shape)
43664349

43674350
# Copy over stacktrace from selected output to new output
43684351
copy_stack_trace(node.outputs + correct_out, out)
43694352
return [out]
4353+
43704354
# if left is right -> left
43714355
if node.inputs[1] is node.inputs[2]:
43724356
# Note: No need to copy over stacktrace, because the input node

0 commit comments

Comments
 (0)