Skip to content

Commit 46a46af

Browse files
Replace use of broadcastable with shape in aesara.tensor.basic
1 parent 94c2e4c commit 46a46af

File tree

2 files changed

+57
-52
lines changed

2 files changed

+57
-52
lines changed

aesara/tensor/basic.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from aesara.graph.fg import FunctionGraph
2929
from aesara.graph.op import Op
3030
from aesara.graph.rewriting.utils import rewrite_graph
31-
from aesara.graph.type import Type
31+
from aesara.graph.type import HasShape, Type
3232
from aesara.link.c.op import COp
3333
from aesara.link.c.params_type import ParamsType
3434
from aesara.misc.safe_asarray import _asarray
@@ -348,8 +348,8 @@ def get_scalar_constant_value(
348348
if isinstance(inp, Constant):
349349
return np.asarray(np.shape(inp.data)[i])
350350
# The shape of a broadcastable dimension is 1
351-
if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]:
352-
return np.asarray(1)
351+
if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None:
352+
return np.asarray(inp.type.shape[i])
353353

354354
# Don't act as the constant_folding optimization here as this
355355
# fct is used too early in the optimization phase. This would
@@ -502,21 +502,16 @@ def get_scalar_constant_value(
502502
owner.inputs[1], max_recur=max_recur
503503
)
504504
grandparent = leftmost_parent.owner.inputs[0]
505-
gp_broadcastable = grandparent.type.broadcastable
505+
gp_shape = grandparent.type.shape
506506
ndim = grandparent.type.ndim
507507
if grandparent.owner and isinstance(
508508
grandparent.owner.op, Unbroadcast
509509
):
510-
ggp_broadcastable = grandparent.owner.inputs[0].broadcastable
511-
l = [
512-
b1 or b2
513-
for b1, b2 in zip(ggp_broadcastable, gp_broadcastable)
514-
]
515-
gp_broadcastable = tuple(l)
510+
ggp_shape = grandparent.owner.inputs[0].type.shape
511+
l = [s1 == 1 or s2 == 1 for s1, s2 in zip(ggp_shape, gp_shape)]
512+
gp_shape = tuple(l)
516513

