@@ -404,19 +404,44 @@ def _find_diag_from_eye_mul(potential_mul_input):
404
404
eye_input = [
405
405
mul_input
406
406
for mul_input in inputs_to_mul
407
- if mul_input .owner and isinstance (mul_input .owner .op , Eye )
407
+ if mul_input .owner
408
+ and (
409
+ isinstance (mul_input .owner .op , Eye )
410
+ or
411
+ # This whole condition checks if there is an Eye hiding inside a DimShuffle.
412
+ # This arises from batched elementwise multiplication between a tensor and an eye, e.g.:
413
+ # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites.
414
+ (
415
+ isinstance (mul_input .owner .op , DimShuffle )
416
+ and mul_input .owner .inputs [0 ].owner is not None
417
+ and isinstance (mul_input .owner .inputs [0 ].owner .op , Eye )
418
+ )
419
+ )
408
420
]
409
-
410
- # Check if 1's are being put on the main diagonal only (k = 0)
411
- if eye_input and getattr (eye_input [0 ].owner .inputs [- 1 ], "data" , - 1 ).item () != 0 :
421
+ if not eye_input :
412
422
return None
413
423
414
- # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
415
- if eye_input and eye_input [0 ].broadcastable [- 2 :] != (False , False ):
424
+ eye_input = eye_input [0 ]
425
+
426
+ # If this multiplication came from a batched operation, it will be wrapped in a DimShuffle
427
+ if isinstance (eye_input .owner .op , DimShuffle ):
428
+ inner_eye = eye_input .owner .inputs [0 ]
429
+ if not isinstance (inner_eye .owner .op , Eye ):
430
+ return None
431
+ # Check if 1's are being put on the main diagonal only (k = 0)
432
+ # and if the identity matrix is degenerate (column or row matrix)
433
+ if getattr (
434
+ inner_eye .owner .inputs [- 1 ], "data" , - 1
435
+ ).item () != 0 or inner_eye .broadcastable [- 2 :] != (False , False ):
436
+ return None
437
+
438
+ elif getattr (
439
+ eye_input .owner .inputs [- 1 ], "data" , - 1
440
+ ).item () != 0 or eye_input .broadcastable [- 2 :] != (False , False ):
416
441
return None
417
442
418
443
# Get all non Eye inputs (scalars/matrices/vectors)
419
- non_eye_inputs = list (set (inputs_to_mul ) - set ( eye_input ) )
444
+ non_eye_inputs = list (set (inputs_to_mul ) - { eye_input } )
420
445
return eye_input , non_eye_inputs
421
446
422
447
@@ -448,15 +473,22 @@ def rewrite_det_diag_to_prod_diag(fgraph, node):
448
473
inputs = node .inputs [0 ]
449
474
450
475
# Check for use of pt.diag first
451
- if inputs .owner and isinstance (inputs .owner .op , AllocDiag2 ):
476
+ if (
477
+ inputs .owner
478
+ and isinstance (inputs .owner .op , AllocDiag2 )
479
+ and inputs .owner .op .offset == 0
480
+ ):
452
481
diag_input = inputs .owner .inputs [0 ]
482
+ diag_input .dprint ()
453
483
det_val = diag_input .prod (axis = - 1 )
454
484
return [det_val ]
455
485
456
486
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
457
487
inputs_or_none = _find_diag_from_eye_mul (inputs )
488
+
458
489
if inputs_or_none is None :
459
490
return None
491
+
460
492
eye_input , non_eye_inputs = inputs_or_none
461
493
462
494
# Dealing with only one other input
0 commit comments