|
1 | 1 | from collections.abc import Collection
|
2 |
| -from functools import reduce |
3 | 2 | from typing import Iterable, Set, Tuple, Union
|
4 | 3 |
|
5 | 4 | import numpy as np
|
6 |
| -import numpy.core.numeric |
7 | 5 | from numpy.core.multiarray import normalize_axis_index
|
8 | 6 |
|
9 | 7 | import pytensor
|
|
14 | 12 | disconnected_type,
|
15 | 13 | grad_undefined,
|
16 | 14 | )
|
17 |
| -from pytensor.graph.basic import Apply, Constant, Variable, equal_computations |
| 15 | +from pytensor.graph.basic import Apply, Constant, Variable |
18 | 16 | from pytensor.graph.op import Op
|
19 | 17 | from pytensor.link.c.op import COp
|
20 | 18 | from pytensor.link.c.params_type import ParamsType
|
|
23 | 21 | from pytensor.raise_op import Assert
|
24 | 22 | from pytensor.scalar import int32 as int_t
|
25 | 23 | from pytensor.scalar import upcast
|
26 |
| -from pytensor.scalar.basic import Composite |
27 | 24 | from pytensor.tensor import basic as at
|
28 | 25 | from pytensor.tensor import get_vector_length
|
29 | 26 | from pytensor.tensor.exceptions import NotScalarConstantError
|
30 | 27 | from pytensor.tensor.math import abs as at_abs
|
31 |
| -from pytensor.tensor.math import all as at_all |
| 28 | +from pytensor.tensor.math import all as pt_all |
| 29 | +from pytensor.tensor.math import eq as pt_eq |
32 | 30 | from pytensor.tensor.math import ge, lt, maximum, minimum, prod
|
33 | 31 | from pytensor.tensor.math import sum as at_sum
|
34 | 32 | from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
|
@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
|
536 | 534 |
|
537 | 535 | if assert_nonneg:
|
538 | 536 | assert_op = Assert("Input to bincount has negative values!")
|
539 |
| - x = assert_op(x, at_all(x >= 0)) |
| 537 | + x = assert_op(x, pt_all(x >= 0)) |
540 | 538 |
|
541 | 539 | max_value = at.cast(x.max() + 1, "int64")
|
542 | 540 |
|
@@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
|
1436 | 1434 | return RavelMultiIndex(mode=mode, order=order)(*args)
|
1437 | 1435 |
|
1438 | 1436 |
|
| 1437 | +_broadcast_assert = Assert( |
| 1438 | + "Could not broadcast dimensions. Broadcasting is only allowed along " |
| 1439 | + "axes that have a statically known length 1. Use `specify_shape` to " |
| 1440 | + "inform PyTensor of a known shape." |
| 1441 | +) |
| 1442 | + |
| 1443 | + |
1439 | 1444 | def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
|
1440 | 1445 | """Compute the shape resulting from broadcasting arrays.
|
1441 | 1446 |
|
@@ -1510,119 +1515,45 @@ def broadcast_shape_iter(
|
1510 | 1515 | result_dims = []
|
1511 | 1516 |
|
1512 | 1517 | for dim_shapes in zip(*array_shapes):
|
1513 |
| - # Get the shapes in this dimension that are not definitively |
1514 |
| - # broadcastable (i.e. not symbolically known to be broadcastable) |
1515 |
| - maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] |
| 1518 | + # Get the shapes in this dimension that are not broadcastable |
| 1519 | + # (i.e. not symbolically known to be broadcastable) |
| 1520 | + non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] |
1516 | 1521 |
|
1517 |
| - if len(maybe_non_bcast_shapes) == 0: |
| 1522 | + if len(non_bcast_shapes) == 0: |
1518 | 1523 | # Every shape was broadcastable in this dimension
|
1519 | 1524 | result_dims.append(one_at)
|
1520 |
| - elif len(maybe_non_bcast_shapes) == 1: |
| 1525 | + elif len(non_bcast_shapes) == 1: |
1521 | 1526 | # Only one shape might not be broadcastable in this dimension
|
1522 |
| - result_dims.extend(maybe_non_bcast_shapes) |
| 1527 | + result_dims.extend(non_bcast_shapes) |
1523 | 1528 | else:
|
1524 | 1529 | # More than one shape might not be broadcastable in this dimension
|
1525 |
| - |
1526 | 1530 | nonconst_nb_shapes: Set[int] = set()
|
1527 | 1531 | const_nb_shapes: Set[Variable] = set()
|
1528 |
| - for shape in maybe_non_bcast_shapes: |
| 1532 | + for shape in non_bcast_shapes: |
1529 | 1533 | if isinstance(shape, Constant):
|
1530 | 1534 | const_nb_shapes.add(shape.value.item())
|
1531 | 1535 | else:
|
1532 | 1536 | nonconst_nb_shapes.add(shape)
|
1533 | 1537 |
|
1534 | 1538 | if len(const_nb_shapes) > 1:
|
1535 |
| - raise ValueError("Could not broadcast dimensions") |
1536 |
| - elif len(const_nb_shapes) == 1: |
1537 |
| - (const_nb_shape,) = const_nb_shapes |
1538 |
| - |
1539 |
| - assert const_nb_shape != 1 |
1540 |
| - |
1541 |
| - const_nt_shape_var = pytensor.scalar.ScalarConstant( |
1542 |
| - pytensor.scalar.int64, const_nb_shape |
| 1539 | + raise ValueError( |
| 1540 | + f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}." |
1543 | 1541 | )
|
1544 | 1542 |
|
1545 |
| - if len(nonconst_nb_shapes) > 0: |
1546 |
| - # All the potential non-broadcast shapes need to either |
1547 |
| - # be broadcastable or equal to the one non-broadcastable |
1548 |
| - # constant `const_nt_shape_var`. |
1549 |
| - assert_dim = Assert("Could not broadcast dimensions") |
1550 |
| - |
1551 |
| - scalar_nonconst_nb_shapes = [ |
1552 |
| - at.scalar_from_tensor(s) |
1553 |
| - if isinstance(s.type, TensorType) |
1554 |
| - else s |
1555 |
| - for s in nonconst_nb_shapes |
1556 |
| - ] |
1557 |
| - |
1558 |
| - dummy_nonconst_nb_shapes = [ |
1559 |
| - aes.get_scalar_type(dtype=v.dtype)() |
1560 |
| - for v in scalar_nonconst_nb_shapes |
1561 |
| - ] |
1562 |
| - assert_cond = reduce( |
1563 |
| - aes.and_, |
1564 |
| - ( |
1565 |
| - aes.or_( |
1566 |
| - aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var) |
1567 |
| - ) |
1568 |
| - for nbv in dummy_nonconst_nb_shapes |
1569 |
| - ), |
1570 |
| - ) |
1571 |
| - assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond]) |
1572 |
| - |
1573 |
| - bcast_dim = assert_dim( |
1574 |
| - const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes) |
1575 |
| - ) |
1576 |
| - else: |
1577 |
| - bcast_dim = const_nt_shape_var |
| 1543 | + if len(const_nb_shapes) == 1: |
| 1544 | + (first_length,) = const_nb_shapes |
| 1545 | + other_lengths = nonconst_nb_shapes |
| 1546 | + first_length = aes.as_scalar(first_length) |
1578 | 1547 | else:
|
1579 |
| - # There are no constant, non-broadcastable shapes in this |
1580 |
| - # dimension. |
1581 |
| - |
1582 |
| - all_dims_equal = all( |
1583 |
| - # TODO FIXME: This is a largely deficient, and expensive, means |
1584 |
| - # of comparing graphs (and especially shapes) |
1585 |
| - equal_computations([maybe_non_bcast_shapes[0]], [dim]) |
1586 |
| - for dim in maybe_non_bcast_shapes[1:] |
1587 |
| - ) |
| 1548 | + first_length, *other_lengths = nonconst_nb_shapes |
1588 | 1549 |
|
1589 |
| - if all_dims_equal: |
1590 |
| - result_dims.append(maybe_non_bcast_shapes[0]) |
1591 |
| - continue |
1592 |
| - |
1593 |
| - scalar_maybe_non_bcast_shapes = [ |
1594 |
| - at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s |
1595 |
| - for s in maybe_non_bcast_shapes |
1596 |
| - ] |
1597 |
| - dummy_maybe_non_bcast_shapes = [ |
1598 |
| - aes.get_scalar_type(dtype=v.dtype)() |
1599 |
| - for v in scalar_maybe_non_bcast_shapes |
1600 |
| - ] |
1601 |
| - non_bcast_vec = [ |
1602 |
| - aes.switch(aes.eq(nbv, 1), -one_at, nbv) |
1603 |
| - for nbv in dummy_maybe_non_bcast_shapes |
1604 |
| - ] |
1605 |
| - dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) |
1606 |
| - dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) |
1607 |
| - |
1608 |
| - dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) |
1609 |
| - |
1610 |
| - assert_dim = Assert("Could not broadcast dimensions") |
1611 |
| - assert_cond = reduce( |
1612 |
| - aes.and_, |
1613 |
| - ( |
1614 |
| - aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) |
1615 |
| - for nbv in non_bcast_vec |
1616 |
| - ), |
1617 |
| - ) |
1618 |
| - assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond]) |
1619 |
| - |
1620 |
| - bcast_dim = assert_dim( |
1621 |
| - dim_max_op(*scalar_maybe_non_bcast_shapes), |
1622 |
| - assert_cond_op(*scalar_maybe_non_bcast_shapes), |
1623 |
| - ) |
| 1550 | + if len(other_lengths) == 0: |
| 1551 | + result_dims.append(first_length) |
| 1552 | + continue |
1624 | 1553 |
|
1625 |
| - result_dims.append(bcast_dim) |
| 1554 | + # Add assert that all remaining shapes are equal |
| 1555 | + condition = pt_all([pt_eq(first_length, other) for other in other_lengths]) |
| 1556 | + result_dims.append(_broadcast_assert(first_length, condition)) |
1626 | 1557 |
|
1627 | 1558 | return tuple(result_dims)
|
1628 | 1559 |
|
|
0 commit comments