Skip to content

Commit df4183d

Browse files
authored
Use static-only broadcasting rules to compute shape of broadcasting (#345)
1 parent b9c4f20 commit df4183d

File tree

5 files changed

+60
-114
lines changed

5 files changed

+60
-114
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 31 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from collections.abc import Collection
2-
from functools import reduce
32
from typing import Iterable, Set, Tuple, Union
43

54
import numpy as np
6-
import numpy.core.numeric
75
from numpy.core.multiarray import normalize_axis_index
86

97
import pytensor
@@ -14,7 +12,7 @@
1412
disconnected_type,
1513
grad_undefined,
1614
)
17-
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
15+
from pytensor.graph.basic import Apply, Constant, Variable
1816
from pytensor.graph.op import Op
1917
from pytensor.link.c.op import COp
2018
from pytensor.link.c.params_type import ParamsType
@@ -23,12 +21,12 @@
2321
from pytensor.raise_op import Assert
2422
from pytensor.scalar import int32 as int_t
2523
from pytensor.scalar import upcast
26-
from pytensor.scalar.basic import Composite
2724
from pytensor.tensor import basic as at
2825
from pytensor.tensor import get_vector_length
2926
from pytensor.tensor.exceptions import NotScalarConstantError
3027
from pytensor.tensor.math import abs as at_abs
31-
from pytensor.tensor.math import all as at_all
28+
from pytensor.tensor.math import all as pt_all
29+
from pytensor.tensor.math import eq as pt_eq
3230
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
3331
from pytensor.tensor.math import sum as at_sum
3432
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
536534

537535
if assert_nonneg:
538536
assert_op = Assert("Input to bincount has negative values!")
539-
x = assert_op(x, at_all(x >= 0))
537+
x = assert_op(x, pt_all(x >= 0))
540538

541539
max_value = at.cast(x.max() + 1, "int64")
542540

@@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
14361434
return RavelMultiIndex(mode=mode, order=order)(*args)
14371435

14381436

