@@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
3231
3231
class TestLocalReduce :
3232
3232
def setup_method (self ):
3233
3233
self .mode = get_default_mode ().including (
3234
- "canonicalize" , "specialize" , "uncanonicalize" , "local_max_and_argmax"
3234
+ "canonicalize" , "specialize" , "uncanonicalize"
3235
3235
)
3236
3236
3237
3237
def test_local_reduce_broadcast_all_0 (self ):
@@ -3304,62 +3304,94 @@ def test_local_reduce_broadcast_some_1(self):
3304
3304
isinstance (node .op , CAReduce ) for node in f .maker .fgraph .toposort ()
3305
3305
)
3306
3306
3307
- def test_local_reduce_join (self ):
3307
+
3308
+ class TestReduceJoin :
3309
+ def setup_method (self ):
3310
+ self .mode = get_default_mode ().including (
3311
+ "canonicalize" , "specialize" , "uncanonicalize"
3312
+ )
3313
+
3314
+ @pytest .mark .parametrize (
3315
+ "op, nin" , [(pt_sum , 3 ), (pt_max , 2 ), (pt_min , 2 ), (prod , 3 )]
3316
+ )
3317
+ def test_local_reduce_join (self , op , nin ):
3308
3318
vx = matrix ()
3309
3319
vy = matrix ()
3310
3320
vz = matrix ()
3311
3321
x = np .asarray ([[1 , 0 ], [3 , 4 ]], dtype = config .floatX )
3312
3322
y = np .asarray ([[4 , 0 ], [2 , 1 ]], dtype = config .floatX )
3313
3323
z = np .asarray ([[5 , 0 ], [1 , 2 ]], dtype = config .floatX )
3314
- # Test different reduction scalar operation
3315
- for out , res in [
3316
- (pt_max ((vx , vy ), 0 ), np .max ((x , y ), 0 )),
3317
- (pt_min ((vx , vy ), 0 ), np .min ((x , y ), 0 )),
3318
- (pt_sum ((vx , vy , vz ), 0 ), np .sum ((x , y , z ), 0 )),
3319
- (prod ((vx , vy , vz ), 0 ), np .prod ((x , y , z ), 0 )),
3320
- (prod ((vx , vy .T , vz ), 0 ), np .prod ((x , y .T , z ), 0 )),
3321
- ]:
3322
- f = function ([vx , vy , vz ], out , on_unused_input = "ignore" , mode = self .mode )
3323
- assert (f (x , y , z ) == res ).all (), out
3324
- topo = f .maker .fgraph .toposort ()
3325
- assert len (topo ) <= 2 , out
3326
- assert isinstance (topo [- 1 ].op , Elemwise ), out
3327
3324
3325
+ inputs = (vx , vy , vz )[:nin ]
3326
+ test_values = (x , y , z )[:nin ]
3327
+
3328
+ out = op (inputs , axis = 0 )
3329
+ f = function (inputs , out , mode = self .mode )
3330
+ np .testing .assert_allclose (
3331
+ f (* test_values ), getattr (np , op .__name__ )(test_values , axis = 0 )
3332
+ )
3333
+ topo = f .maker .fgraph .toposort ()
3334
+ assert len (topo ) <= 2
3335
+ assert isinstance (topo [- 1 ].op , Elemwise )
3336
+
3337
+ def test_type (self ):
3328
3338
# Test different axis for the join and the reduction
3329
3339
# We must force the dtype, of otherwise, this tests will fail
3330
3340
# on 32 bit systems
3331
3341
A = shared (np .array ([1 , 2 , 3 , 4 , 5 ], dtype = "int64" ))
3332
3342
3333
3343
f = function ([], pt_sum (pt .stack ([A , A ]), axis = 0 ), mode = self .mode )
3334
- utt .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3344
+ np . testing .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3335
3345
topo = f .maker .fgraph .toposort ()
3336
3346
assert isinstance (topo [- 1 ].op , Elemwise )
3337
3347
3338
3348
# Test a case that was bugged in a old PyTensor bug
3339
3349
f = function ([], pt_sum (pt .stack ([A , A ]), axis = 1 ), mode = self .mode )
3340
3350
3341
- utt .assert_allclose (f (), [15 , 15 ])
3351
+ np . testing .assert_allclose (f (), [15 , 15 ])
3342
3352
topo = f .maker .fgraph .toposort ()
3343
3353
assert not isinstance (topo [- 1 ].op , Elemwise )
3344
3354
3345
3355
# This case could be rewritten
3346
3356
A = shared (np .array ([1 , 2 , 3 , 4 , 5 ]).reshape (5 , 1 ))
3347
3357
f = function ([], pt_sum (pt .concatenate ((A , A ), axis = 1 ), axis = 1 ), mode = self .mode )
3348
- utt .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3358
+ np . testing .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3349
3359
topo = f .maker .fgraph .toposort ()
3350
3360
assert not isinstance (topo [- 1 ].op , Elemwise )
3351
3361
3352
3362
A = shared (np .array ([1 , 2 , 3 , 4 , 5 ]).reshape (5 , 1 ))
3353
3363
f = function ([], pt_sum (pt .concatenate ((A , A ), axis = 1 ), axis = 0 ), mode = self .mode )
3354
- utt .assert_allclose (f (), [15 , 15 ])
3364
+ np . testing .assert_allclose (f (), [15 , 15 ])
3355
3365
topo = f .maker .fgraph .toposort ()
3356
3366
assert not isinstance (topo [- 1 ].op , Elemwise )
3357
3367
3368
+ def test_not_supported_axis_none (self ):
3358
3369
# Test that the rewrite does not crash in one case where it
3359
3370
# is not applied. Reported at
3360
3371
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
3372
+ vx = matrix ()
3373
+ vy = matrix ()
3374
+ vz = matrix ()
3375
+ x = np .asarray ([[1 , 0 ], [3 , 4 ]], dtype = config .floatX )
3376
+ y = np .asarray ([[4 , 0 ], [2 , 1 ]], dtype = config .floatX )
3377
+ z = np .asarray ([[5 , 0 ], [1 , 2 ]], dtype = config .floatX )
3378
+
3361
3379
out = pt_sum ([vx , vy , vz ], axis = None )
3362
- f = function ([vx , vy , vz ], out )
3380
+ f = function ([vx , vy , vz ], out , mode = self .mode )
3381
+ np .testing .assert_allclose (f (x , y , z ), np .sum ([x , y , z ]))
3382
+
3383
+ def test_not_supported_unequal_shapes (self ):
3384
+ # Not the same shape along the join axis
3385
+ vx = matrix (shape = (1 , 3 ))
3386
+ vy = matrix (shape = (2 , 3 ))
3387
+ x = np .asarray ([[1 , 0 , 1 ]], dtype = config .floatX )
3388
+ y = np .asarray ([[4 , 0 , 1 ], [2 , 1 , 1 ]], dtype = config .floatX )
3389
+ out = pt_sum (join (0 , vx , vy ), axis = 0 )
3390
+
3391
+ f = function ([vx , vy ], out , mode = self .mode )
3392
+ np .testing .assert_allclose (
3393
+ f (x , y ), np .sum (np .concatenate ([x , y ], axis = 0 ), axis = 0 )
3394
+ )
3363
3395
3364
3396
3365
3397
def test_local_useless_adds ():
0 commit comments