517-
assert ndim == len(gp_broadcastable)
518-
519-
if not (idx < len(gp_broadcastable)):
514+
if not (idx < ndim):
520515
msg = (
521516
"get_scalar_constant_value detected "
522517
f"deterministic IndexError: x.shape[{int(idx)}] "
@@ -528,8 +523,9 @@ def get_scalar_constant_value(
528523
msg += f" x={v}"
529524
raise ValueError(msg)
530525

531-
if gp_broadcastable[idx]:
532-
return np.asarray(1)
526+
gp_shape_val = gp_shape[idx]
527+
if gp_shape_val is not None and gp_shape_val > -1:
528+
return np.asarray(gp_shape_val)
533529

534530
if isinstance(grandparent, Constant):
535531
return np.asarray(np.shape(grandparent.data)[idx])
@@ -1511,15 +1507,16 @@ def grad(self, inputs, grads):
15111507
axis_kept = []
15121508
for i, (ib, gb) in enumerate(
15131509
zip(
1514-
inputs[0].broadcastable,
1510+
inputs[0].type.shape,
15151511
# We need the dimensions corresponding to x
1516-
grads[0].broadcastable[-inputs[0].ndim :],
1512+
grads[0].type.shape[-inputs[0].ndim :],
15171513
)
15181514
):
1519-
if ib and not gb:
1515+
if ib == 1 and gb != 1:
15201516
axis_broadcasted.append(i + n_axes_to_sum)
15211517
else:
15221518
axis_kept.append(i)
1519+
15231520
gx = gz.sum(axis=axis + axis_broadcasted)
15241521
if axis_broadcasted:
15251522
new_order = ["x"] * x.ndim
@@ -1865,11 +1862,14 @@ def transpose(x, axes=None):
18651862
18661863
"""
18671864
_x = as_tensor_variable(x)
1865+
18681866
if axes is None:
1869-
axes = list(range((_x.ndim - 1), -1, -1))
1870-
ret = DimShuffle(_x.broadcastable, axes)(_x)
1871-
if _x.name and axes == list(range((_x.ndim - 1), -1, -1)):
1867+
axes = list(range((_x.type.ndim - 1), -1, -1))
1868+
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
1869+
1870+
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
18721871
ret.name = _x.name + ".T"
1872+
18731873
return ret
18741874

18751875

@@ -3207,11 +3207,11 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
32073207
if xs0 == ys0:
32083208
for i in range(xs0):
32093209
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1)
3210-
elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
3210+
elif ys0 == 1 and node.inputs[1].type.shape[curdim] == 1:
32113211
# Broadcast y
32123212
for i in range(xs0):
32133213
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1)
3214-
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
3214+
elif xs0 == 1 and node.inputs[0].type.shape[curdim] == 1:
32153215
# Broadcast x
32163216
for i in range(ys0):
32173217
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1)
@@ -3270,7 +3270,7 @@ def grad(self, inp, grads):
32703270
broadcasted_dims = [
32713271
dim
32723272
for dim in range(gz.type.ndim)
3273-
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim]
3273+
if x.type.shape[dim] == 1 and gz.type.shape[dim] != 1
32743274
]
32753275
gx = Sum(axis=broadcasted_dims)(gx)
32763276

@@ -3285,8 +3285,13 @@ def grad(self, inp, grads):
32853285
newdims.append(i)
32863286
i += 1
32873287

3288-
gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
3289-
assert gx.type.broadcastable == x.type.broadcastable
3288+
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
3289+
assert gx.type.ndim == x.type.ndim
3290+
assert all(
3291+
s1 == s2
3292+
for s1, s2 in zip(gx.type.shape, x.type.shape)
3293+
if s1 == 1 or s2 == 1
3294+
)
32903295

32913296
# if x is an integer type, then so is the output.
32923297
# this means f(x+eps) = f(x) so the gradient with respect

tests/tensor/test_basic.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,10 @@ def test_make_vector_fail(self):
458458
res = MakeVector("int32")(a, b)
459459

460460
res = MakeVector()(a)
461-
assert res.broadcastable == (True,)
461+
assert res.type.shape == (1,)
462462

463463
res = MakeVector()()
464-
assert res.broadcastable == (False,)
464+
assert res.type.shape == (0,)
465465

466466
def test_infer_shape(self):
467467
adscal = dscalar()
@@ -1665,18 +1665,18 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self):
16651665
a = self.shared(a_val, shape=(None, None, 1))
16661666
b = self.shared(b_val, shape=(1, None, 1))
16671667
c = self.join_op(1, a, b)
1668-
assert c.type.broadcastable[0] and c.type.broadcastable[2]
1669-
assert not c.type.broadcastable[1]
1668+
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
1669+
assert c.type.shape[1] != 1
16701670

16711671
# Opt can remplace the int by an Aesara constant
16721672
c = self.join_op(constant(1), a, b)
1673-
assert c.type.broadcastable[0] and c.type.broadcastable[2]
1674-
assert not c.type.broadcastable[1]
1673+
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
1674+
assert c.type.shape[1] != 1
16751675

16761676
# In case futur opt insert other useless stuff
16771677
c = self.join_op(cast(constant(1), dtype="int32"), a, b)
1678-
assert c.type.broadcastable[0] and c.type.broadcastable[2]
1679-
assert not c.type.broadcastable[1]
1678+
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
1679+
assert c.type.shape[1] != 1
16801680

16811681
f = function([], c, mode=self.mode)
16821682
topo = f.maker.fgraph.toposort()
@@ -1703,7 +1703,7 @@ def test_broadcastable_flag_assignment_mixed_thisaxes(self):
17031703
a = self.shared(a_val, shape=(None, None, 1))
17041704
b = self.shared(b_val, shape=(1, None, 1))
17051705
c = self.join_op(0, a, b)
1706-
assert not c.type.broadcastable[0]
1706+
assert c.type.shape[0] != 1
17071707

17081708
f = function([], c, mode=self.mode)
17091709
topo = f.maker.fgraph.toposort()
@@ -1736,7 +1736,7 @@ def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
17361736
a = self.shared(a_val, shape=(1, None, 1))
17371737
b = self.shared(b_val, shape=(1, None, 1))
17381738
c = self.join_op(0, a, b)
1739-
assert not c.type.broadcastable[0]
1739+
assert c.type.shape[0] != 1
17401740

17411741
f = function([], c, mode=self.mode)
17421742
topo = f.maker.fgraph.toposort()
@@ -1754,9 +1754,9 @@ def test_broadcastable_single_input_broadcastable_dimension(self):
17541754
a_val = rng.random((1, 4, 1)).astype(self.floatX)
17551755
a = self.shared(a_val, shape=(1, None, 1))
17561756
b = self.join_op(0, a)
1757-
assert b.type.broadcastable[0]
1758-
assert b.type.broadcastable[2]
1759-
assert not b.type.broadcastable[1]
1757+
assert b.type.shape[0] == 1
1758+
assert b.type.shape[2] == 1
1759+
assert b.type.shape[1] != 1
17601760

17611761
f = function([], b, mode=self.mode)
17621762
topo = f.maker.fgraph.toposort()
@@ -1782,13 +1782,13 @@ def test_broadcastable_flags_many_dims_and_inputs(self):
17821782
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
17831783
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()
17841784
f = self.join_op(0, a, b, c, d, e)
1785-
fb = f.type.broadcastable
1785+
fb = tuple(s == 1 for s in f.type.shape)
17861786
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
17871787
g = self.join_op(1, a, b, c, d, e)
1788-
gb = g.type.broadcastable
1788+
gb = tuple(s == 1 for s in g.type.shape)
17891789
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
17901790
h = self.join_op(4, a, b, c, d, e)
1791-
hb = h.type.broadcastable
1791+
hb = tuple(s == 1 for s in h.type.shape)
17921792
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
17931793

17941794
f = function([a, b, c, d, e], f, mode=self.mode)
@@ -1981,8 +1981,8 @@ def test_TensorFromScalar():
19811981
s = aes.constant(56)
19821982
t = tensor_from_scalar(s)
19831983
assert t.owner.op is tensor_from_scalar
1984-
assert t.type.broadcastable == (), t.type.broadcastable
1985-
assert t.type.ndim == 0, t.type.ndim
1984+
assert t.type.shape == ()
1985+
assert t.type.ndim == 0
19861986
assert t.type.dtype == s.type.dtype
19871987

19881988
v = eval_outputs([t])
@@ -2129,23 +2129,23 @@ def test_flatten_broadcastable():
21292129

21302130
inp = TensorType("float64", shape=(None, None, None, None))()
21312131
out = flatten(inp, ndim=2)
2132-
assert out.broadcastable == (False, False)
2132+
assert out.type.shape == (None, None)
21332133

21342134
inp = TensorType("float64", shape=(None, None, None, 1))()
21352135
out = flatten(inp, ndim=2)
2136-
assert out.broadcastable == (False, False)
2136+
assert out.type.shape == (None, None)
21372137

21382138
inp = TensorType("float64", shape=(None, 1, None, 1))()
21392139
out = flatten(inp, ndim=2)
2140-
assert out.broadcastable == (False, False)
2140+
assert out.type.shape == (None, None)
21412141

21422142
inp = TensorType("float64", shape=(None, 1, 1, 1))()
21432143
out = flatten(inp, ndim=2)
2144-
assert out.broadcastable == (False, True)
2144+
assert out.type.shape == (None, 1)
21452145

21462146
inp = TensorType("float64", shape=(1, None, 1, 1))()
21472147
out = flatten(inp, ndim=3)
2148-
assert out.broadcastable == (True, False, True)
2148+
assert out.type.shape == (1, None, 1)
21492149

21502150

21512151
def test_flatten_ndim_invalid():
@@ -2938,8 +2938,8 @@ def permute_fixed(s_input):
29382938

29392939
def test_3b_2(self):
29402940
# Test permute_row_elements on a more complex broadcasting pattern:
2941-
# input.type.broadcastable = (False, True, False),
2942-
# p.type.broadcastable = (False, False).
2941+
# input.type.shape = (None, 1, None),
2942+
# p.type.shape = (None, None).
29432943

29442944
input = TensorType("floatX", shape=(None, 1, None))()
29452945
p = imatrix()
@@ -4046,7 +4046,7 @@ def test_broadcasted(self):
40464046
B = np.asarray(np.random.random((4, 1)), dtype="float32")
40474047
for m in self.modes:
40484048
f = function([a, b], choose(a, b, mode=m))
4049-
assert choose(a, b, mode=m).broadcastable[0]
4049+
assert choose(a, b, mode=m).type.shape[0] == 1
40504050
t_c = f(A, B)
40514051
n_c = np.choose(A, B, mode=m)
40524052
assert np.allclose(t_c, n_c)

0 commit comments

Comments
 (0)