@@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node):
621
621
part of the graph.
622
622
623
623
"""
624
- for idx , i in enumerate (node .inputs ):
625
- if i .owner and i .owner .op == switch :
626
- switch_node = i .owner
627
- try :
628
- if (
629
- get_underlying_scalar_constant_value (
630
- switch_node .inputs [1 ], only_process_constants = True
631
- )
632
- == 0.0
633
- ):
634
- listmul = node .inputs [:idx ] + node .inputs [idx + 1 :]
635
- fmul = mul (* ([* listmul , switch_node .inputs [2 ]]))
636
-
637
- # Copy over stacktrace for elementwise multiplication op
638
- # from previous elementwise multiplication op.
639
- # An error in the multiplication (e.g. errors due to
640
- # inconsistent shapes), will point to the
641
- # multiplication op.
642
- copy_stack_trace (node .outputs , fmul )
643
-
644
- fct = [switch (switch_node .inputs [0 ], 0 , fmul )]
645
- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
646
-
647
- # Copy over stacktrace for switch op from both previous
648
- # elementwise multiplication op and previous switch op,
649
- # because an error in this part can be caused by either
650
- # of the two previous ops.
651
- copy_stack_trace (node .outputs + switch_node .outputs , fct )
652
- return fct
653
- except NotScalarConstantError :
654
- pass
655
- try :
656
- if (
657
- get_underlying_scalar_constant_value (
658
- switch_node .inputs [2 ], only_process_constants = True
659
- )
660
- == 0.0
661
- ):
662
- listmul = node .inputs [:idx ] + node .inputs [idx + 1 :]
663
- fmul = mul (* ([* listmul , switch_node .inputs [1 ]]))
664
- # Copy over stacktrace for elementwise multiplication op
665
- # from previous elementwise multiplication op.
666
- # An error in the multiplication (e.g. errors due to
667
- # inconsistent shapes), will point to the
668
- # multiplication op.
669
- copy_stack_trace (node .outputs , fmul )
670
-
671
- fct = [switch (switch_node .inputs [0 ], fmul , 0 )]
672
- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
673
-
674
- # Copy over stacktrace for switch op from both previous
675
- # elementwise multiplication op and previous switch op,
676
- # because an error in this part can be caused by either
677
- # of the two previous ops.
678
- copy_stack_trace (node .outputs + switch_node .outputs , fct )
679
- return fct
680
- except NotScalarConstantError :
681
- pass
682
- return False
624
+ for mul_inp_idx , mul_inp in enumerate (node .inputs ):
625
+ if mul_inp .owner and mul_inp .owner .op == switch :
626
+ switch_node = mul_inp .owner
627
+ # Look for a zero as the first or second branch of the switch
628
+ for branch in range (2 ):
629
+ zero_switch_input = switch_node .inputs [1 + branch ]
630
+ if not get_unique_constant_value (zero_switch_input ) == 0.0 :
631
+ continue
632
+
633
+ switch_cond = switch_node .inputs [0 ]
634
+ other_switch_input = switch_node .inputs [1 + (1 - branch )]
635
+
636
+ listmul = list (node .inputs )
637
+ listmul [mul_inp_idx ] = other_switch_input
638
+ fmul = mul (* listmul )
639
+
640
+ # Copy over stacktrace for elementwise multiplication op
641
+ # from previous elementwise multiplication op.
642
+ # An error in the multiplication (e.g. errors due to
643
+ # inconsistent shapes), will point to the
644
+ # multiplication op.
645
+ copy_stack_trace (node .outputs , fmul )
646
+
647
+ if branch == 0 :
648
+ fct = switch (switch_cond , zero_switch_input , fmul )
649
+ else :
650
+ fct = switch (switch_cond , fmul , zero_switch_input )
651
+
652
+ # Tell debug_mode than the output is correct, even if nan disappear
653
+ fct .tag .values_eq_approx = values_eq_approx_remove_nan
654
+
655
+ # Copy over stacktrace for switch op from both previous
656
+ # elementwise multiplication op and previous switch op,
657
+ # because an error in this part can be caused by either
658
+ # of the two previous ops.
659
+ copy_stack_trace (node .outputs + switch_node .outputs , fct )
660
+ return [fct ]
683
661
684
662
685
663
@register_canonicalize
@@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node):
699
677
See `local_mul_switch_sink` for more details.
700
678
701
679
"""
702
- op = node .op
703
- if node .inputs [0 ].owner and node .inputs [0 ].owner .op == switch :
704
- switch_node = node .inputs [0 ].owner
705
- try :
706
- if (
707
- get_underlying_scalar_constant_value (
708
- switch_node .inputs [1 ], only_process_constants = True
709
- )
710
- == 0.0
711
- ):
712
- fdiv = op (switch_node .inputs [2 ], node .inputs [1 ])
713
- # Copy over stacktrace for elementwise division op
714
- # from previous elementwise multiplication op.
715
- # An error in the division (e.g. errors due to
716
- # inconsistent shapes or division by zero),
717
- # will point to the new division op.
718
- copy_stack_trace (node .outputs , fdiv )
719
-
720
- fct = [switch (switch_node .inputs [0 ], 0 , fdiv )]
721
- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
722
-
723
- # Copy over stacktrace for switch op from both previous
724
- # elementwise division op and previous switch op,
725
- # because an error in this part can be caused by either
726
- # of the two previous ops.
727
- copy_stack_trace (node .outputs + switch_node .outputs , fct )
728
- return fct
729
- except NotScalarConstantError :
730
- pass
731
- try :
732
- if (
733
- get_underlying_scalar_constant_value (
734
- switch_node .inputs [2 ], only_process_constants = True
735
- )
736
- == 0.0
737
- ):
738
- fdiv = op (switch_node .inputs [1 ], node .inputs [1 ])
739
- # Copy over stacktrace for elementwise division op
740
- # from previous elementwise multiplication op.
741
- # An error in the division (e.g. errors due to
742
- # inconsistent shapes or division by zero),
743
- # will point to the new division op.
744
- copy_stack_trace (node .outputs , fdiv )
745
-
746
- fct = [switch (switch_node .inputs [0 ], fdiv , 0 )]
747
- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
680
+ num , denom = node .inputs
748
681
749
- # Copy over stacktrace for switch op from both previous
750
- # elementwise division op and previous switch op,
751
- # because an error in this part can be caused by either
752
- # of the two previous ops.
753
- copy_stack_trace (node .outputs + switch_node .outputs , fct )
754
- return fct
755
- except NotScalarConstantError :
756
- pass
757
- return False
682
+ if num .owner and num .owner .op == switch :
683
+ switch_node = num .owner
684
+ # Look for a zero as the first or second branch of the switch
685
+ for branch in range (2 ):
686
+ zero_switch_input = switch_node .inputs [1 + branch ]
687
+ if not get_unique_constant_value (zero_switch_input ) == 0.0 :
688
+ continue
689
+
690
+ switch_cond = switch_node .inputs [0 ]
691
+ other_switch_input = switch_node .inputs [1 + (1 - branch )]
692
+
693
+ fdiv = node .op (other_switch_input , denom )
694
+
695
+ # Copy over stacktrace for elementwise division op
696
+ # from previous elementwise multiplication op.
697
+ # An error in the division (e.g. errors due to
698
+ # inconsistent shapes or division by zero),
699
+ # will point to the new division op.
700
+ copy_stack_trace (node .outputs , fdiv )
701
+
702
+ fct = switch (switch_cond , zero_switch_input , fdiv )
703
+
704
+ # Tell debug_mode than the output is correct, even if nan disappear
705
+ fct .tag .values_eq_approx = values_eq_approx_remove_nan
706
+
707
+ # Copy over stacktrace for switch op from both previous
708
+ # elementwise division op and previous switch op,
709
+ # because an error in this part can be caused by either
710
+ # of the two previous ops.
711
+ copy_stack_trace (node .outputs + switch_node .outputs , fct )
712
+ return [fct ]
758
713
759
714
760
715
class AlgebraicCanonizer (NodeRewriter ):
0 commit comments