@@ -6163,8 +6163,7 @@ class TestLocalUselessSwitch:
6163
6163
def setup_method (self ):
6164
6164
self .mode = mode_opt .excluding ("constant_folding" )
6165
6165
6166
- def test_const0 (self ):
6167
-
6166
+ def test_const_0 (self ):
6168
6167
for dtype1 in ["int32" , "int64" ]:
6169
6168
for dtype2 in ["int32" , "int64" ]:
6170
6169
x = tt .matrix ("x" , dtype = dtype1 )
@@ -6186,10 +6185,15 @@ def test_const0(self):
6186
6185
)
6187
6186
vx = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype1 )
6188
6187
vy = np .array ([[7 , 8 , 9 ], [10 , 11 , 12 ]], dtype = dtype2 )
6189
- assert np .all (f (vx , vy ) == vy )
6188
+ np_res = np .where (0 , vx , vy )
6189
+ assert np .array_equal (f (vx , vy ), np_res )
6190
6190
6191
- def test_const1 (self ):
6191
+ res_non_bool_np = np .where (np .ones (10 ), 0 , 1 )
6192
+ non_bool_graph = tt .switch (np .ones (10 ), 0 , 1 )
6193
+ non_bool_fn = function ([], non_bool_graph , mode = self .mode )
6194
+ assert np .array_equal (non_bool_fn (), res_non_bool_np )
6192
6195
6196
+ def test_const_1 (self ):
6193
6197
for dtype1 in ["int32" , "int64" ]:
6194
6198
for dtype2 in ["int32" , "int64" ]:
6195
6199
x = tt .matrix ("x" , dtype = dtype1 )
@@ -6211,10 +6215,10 @@ def test_const1(self):
6211
6215
)
6212
6216
vx = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype1 )
6213
6217
vy = np .array ([[7 , 8 , 9 ], [10 , 11 , 12 ]], dtype = dtype2 )
6214
- assert np .all (f (vx , vy ) == vx )
6218
+ np_res = np .where (1 , vx , vy )
6219
+ assert np .array_equal (f (vx , vy ), np_res )
6215
6220
6216
6221
def test_left_is_right (self ):
6217
-
6218
6222
for dtype1 in ["int32" , "int64" ]:
6219
6223
x = tt .matrix ("x" , dtype = dtype1 )
6220
6224
varc = tt .matrix ("varc" , dtype = dtype1 )
@@ -6239,12 +6243,11 @@ def test_left_is_right(self):
6239
6243
6240
6244
vx = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype1 )
6241
6245
vc = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype1 )
6242
- assert np .all (f1 (vx ) == vx )
6243
- assert np .all (f0 (vx ) == vx )
6244
- assert np .all (f2 (vx , vc ) == vx )
6246
+ assert np .array_equal (f1 (vx ), vx )
6247
+ assert np .array_equal (f0 (vx ), vx )
6248
+ assert np .array_equal (f2 (vx , vc ), vx )
6245
6249
6246
6250
def test_shape_le_0 (self ):
6247
-
6248
6251
for dtype1 in ["float32" , "float64" ]:
6249
6252
x = tt .matrix ("x" , dtype = dtype1 )
6250
6253
z0 = tt .switch (tt .le (x .shape [0 ], 0 ), 0 , x .shape [0 ])
@@ -6259,84 +6262,63 @@ def test_shape_le_0(self):
6259
6262
assert f0 (vx ) == 0
6260
6263
assert f1 (vx ) == 5
6261
6264
6262
- def test_broadcast1 (self ):
6265
+ def test_broadcasting_1 (self ):
6263
6266
# test switch(cst, matrix, row)
6264
6267
x = tt .matrix ("x" , dtype = "int32" )
6265
6268
y = tt .vector ("y" , dtype = "int64" )
6266
6269
6267
6270
z = tt .switch (1 , x , y )
6268
6271
f = function ([x , y ], z , mode = self .mode )
6269
- assert (
6270
- len (
6271
- [
6272
- node .op
6273
- for node in f .maker .fgraph .toposort ()
6274
- if isinstance (node .op , tt .Elemwise )
6275
- and not isinstance (node .op .scalar_op , scal .basic .Cast )
6276
- ]
6277
- )
6278
- == 0
6279
- )
6272
+
6273
+ assert isinstance (f .maker .fgraph .outputs [0 ].owner .op , tt .Elemwise )
6274
+ assert isinstance (f .maker .fgraph .outputs [0 ].owner .op .scalar_op , scal .basic .Cast )
6275
+ assert not any (node .op == tt .switch for node in f .maker .fgraph .toposort ())
6276
+
6280
6277
vx = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = "int32" )
6281
6278
vy = np .array ([10 , 11 , 12 ], dtype = "int64" )
6282
- assert np .all (f (vx , vy ) == vx )
6279
+ np_res = np .where (1 , vx , vy )
6280
+ assert np .array_equal (f (vx , vy ), np_res )
6283
6281
6284
6282
z = tt .switch (0 , x , y )
6285
6283
f = function ([x , y ], z , mode = self .mode )
6286
- assert (
6287
- len (
6288
- [
6289
- node .op
6290
- for node in f .maker .fgraph .toposort ()
6291
- if isinstance (node .op , tt .Elemwise )
6292
- ]
6293
- )
6294
- == 0
6295
- )
6284
+
6285
+ assert isinstance (f .maker .fgraph .outputs [0 ].owner .op , tt .Alloc )
6286
+ assert f .maker .fgraph .inputs [1 ] == f .maker .fgraph .outputs [0 ].owner .inputs [0 ]
6287
+ assert not any (node .op == tt .switch for node in f .maker .fgraph .toposort ())
6288
+
6296
6289
vx = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = "int32" )
6297
6290
vy = np .array ([10 , 11 , 12 ], dtype = "int64" )
6298
- assert np .all (f (vx , vy ) == vy )
6291
+ np_res = np .where (0 , vx , vy )
6292
+ assert np .array_equal (f (vx , vy ), np_res )
6299
6293
6300
- def test_broadcast2 (self ):
6294
+ def test_broadcasting_2 (self ):
6301
6295
# test switch(cst, vector, matrix)
6302
6296
6303
- # This case is not optimized for now.
6304
6297
x = tt .vector ("x" , dtype = "int32" )
6305
6298
y = tt .matrix ("y" , dtype = "int64" )
6306
6299
z = tt .switch (1 , x , y )
6307
6300
f = function ([x , y ], z , mode = self .mode )
6308
- assert (
6309
- len (
6310
- [
6311
- node .op
6312
- for node in f .maker .fgraph .toposort ()
6313
- if isinstance (node .op , tt .Elemwise )
6314
- and not isinstance (node .op .scalar_op , scal .basic .Cast )
6315
- ]
6316
- )
6317
- == 0
6318
- )
6301
+
6302
+ assert isinstance (f .maker .fgraph .outputs [0 ].owner .op , tt .Alloc )
6303
+ assert not any (node .op == tt .switch for node in f .maker .fgraph .toposort ())
6304
+
6319
6305
vx = np .array ([4 , 5 , 6 ], dtype = "int32" )
6320
6306
vy = np .array ([[7 , 8 , 9 ], [10 , 11 , 12 ]], dtype = "int64" )
6321
- assert np .all (f (vx , vy ) == vx )
6307
+ np_res = np .where (1 , vx , vy )
6308
+ assert np .array_equal (f (vx , vy ), np_res )
6322
6309
6323
6310
z = tt .switch (0 , x , y )
6324
6311
f = function ([x , y ], z , mode = self .mode )
6325
- assert (
6326
- len (
6327
- [
6328
- node .op
6329
- for node in f .maker .fgraph .toposort ()
6330
- if isinstance (node .op , tt .Elemwise )
6331
- ]
6332
- )
6333
- == 0
6334
- )
6312
+
6313
+ assert isinstance (f .maker .fgraph .outputs [0 ].owner .op , DeepCopyOp )
6314
+ assert not any (node .op == tt .switch for node in f .maker .fgraph .toposort ())
6315
+
6335
6316
vx = np .array ([4 , 5 , 6 ], dtype = "int32" )
6336
6317
vy = np .array ([[7 , 8 , 9 ], [10 , 11 , 12 ]], dtype = "int64" )
6337
- assert np .all (f (vx , vy ) == vy )
6318
+ np_res = np .where (0 , vx , vy )
6319
+ assert np .array_equal (f (vx , vy ), np_res )
6338
6320
6339
- def test_broadcast3 (self ):
6321
+ def test_broadcasting_3 (self ):
6340
6322
# test switch(matrix, same_vector, same_vector)
6341
6323
6342
6324
x = tt .matrix ("x" , dtype = "int32" )
@@ -6346,16 +6328,9 @@ def test_broadcast3(self):
6346
6328
vx = np .array ([[0 , 1 ], [1 , 0 ]], dtype = "int32" )
6347
6329
vy = np .array ([7 , 8 ], dtype = "int64" )
6348
6330
utt .assert_allclose (f (vx , vy ), np .where (vx , vy , vy ))
6349
- assert (
6350
- len (
6351
- [
6352
- node .op
6353
- for node in f .maker .fgraph .toposort ()
6354
- if isinstance (node .op , tt .Elemwise )
6355
- ]
6356
- )
6357
- == 0
6358
- )
6331
+
6332
+ assert isinstance (f .maker .fgraph .outputs [0 ].owner .op , tt .Alloc )
6333
+ assert not any (node .op == tt .switch for node in f .maker .fgraph .toposort ())
6359
6334
6360
6335
6361
6336
class TestLocalMergeSwitchSameCond :
0 commit comments