Skip to content

Commit 94c2e4c

Browse files
Replace use of broadcastable with shape in aesara.tensor.elemwise
1 parent f1dc089 commit 94c2e4c

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

aesara/tensor/elemwise.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,17 @@ class DimShuffle(ExternalCOp):
6262
If `j = new_order[i]` is an index, the output's ith dimension
6363
will be the input's jth dimension.
6464
If `new_order[i]` is `x`, the output's ith dimension will
65-
be 1 and Broadcast operations will be allowed to do broadcasting
65+
be 1 and broadcast operations will be allowed to do broadcasting
6666
over that dimension.
6767
68-
If `input.broadcastable[i] == False` then `i` must be found in new_order.
68+
If `input.type.shape[i] != 1` then `i` must be found in `new_order`.
6969
Broadcastable dimensions, on the other hand, can be discarded.
7070
7171
.. code-block:: python
7272
7373
DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
7474
75-
This op will only work on 3d tensors with no broadcastable
75+
This `Op` will only work on 3d tensors with no broadcastable
7676
dimensions. The first dimension will be broadcastable,
7777
then we will have the third dimension of the input tensor as
7878
the second of the resulting tensor, etc. If the tensor has
@@ -83,7 +83,7 @@ class DimShuffle(ExternalCOp):
8383
8484
DimShuffle((True, False), [1])
8585
86-
This op will only work on 2d tensors with the first dimension
86+
This `Op` will only work on 2d tensors with the first dimension
8787
broadcastable.
8888
The second dimension of the input tensor will be the first dimension of
8989
the resulting tensor.
@@ -186,7 +186,7 @@ def __setstate__(self, state):
186186

