@@ -364,8 +364,8 @@ def set_shape(self, r, s, override=False):
364
364
else :
365
365
shape_vars .append (self .unpack (s [i ], r ))
366
366
assert all (
367
- not hasattr (r .type , "broadcastable " )
368
- or not r .type .broadcastable [i ]
367
+ not hasattr (r .type , "shape " )
368
+ or r .type .shape [i ] != 1
369
369
or self .lscalar_one .equals (shape_vars [i ])
370
370
or self .lscalar_one .equals (extract_constant (shape_vars [i ]))
371
371
for i in range (r .type .ndim )
@@ -447,9 +447,9 @@ def update_shape(self, r, other_r):
447
447
merged_shape .append (other_shape [i ])
448
448
assert all (
449
449
(
450
- not hasattr (r .type , "broadcastable " )
451
- or not r .type .broadcastable [i ]
452
- and not other_r .type .broadcastable [i ]
450
+ not hasattr (r .type , "shape " )
451
+ or r .type .shape [i ] != 1
452
+ and other_r .type .shape [i ] != 1
453
453
)
454
454
or self .lscalar_one .equals (merged_shape [i ])
455
455
or self .lscalar_one .equals (
@@ -474,8 +474,8 @@ def set_shape_i(self, r, i, s_i):
474
474
else :
475
475
new_shape .append (s_j )
476
476
assert all (
477
- not hasattr (r .type , "broadcastable " )
478
- or not r .type .broadcastable [idx ]
477
+ not hasattr (r .type , "shape " )
478
+ or r .type .shape [idx ] != 1
479
479
or self .lscalar_one .equals (new_shape [idx ])
480
480
or self .lscalar_one .equals (extract_constant (new_shape [idx ]))
481
481
for idx in range (r .type .ndim )
@@ -781,7 +781,11 @@ def f(fgraph, node):
781
781
# We should try to figure out why we lost the information about this
782
782
# constant value... but in the meantime, better not apply this
783
783
# rewrite.
784
- if rval .broadcastable == node .outputs [0 ].broadcastable :
784
+ if rval .type .ndim == node .outputs [0 ].type .ndim and all (
785
+ s1 == s1
786
+ for s1 , s2 in zip (rval .type .shape , node .outputs [0 ].type .shape )
787
+ if s1 == 1 or s2 == 1
788
+ ):
785
789
return [rval ]
786
790
else :
787
791
return False
@@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node):
816
820
if (
817
821
inp .type .ndim == 1
818
822
and output .type .ndim == 1
819
- and inp .type .broadcastable == output .type .broadcastable
823
+ and all (
824
+ s1 == s2
825
+ for s1 , s2 in zip (inp .type .shape , output .type .shape )
826
+ if s1 == 1 or s2 == 1
827
+ )
820
828
):
821
829
return [inp ]
822
830
@@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node):
862
870
shape_match [dim ] = True
863
871
continue
864
872
865
- # Match 1 if input.broadcastable [dim] is True
873
+ # Match 1 if input.type.shape [dim] == 1
866
874
cst_outshp_i = extract_constant (outshp_i , only_process_constants = 1 )
867
875
if inp .type .shape [dim ] == 1 and cst_outshp_i == 1 :
868
876
shape_match [dim ] = True
@@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node):
931
939
if index != output .type .ndim :
932
940
inner = op .__class__ (len (new_output_shape ))(inp , new_output_shape )
933
941
copy_stack_trace (output , inner )
934
- new_node = [DimShuffle (inner .type .broadcastable , dimshuffle_new_order )(inner )]
942
+ new_node = [
943
+ DimShuffle (tuple (s == 1 for s in inner .type .shape ), dimshuffle_new_order )(
944
+ inner
945
+ )
946
+ ]
935
947
copy_stack_trace (output , new_node )
936
948
return new_node
937
949
@@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
1096
1108
1097
1109
new_order = node .inputs [0 ].owner .op .new_order
1098
1110
inp = node .inputs [0 ].owner .inputs [0 ]
1099
- broadcastables = node .inputs [0 ].broadcastable
1100
1111
new_order_of_nonbroadcast = []
1101
- for i , bd in zip (new_order , broadcastables ):
1102
- if not bd :
1112
+ for i , s in zip (new_order , node . inputs [ 0 ]. type . shape ):
1113
+ if s != 1 :
1103
1114
new_order_of_nonbroadcast .append (i )
1104
1115
no_change_in_order = all (
1105
1116
new_order_of_nonbroadcast [i ] <= new_order_of_nonbroadcast [i + 1 ]
@@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node):
1123
1134
"""
1124
1135
if isinstance (node .op , Unbroadcast ):
1125
1136
x = node .inputs [0 ]
1126
- if x .broadcastable == node .outputs [0 ].broadcastable :
1137
+ if x .type .ndim == node .outputs [0 ].type .ndim and all (
1138
+ s1 == s2
1139
+ for s1 , s2 in zip (x .type .shape , node .outputs [0 ].type .shape )
1140
+ if s1 == 1 or s2 == 1
1141
+ ):
1127
1142
# No broadcastable flag was modified
1128
1143
# No need to copy over stack trace,
1129
1144
# because x should already have a stack trace.
0 commit comments