Skip to content

Commit 6898f74

Browse files
committed
Remove BroadcastTo
1 parent 5f809cf commit 6898f74

File tree

8 files changed

+14
-541
lines changed

8 files changed

+14
-541
lines changed

pytensor/link/jax/dispatch/extra_ops.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import jax.numpy as jnp
44

55
from pytensor.link.jax.dispatch.basic import jax_funcify
6-
from pytensor.tensor.basic import infer_static_shape
76
from pytensor.tensor.extra_ops import (
87
Bartlett,
9-
BroadcastTo,
108
CumOp,
119
FillDiagonal,
1210
FillDiagonalOffset,
@@ -102,18 +100,6 @@ def ravelmultiindex(*inp, mode=mode, order=order):
102100
return ravelmultiindex
103101

104102

105-
@jax_funcify.register(BroadcastTo)
106-
def jax_funcify_BroadcastTo(op, node, **kwargs):
107-
shape = node.inputs[1:]
108-
static_shape = infer_static_shape(shape)[1]
109-
110-
def broadcast_to(x, *shape):
111-
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
112-
return jnp.broadcast_to(x, shape)
113-
114-
return broadcast_to
115-
116-
117103
@jax_funcify.register(FillDiagonal)
118104
def jax_funcify_FillDiagonal(op, **kwargs):
119105
def filldiagonal(value, diagonal):

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22

33
import numba
44
import numpy as np
5-
from numba.misc.special import literal_unroll
65

76
from pytensor import config
87
from pytensor.link.numba.dispatch import basic as numba_basic
98
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
109
from pytensor.raise_op import CheckAndRaise
1110
from pytensor.tensor.extra_ops import (
1211
Bartlett,
13-
BroadcastTo,
1412
CumOp,
1513
FillDiagonal,
1614
FillDiagonalOffset,
@@ -353,29 +351,6 @@ def searchsorted(a, v):
353351
return searchsorted
354352

355353

356-
@numba_funcify.register(BroadcastTo)
357-
def numba_funcify_BroadcastTo(op, node, **kwargs):
358-
create_zeros_tuple = numba_basic.create_tuple_creator(
359-
lambda _: 0, len(node.inputs) - 1
360-
)
361-
362-
# TODO broadcastable checks
363-
@numba_basic.numba_njit
364-
def broadcast_to(x, *shape):
365-
scalars_shape = create_zeros_tuple()
366-
367-
i = 0
368-
for s_i in literal_unroll(shape):
369-
scalars_shape = numba_basic.tuple_setitem(
370-
scalars_shape, i, numba_basic.to_scalar(s_i)
371-
)
372-
i += 1
373-
374-
return np.broadcast_to(x, scalars_shape)
375-
376-
return broadcast_to
377-
378-
379354
@numba_funcify.register(CheckAndRaise)
380355
def numba_funcify_CheckAndRaise(op, node, **kwargs):
381356
error = op.exc_type

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytensor.scalar import upcast
2424
from pytensor.tensor import as_tensor_variable
2525
from pytensor.tensor import basic as at
26-
from pytensor.tensor.basic import get_vector_length, second
26+
from pytensor.tensor.basic import alloc, second
2727
from pytensor.tensor.exceptions import NotScalarConstantError
2828
from pytensor.tensor.math import abs as pt_abs
2929
from pytensor.tensor.math import all as pt_all
@@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
15841584
return tuple(result_dims)
15851585

15861586

1587-
class BroadcastTo(COp):
1588-
"""An `Op` for `numpy.broadcast_to`."""
1589-
1590-
_output_type_depends_on_input_value = True
1591-
1592-
__props__ = ()
1593-
1594-
view_map = {0: [0]}
1595-
1596-
def __call__(self, a, shape, **kwargs):
1597-
return super().__call__(a, *shape, **kwargs)
1598-
1599-
def make_node(self, a, *shape):
1600-
a = at.as_tensor_variable(a)
1601-
1602-
shape, static_shape = at.infer_static_shape(shape)
1603-
1604-
if len(shape) < a.ndim:
1605-
raise ValueError(
1606-
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
1607-
)
1608-
1609-
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
1610-
1611-
# Attempt to prevent in-place operations on this view-based output
1612-
out.tag.indestructible = True
1613-
1614-
return Apply(self, [a] + shape, [out])
1615-
1616-
def perform(self, node, inputs, output_storage):
1617-
a, *shape = inputs
1618-
z = output_storage[0]
1619-
z[0] = np.broadcast_to(a, shape)
1620-
1621-
def grad(self, inputs, outputs_gradients):
1622-
a, *shape = inputs
1623-
(dout,) = outputs_gradients
1624-
1625-
# Determine the dimensions that were added by broadcasting
1626-
new_dims = list(range(dout.ndim - a.ndim))
1627-
1628-
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
1629-
1630-
# Determine the dimensions that were broadcast
1631-
_, static_shape = at.infer_static_shape(shape)
1632-
1633-
# TODO: This needs to be performed at run-time when static shape
1634-
# information isn't available.
1635-
bcast_sums = [
1636-
i
1637-
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
1638-
if a_s == 1 and s_s != 1
1639-
]
1640-
1641-
if bcast_sums:
1642-
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
1643-
1644-
return [d_wrt_a] + [
1645-
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
1646-
]
1647-
1648-
def infer_shape(self, fgraph, node, ins_shapes):
1649-
return [node.inputs[1:]]
1650-
1651-
def c_code(self, node, name, inputs, outputs, sub):
1652-
inp_dims = node.inputs[0].ndim
1653-
out_dims = node.outputs[0].ndim
1654-
new_dims = out_dims - inp_dims
1655-
1656-
(x, *shape) = inputs
1657-
(out,) = outputs
1658-
fail = sub["fail"]
1659-
1660-
# TODO: Could just use `PyArray_Return`, no?
1661-
dims_array = ", ".join(
1662-
[
1663-
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
1664-
for i, shape in enumerate(shape)
1665-
]
1666-
)
1667-
1668-
src = (
1669-
"""
1670-
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
1671-
1672-
NpyIter *iter;
1673-
PyArrayObject *ops[1] = {%(x)s};
1674-
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
1675-
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
1676-
PyArray_Descr *op_dtypes[1] = {NULL};
1677-
int oa_ndim = %(out_dims)s;
1678-
int* op_axes[1] = {NULL};
1679-
npy_intp buffersize = 0;
1680-
1681-
for(int i = 0; i < %(inp_dims)s; i++)
1682-
{
1683-
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
1684-
{
1685-
PyErr_Format(PyExc_ValueError,
1686-
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
1687-
i,
1688-
(long long int) itershape[i + %(new_dims)s],
1689-
(long long int) PyArray_DIMS(%(x)s)[i]
1690-
);
1691-
%(fail)s
1692-
}
1693-
}
1694-
1695-
iter = NpyIter_AdvancedNew(
1696-
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
1697-
);
1698-
%(out)s = NpyIter_GetIterView(iter, 0);
1699-
1700-
if(%(out)s == NULL){
1701-
NpyIter_Deallocate(iter);
1702-
%(fail)s;
1703-
}
1704-
1705-
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
1706-
%(fail)s;
1707-
}
1708-
1709-
"""
1710-
% locals()
1711-
)
1712-
1713-
return src
1714-
1715-
def c_code_cache_version(self):
1716-
return (2,)
1717-
1718-
1719-
broadcast_to_ = BroadcastTo()
1720-
1721-
17221587
def geomspace(start, end, steps, base=10.0):
17231588
from pytensor.tensor.math import log
17241589

@@ -1762,13 +1627,7 @@ def broadcast_to(
17621627
broadcasted array may refer to a single memory location.
17631628
17641629
"""
1765-
x = at.as_tensor(x)
1766-
shape_len = get_vector_length(shape)
1767-
1768-
if x.ndim == 0 and shape_len == 0:
1769-
return x
1770-
1771-
return broadcast_to_(x, shape)
1630+
return alloc(x, *shape)
17721631

17731632

17741633
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:

pytensor/tensor/rewriting/extra_ops.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.graph.rewriting.basic import node_rewriter
33
from pytensor.tensor.basic import Alloc, as_tensor_variable
44
from pytensor.tensor.elemwise import Elemwise
5-
from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique
5+
from pytensor.tensor.extra_ops import Repeat, Unique
66
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
77

88

@@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node):
6060
return [new_x]
6161

6262

63-
@register_useless
64-
@register_canonicalize
65-
@node_rewriter([Unique])
66-
def local_Unique_BroadcastTo_lift(fgraph, node):
67-
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
68-
69-
This isn't really so much a lift as a "reduction/consumption".
70-
"""
71-
if not isinstance(node.op, Unique):
72-
return False
73-
74-
if (
75-
node.op.return_index
76-
or node.op.return_inverse
77-
or node.op.return_counts
78-
or node.op.axis is not None
79-
):
80-
return False
81-
82-
bcast_var = node.inputs[0]
83-
84-
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
85-
return False
86-
87-
bcasted_var, *bcast_shape = bcast_var.owner.inputs
88-
89-
new_unique, *_ = node.op.make_node(bcasted_var).outputs
90-
91-
old_out = node.outputs[0]
92-
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
93-
return [new_x]
94-
95-
9663
@register_useless
9764
@register_canonicalize
9865
@node_rewriter([Unique])
@@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node):
161128
old_out = node.outputs[0]
162129
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
163130
return [new_x]
164-
165-
166-
@register_useless
167-
@register_canonicalize
168-
@node_rewriter([BroadcastTo])
169-
def local_remove_scalar_BroadcastTo(fgraph, node):
170-
bcast_shape = node.inputs[1:]
171-
172-
if not bcast_shape:
173-
bcasted_var = node.inputs[0]
174-
# If this isn't true, the graph is invalid
175-
assert bcasted_var.ndim == 0
176-
return [bcasted_var]

tests/link/jax/test_extra_ops.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor.graph.fg import FunctionGraph
88
from pytensor.graph.op import get_test_value
99
from pytensor.tensor import extra_ops as at_extra_ops
10-
from pytensor.tensor.type import matrix, vector
10+
from pytensor.tensor.type import matrix
1111
from tests.link.jax.test_basic import compare_jax_and_py
1212

1313

@@ -63,29 +63,6 @@ def test_extra_ops():
6363
)
6464

6565

66-
@pytest.mark.parametrize(
67-
"x, shape",
68-
[
69-
(
70-
set_test_value(
71-
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
72-
),
73-
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
74-
),
75-
(
76-
set_test_value(
77-
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
78-
),
79-
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
80-
),
81-
],
82-
)
83-
def test_BroadcastTo(x, shape):
84-
out = at_extra_ops.broadcast_to(x, shape)
85-
fgraph = FunctionGraph(outputs=[out])
86-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
87-
88-
8966
@pytest.mark.xfail(
9067
version_parse(jax.__version__) >= version_parse("0.2.12"),
9168
reason="Omnistaging cannot be disabled",

tests/link/numba/test_extra_ops.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,41 +36,6 @@ def test_Bartlett(val):
3636
)
3737

3838

39-
@pytest.mark.parametrize(
40-
"x, shape",
41-
[
42-
(
43-
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
44-
[set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]],
45-
),
46-
(
47-
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
48-
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
49-
),
50-
(
51-
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
52-
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
53-
),
54-
(
55-
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
56-
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
57-
),
58-
],
59-
)
60-
def test_BroadcastTo(x, shape):
61-
g = extra_ops.BroadcastTo()(x, shape)
62-
g_fg = FunctionGraph(outputs=[g])
63-
64-
compare_numba_and_py(
65-
g_fg,
66-
[
67-
i.tag.test_value
68-
for i in g_fg.inputs
69-
if not isinstance(i, (SharedVariable, Constant))
70-
],
71-
)
72-
73-
7439
@pytest.mark.parametrize(
7540
"val, axis, mode",
7641
[

0 commit comments

Comments
 (0)