@@ -62,17 +62,17 @@ class DimShuffle(ExternalCOp):
62
62
If `j = new_order[i]` is an index, the output's ith dimension
63
63
will be the input's jth dimension.
64
64
If `new_order[i]` is `x`, the output's ith dimension will
65
- be 1 and Broadcast operations will be allowed to do broadcasting
65
+ be 1 and broadcast operations will be allowed to do broadcasting
66
66
over that dimension.
67
67
68
- If `input.broadcastable [i] == False ` then `i` must be found in new_order.
68
+ If `input.type.shape [i] != 1 ` then `i` must be found in ` new_order` .
69
69
Broadcastable dimensions, on the other hand, can be discarded.
70
70
71
71
.. code-block:: python
72
72
73
73
DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
74
74
75
- This op will only work on 3d tensors with no broadcastable
75
+ This `Op` will only work on 3d tensors with no broadcastable
76
76
dimensions. The first dimension will be broadcastable,
77
77
then we will have the third dimension of the input tensor as
78
78
the second of the resulting tensor, etc. If the tensor has
@@ -83,7 +83,7 @@ class DimShuffle(ExternalCOp):
83
83
84
84
DimShuffle((True, False), [1])
85
85
86
- This op will only work on 2d tensors with the first dimension
86
+ This `Op` will only work on 2d tensors with the first dimension
87
87
broadcastable.
88
88
The second dimension of the input tensor will be the first dimension of
89
89
the resulting tensor.
@@ -186,7 +186,7 @@ def __setstate__(self, state):
186
186
187
187
def make_node (self , _input ):
188
188
input = as_tensor_variable (_input )
189
- ib = tuple (input .type .broadcastable )
189
+ ib = tuple (s == 1 for s in input .type .shape )
190
190
if ib != self .input_broadcastable :
191
191
if len (ib ) != len (self .input_broadcastable ):
192
192
raise TypeError (
@@ -258,7 +258,7 @@ def grad(self, inp, grads):
258
258
(x ,) = inp
259
259
(gz ,) = grads
260
260
gz = as_tensor_variable (gz )
261
- grad_order = ["x" ] * len ( x .type .broadcastable )
261
+ grad_order = ["x" ] * x .type .ndim
262
262
for i , v in enumerate (self .new_order ):
263
263
if v != "x" :
264
264
grad_order [v ] = i
@@ -269,7 +269,7 @@ def grad(self, inp, grads):
269
269
return [inp [0 ].zeros_like (dtype = config .floatX )]
270
270
else :
271
271
return [
272
- DimShuffle (gz .type .broadcastable , grad_order )(
272
+ DimShuffle (tuple ( s == 1 for s in gz .type .shape ) , grad_order )(
273
273
Elemwise (scalar_identity )(gz )
274
274
)
275
275
]
@@ -406,7 +406,7 @@ def get_output_info(self, dim_shuffle, *inputs):
406
406
# TODO: use LComplete instead
407
407
args .append (
408
408
dim_shuffle (
409
- input .type .broadcastable ,
409
+ tuple ( 1 if s == 1 else None for s in input .type .shape ) ,
410
410
["x" ] * difference + list (range (length )),
411
411
)(input )
412
412
)
@@ -452,11 +452,11 @@ def get_most_specialized_shape(shapes):
452
452
inplace_pattern = self .inplace_pattern
453
453
if inplace_pattern :
454
454
for overwriter , overwritten in inplace_pattern .items ():
455
- for ob , ib in zip (
455
+ for out_s , in_s in zip (
456
456
out_shapes [overwriter ],
457
- inputs [overwritten ].type .broadcastable ,
457
+ inputs [overwritten ].type .shape ,
458
458
):
459
- if ib and not ob = = 1 :
459
+ if in_s == 1 and out_s ! = 1 :
460
460
raise ValueError (
461
461
"Operation cannot be done inplace on an input "
462
462
"with broadcasted dimensions."
@@ -578,8 +578,8 @@ def L_op(self, inputs, outs, ograds):
578
578
# TODO: only count dimensions that were effectively broadcasted
579
579
to_sum = [
580
580
j
581
- for j , bcast in enumerate (ipt .type .broadcastable )
582
- if bcast and not outs [0 ].broadcastable [j ]
581
+ for j , in_s in enumerate (ipt .type .shape )
582
+ if in_s == 1 and outs [0 ].type . shape [j ] != 1
583
583
]
584
584
585
585
if to_sum :
@@ -614,7 +614,7 @@ def as_scalar(t):
614
614
f"{ str (self .scalar_op )} .grad returned { str (type (scalar_igrads ))} instead of list or tuple"
615
615
)
616
616
617
- nd = len ( inputs [0 ].type .broadcastable ) # this is the same for everyone
617
+ nd = inputs [0 ].type .ndim # this is the same for everyone
618
618
619
619
def transform (r ):
620
620
# From a graph of ScalarOps, make a graph of Broadcast ops.
@@ -897,7 +897,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
897
897
# for each input:
898
898
# same as range(ndim), but with 'x' at all broadcastable positions
899
899
orders = [
900
- [x and "x" or i for i , x in enumerate (input .type .broadcastable )]
900
+ [s == 1 and "x" or i for i , s in enumerate (input .type .shape )]
901
901
for input in inputs
902
902
]
903
903
@@ -920,7 +920,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
920
920
[
921
921
f"PyArray_ISFORTRAN({ arr } )"
922
922
for arr , var in z
923
- if not all (var .broadcastable )
923
+ if not all (s == 1 for s in var .type . shape )
924
924
]
925
925
)
926
926
# If it is a scalar, make it c contig to prevent problem with
@@ -1005,7 +1005,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1005
1005
or
1006
1006
# Use simpler code when output ndim == 0 or 1
1007
1007
# or for broadcated scalar.
1008
- all (node .outputs [0 ].broadcastable )
1008
+ all (s == 1 for s in node .outputs [0 ].type . shape )
1009
1009
):
1010
1010
if nnested :
1011
1011
all_code = [("" , "" )] * (nnested - 1 ) + [("" , code )] + ["" ]
@@ -1077,7 +1077,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1077
1077
all (o .ndim >= 1 for o in node .outputs )
1078
1078
and
1079
1079
# Don't use the contig code for broadcasted scalar.
1080
- not all (node .outputs [0 ].broadcastable )
1080
+ not all (s == 1 for s in node .outputs [0 ].type . shape )
1081
1081
):
1082
1082
contig = None
1083
1083
try :
@@ -1110,7 +1110,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1110
1110
"""
1111
1111
index = ""
1112
1112
for x , var in zip (inames + onames , inputs + node .outputs ):
1113
- if not all (var .broadcastable ):
1113
+ if not all (s == 1 for s in var .type . shape ):
1114
1114
contig += (
1115
1115
"""
1116
1116
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
@@ -1144,18 +1144,19 @@ def _c_all(self, node, nodename, inames, onames, sub):
1144
1144
)
1145
1145
if contig is not None :
1146
1146
z = list (zip (inames + onames , inputs + node .outputs ))
1147
+ all_broadcastable = all (s == 1 for s in var .type .shape )
1147
1148
cond1 = " && " .join (
1148
1149
[
1149
1150
"PyArray_ISCONTIGUOUS(%s)" % arr
1150
1151
for arr , var in z
1151
- if not all ( var . broadcastable )
1152
+ if not all_broadcastable
1152
1153
]
1153
1154
)
1154
1155
cond2 = " && " .join (
1155
1156
[
1156
1157
"PyArray_ISFORTRAN(%s)" % arr
1157
1158
for arr , var in z
1158
- if not all ( var . broadcastable )
1159
+ if not all_broadcastable
1159
1160
]
1160
1161
)
1161
1162
loop = (
@@ -1388,13 +1389,7 @@ def infer_shape(self, fgraph, node, shapes):
1388
1389
axis = self .axis
1389
1390
if axis is None :
1390
1391
return ((),)
1391
- return (
1392
- [
1393
- ishape [i ]
1394
- for (i , b ) in enumerate (node .inputs [0 ].type .broadcastable )
1395
- if i not in axis
1396
- ],
1397
- )
1392
+ return ([ishape [i ] for i in range (node .inputs [0 ].type .ndim ) if i not in axis ],)
1398
1393
1399
1394
def _c_all (self , node , name , inames , onames , sub ):
1400
1395
@@ -1419,7 +1414,7 @@ def _c_all(self, node, name, inames, onames, sub):
1419
1414
1420
1415
axis = self .axis
1421
1416
if axis is None :
1422
- axis = list (range (len ( input .type .broadcastable ) ))
1417
+ axis = list (range (input .type .ndim ))
1423
1418
1424
1419
if len (axis ) == 0 :
1425
1420
# The acc_dtype is never a downcast compared to the input dtype
0 commit comments