1437+
_broadcast_assert = Assert(
1438+
"Could not broadcast dimensions. Broadcasting is only allowed along "
1439+
"axes that have a statically known length 1. Use `specify_shape` to "
1440+
"inform PyTensor of a known shape."
1441+
)
1442+
1443+
14391444
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
14401445
"""Compute the shape resulting from broadcasting arrays.
14411446
@@ -1510,119 +1515,45 @@ def broadcast_shape_iter(
15101515
result_dims = []
15111516

15121517
for dim_shapes in zip(*array_shapes):
1513-
# Get the shapes in this dimension that are not definitively
1514-
# broadcastable (i.e. not symbolically known to be broadcastable)
1515-
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
1518+
# Get the shapes in this dimension that are not broadcastable
1519+
# (i.e. not symbolically known to be broadcastable)
1520+
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
15161521

1517-
if len(maybe_non_bcast_shapes) == 0:
1522+
if len(non_bcast_shapes) == 0:
15181523
# Every shape was broadcastable in this dimension
15191524
result_dims.append(one_at)
1520-
elif len(maybe_non_bcast_shapes) == 1:
1525+
elif len(non_bcast_shapes) == 1:
15211526
# Only one shape might not be broadcastable in this dimension
1522-
result_dims.extend(maybe_non_bcast_shapes)
1527+
result_dims.extend(non_bcast_shapes)
15231528
else:
15241529
# More than one shape might not be broadcastable in this dimension
1525-
15261530
nonconst_nb_shapes: Set[int] = set()
15271531
const_nb_shapes: Set[Variable] = set()
1528-
for shape in maybe_non_bcast_shapes:
1532+
for shape in non_bcast_shapes:
15291533
if isinstance(shape, Constant):
15301534
const_nb_shapes.add(shape.value.item())
15311535
else:
15321536
nonconst_nb_shapes.add(shape)
15331537

15341538
if len(const_nb_shapes) > 1:
1535-
raise ValueError("Could not broadcast dimensions")
1536-
elif len(const_nb_shapes) == 1:
1537-
(const_nb_shape,) = const_nb_shapes
1538-
1539-
assert const_nb_shape != 1
1540-
1541-
const_nt_shape_var = pytensor.scalar.ScalarConstant(
1542-
pytensor.scalar.int64, const_nb_shape
1539+
raise ValueError(
1540+
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
15431541
)
15441542

1545-
if len(nonconst_nb_shapes) > 0:
1546-
# All the potential non-broadcast shapes need to either
1547-
# be broadcastable or equal to the one non-broadcastable
1548-
# constant `const_nt_shape_var`.
1549-
assert_dim = Assert("Could not broadcast dimensions")
1550-
1551-
scalar_nonconst_nb_shapes = [
1552-
at.scalar_from_tensor(s)
1553-
if isinstance(s.type, TensorType)
1554-
else s
1555-
for s in nonconst_nb_shapes
1556-
]
1557-
1558-
dummy_nonconst_nb_shapes = [
1559-
aes.get_scalar_type(dtype=v.dtype)()
1560-
for v in scalar_nonconst_nb_shapes
1561-
]
1562-
assert_cond = reduce(
1563-
aes.and_,
1564-
(
1565-
aes.or_(
1566-
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
1567-
)
1568-
for nbv in dummy_nonconst_nb_shapes
1569-
),
1570-
)
1571-
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
1572-
1573-
bcast_dim = assert_dim(
1574-
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
1575-
)
1576-
else:
1577-
bcast_dim = const_nt_shape_var
1543+
if len(const_nb_shapes) == 1:
1544+
(first_length,) = const_nb_shapes
1545+
other_lengths = nonconst_nb_shapes
1546+
first_length = aes.as_scalar(first_length)
15781547
else:
1579-
# There are no constant, non-broadcastable shapes in this
1580-
# dimension.
1581-
1582-
all_dims_equal = all(
1583-
# TODO FIXME: This is a largely deficient, and expensive, means
1584-
# of comparing graphs (and especially shapes)
1585-
equal_computations([maybe_non_bcast_shapes[0]], [dim])
1586-
for dim in maybe_non_bcast_shapes[1:]
1587-
)
1548+
first_length, *other_lengths = nonconst_nb_shapes
15881549

1589-
if all_dims_equal:
1590-
result_dims.append(maybe_non_bcast_shapes[0])
1591-
continue
1592-
1593-
scalar_maybe_non_bcast_shapes = [
1594-
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
1595-
for s in maybe_non_bcast_shapes
1596-
]
1597-
dummy_maybe_non_bcast_shapes = [
1598-
aes.get_scalar_type(dtype=v.dtype)()
1599-
for v in scalar_maybe_non_bcast_shapes
1600-
]
1601-
non_bcast_vec = [
1602-
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
1603-
for nbv in dummy_maybe_non_bcast_shapes
1604-
]
1605-
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
1606-
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
1607-
1608-
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
1609-
1610-
assert_dim = Assert("Could not broadcast dimensions")
1611-
assert_cond = reduce(
1612-
aes.and_,
1613-
(
1614-
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
1615-
for nbv in non_bcast_vec
1616-
),
1617-
)
1618-
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
1619-
1620-
bcast_dim = assert_dim(
1621-
dim_max_op(*scalar_maybe_non_bcast_shapes),
1622-
assert_cond_op(*scalar_maybe_non_bcast_shapes),
1623-
)
1550+
if len(other_lengths) == 0:
1551+
result_dims.append(first_length)
1552+
continue
16241553

1625-
result_dims.append(bcast_dim)
1554+
# Add assert that all remaining shapes are equal
1555+
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
1556+
result_dims.append(_broadcast_assert(first_length, condition))
16261557

16271558
return tuple(result_dims)
16281559

tests/tensor/rewriting/test_basic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,8 +1703,12 @@ def verify_op_count(f, count, cls):
17031703
],
17041704
)
17051705
def test_basic(self, expr, x_shape, y_shape):
1706-
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
1707-
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
1706+
x = at.tensor(
1707+
dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
1708+
)
1709+
y = at.tensor(
1710+
dtype="int64", shape=(1 if val == 1 else None for val in y_shape), name="y"
1711+
)
17081712
z = expr(x, y)
17091713

17101714
z_opt = pytensor.function(
@@ -1878,7 +1882,8 @@ def test_multi_input_single_alloc(self):
18781882
mode=self.fast_run_mode,
18791883
)
18801884
self.verify_op_count(func, 0, Alloc)
1881-
self.verify_op_count(func, 1, Assert)
1885+
# The second assert is from the shape check...
1886+
self.verify_op_count(func, 2, Assert)
18821887

18831888
def test_misc(self):
18841889
x = row(dtype=self.dtype)

tests/tensor/rewriting/test_math.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -608,9 +608,10 @@ def test_mul_div_cases(self):
608608
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
609609
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
610610
# must broadcast as their is a dimshuffle in the computation
611-
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
611+
# The broadcast leads to an extra elemwise to check compatibility
612+
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
612613
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
613-
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
614+
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
614615
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
615616
]
616617
):
@@ -621,9 +622,12 @@ def test_mul_div_cases(self):
621622
elem = [t for t in topo if isinstance(t.op, Elemwise)]
622623
assert len(elem) == nb_elemwise
623624
assert isinstance(elem[0].op, (Elemwise,))
624-
assert isinstance(
625-
elem[0].op.scalar_op,
626-
(aes.basic.Reciprocal, aes.basic.TrueDiv),
625+
assert any(
626+
isinstance(
627+
el.op.scalar_op,
628+
(aes.basic.Reciprocal, aes.basic.TrueDiv),
629+
)
630+
for el in elem
627631
)
628632
assert out_dtype == out.dtype
629633

tests/tensor/test_extra_ops.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,9 @@ def shape_tuple(x, use_bcast=True):
10861086
assert any(
10871087
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
10881088
)
1089-
assert np.array_equal([z.eval() for z in b_at], b.shape)
1089+
# This should fail because it would need dynamic broadcasting
1090+
with pytest.raises(AssertionError):
1091+
assert np.array_equal([z.eval() for z in b_at], b.shape)
10901092
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
10911093
assert np.array_equal([z.eval() for z in b_at], b.shape)
10921094

@@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants():
11831185
@pytest.mark.parametrize(
11841186
("s1_vals", "s2_vals", "exp_res"),
11851187
[
1186-
((2, 2), (1, 2), (2, 2)),
1187-
((0, 2), (1, 2), (0, 2)),
1188+
((2, 2), (1, 2), AssertionError),
1189+
((0, 2), (1, 2), AssertionError),
11881190
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
11891191
],
11901192
)
@@ -1203,7 +1205,11 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
12031205
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
12041206
res = at.as_tensor(res)
12051207

1206-
assert tuple(res.eval(eval_point)) == exp_res
1208+
if exp_res is AssertionError:
1209+
with pytest.raises(AssertionError):
1210+
res.eval(eval_point)
1211+
else:
1212+
assert tuple(res.eval(eval_point)) == exp_res
12071213

12081214

12091215
def test_broadcast_shape_symbolic_one_symbolic():
@@ -1395,7 +1401,7 @@ def test_inplace(self):
13951401

13961402

13971403
def test_broadcast_arrays():
1398-
x, y = at.dvector(), at.dmatrix()
1404+
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
13991405
x_bcast, y_bcast = broadcast_arrays(x, y)
14001406

14011407
py_mode = Mode("py", None)

tests/unittest_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _compile_and_check(
255255
# Check that the Op is removed from the compiled function.
256256
if check_topo:
257257
topo_shape = shapes_function.maker.fgraph.toposort()
258-
assert not any(isinstance(t.op, cls) for t in topo_shape)
258+
assert not any(t in outputs for t in topo_shape)
259259
topo_out = outputs_function.maker.fgraph.toposort()
260260
assert any(isinstance(t.op, cls) for t in topo_out)
261261
# Check that the shape produced agrees with the actual shape.

0 commit comments

Comments
 (0)