187187
def make_node(self, _input):
188188
input = as_tensor_variable(_input)
189-
ib = tuple(input.type.broadcastable)
189+
ib = tuple(s == 1 for s in input.type.shape)
190190
if ib != self.input_broadcastable:
191191
if len(ib) != len(self.input_broadcastable):
192192
raise TypeError(
@@ -258,7 +258,7 @@ def grad(self, inp, grads):
258258
(x,) = inp
259259
(gz,) = grads
260260
gz = as_tensor_variable(gz)
261-
grad_order = ["x"] * len(x.type.broadcastable)
261+
grad_order = ["x"] * x.type.ndim
262262
for i, v in enumerate(self.new_order):
263263
if v != "x":
264264
grad_order[v] = i
@@ -269,7 +269,7 @@ def grad(self, inp, grads):
269269
return [inp[0].zeros_like(dtype=config.floatX)]
270270
else:
271271
return [
272-
DimShuffle(gz.type.broadcastable, grad_order)(
272+
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
273273
Elemwise(scalar_identity)(gz)
274274
)
275275
]
@@ -406,7 +406,7 @@ def get_output_info(self, dim_shuffle, *inputs):
406406
# TODO: use LComplete instead
407407
args.append(
408408
dim_shuffle(
409-
input.type.broadcastable,
409+
tuple(1 if s == 1 else None for s in input.type.shape),
410410
["x"] * difference + list(range(length)),
411411
)(input)
412412
)
@@ -452,11 +452,11 @@ def get_most_specialized_shape(shapes):
452452
inplace_pattern = self.inplace_pattern
453453
if inplace_pattern:
454454
for overwriter, overwritten in inplace_pattern.items():
455-
for ob, ib in zip(
455+
for out_s, in_s in zip(
456456
out_shapes[overwriter],
457-
inputs[overwritten].type.broadcastable,
457+
inputs[overwritten].type.shape,
458458
):
459-
if ib and not ob == 1:
459+
if in_s == 1 and out_s != 1:
460460
raise ValueError(
461461
"Operation cannot be done inplace on an input "
462462
"with broadcasted dimensions."
@@ -578,8 +578,8 @@ def L_op(self, inputs, outs, ograds):
578578
# TODO: only count dimensions that were effectively broadcasted
579579
to_sum = [
580580
j
581-
for j, bcast in enumerate(ipt.type.broadcastable)
582-
if bcast and not outs[0].broadcastable[j]
581+
for j, in_s in enumerate(ipt.type.shape)
582+
if in_s == 1 and outs[0].type.shape[j] != 1
583583
]
584584

585585
if to_sum:
@@ -614,7 +614,7 @@ def as_scalar(t):
614614
f"{str(self.scalar_op)}.grad returned {str(type(scalar_igrads))} instead of list or tuple"
615615
)
616616

617-
nd = len(inputs[0].type.broadcastable) # this is the same for everyone
617+
nd = inputs[0].type.ndim # this is the same for everyone
618618

619619
def transform(r):
620620
# From a graph of ScalarOps, make a graph of Broadcast ops.
@@ -897,7 +897,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
897897
# for each input:
898898
# same as range(ndim), but with 'x' at all broadcastable positions
899899
orders = [
900-
[x and "x" or i for i, x in enumerate(input.type.broadcastable)]
900+
[s == 1 and "x" or i for i, s in enumerate(input.type.shape)]
901901
for input in inputs
902902
]
903903

@@ -920,7 +920,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
920920
[
921921
f"PyArray_ISFORTRAN({arr})"
922922
for arr, var in z
923-
if not all(var.broadcastable)
923+
if not all(s == 1 for s in var.type.shape)
924924
]
925925
)
926926
# If it is a scalar, make it c contig to prevent problem with
@@ -1005,7 +1005,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
10051005
or
10061006
# Use simpler code when output ndim == 0 or 1
10071007
# or for broadcated scalar.
1008-
all(node.outputs[0].broadcastable)
1008+
all(s == 1 for s in node.outputs[0].type.shape)
10091009
):
10101010
if nnested:
10111011
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
@@ -1077,7 +1077,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
10771077
all(o.ndim >= 1 for o in node.outputs)
10781078
and
10791079
# Don't use the contig code for broadcasted scalar.
1080-
not all(node.outputs[0].broadcastable)
1080+
not all(s == 1 for s in node.outputs[0].type.shape)
10811081
):
10821082
contig = None
10831083
try:
@@ -1110,7 +1110,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
11101110
"""
11111111
index = ""
11121112
for x, var in zip(inames + onames, inputs + node.outputs):
1113-
if not all(var.broadcastable):
1113+
if not all(s == 1 for s in var.type.shape):
11141114
contig += (
11151115
"""
11161116
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
@@ -1144,18 +1144,19 @@ def _c_all(self, node, nodename, inames, onames, sub):
11441144
)
11451145
if contig is not None:
11461146
z = list(zip(inames + onames, inputs + node.outputs))
1147+
all_broadcastable = all(s == 1 for s in var.type.shape)
11471148
cond1 = " && ".join(
11481149
[
11491150
"PyArray_ISCONTIGUOUS(%s)" % arr
11501151
for arr, var in z
1151-
if not all(var.broadcastable)
1152+
if not all_broadcastable
11521153
]
11531154
)
11541155
cond2 = " && ".join(
11551156
[
11561157
"PyArray_ISFORTRAN(%s)" % arr
11571158
for arr, var in z
1158-
if not all(var.broadcastable)
1159+
if not all_broadcastable
11591160
]
11601161
)
11611162
loop = (
@@ -1388,13 +1389,7 @@ def infer_shape(self, fgraph, node, shapes):
13881389
axis = self.axis
13891390
if axis is None:
13901391
return ((),)
1391-
return (
1392-
[
1393-
ishape[i]
1394-
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
1395-
if i not in axis
1396-
],
1397-
)
1392+
return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)
13981393

13991394
def _c_all(self, node, name, inames, onames, sub):
14001395

@@ -1419,7 +1414,7 @@ def _c_all(self, node, name, inames, onames, sub):
14191414

14201415
axis = self.axis
14211416
if axis is None:
1422-
axis = list(range(len(input.type.broadcastable)))
1417+
axis = list(range(input.type.ndim))
14231418

14241419
if len(axis) == 0:
14251420
# The acc_dtype is never a downcast compared to the input dtype

0 commit comments

Comments
 (0)