Skip to content

Commit 5f809cf

Browse files
committed
Simplify rewrites by assuming Elemwise / Alloc shapes are correct
1 parent 2c4a3e7 commit 5f809cf

File tree

2 files changed

+86
-150
lines changed

2 files changed

+86
-150
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 47 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"""
2424

2525
import logging
26-
from typing import TYPE_CHECKING, Optional, Union
26+
from typing import Union
2727

2828
import numpy as np
2929

@@ -65,21 +65,17 @@
6565
)
6666
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6767
from pytensor.tensor.exceptions import NotScalarConstantError
68-
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
68+
from pytensor.tensor.extra_ops import broadcast_arrays
6969
from pytensor.tensor.math import Sum, add
7070
from pytensor.tensor.math import all as at_all
7171
from pytensor.tensor.math import eq
72-
from pytensor.tensor.shape import Shape_i
72+
from pytensor.tensor.shape import Shape_i, shape_padleft
7373
from pytensor.tensor.sort import TopKOp
7474
from pytensor.tensor.type import DenseTensorType, TensorType
7575
from pytensor.tensor.var import TensorConstant, TensorVariable
7676
from pytensor.utils import NoDuplicateOptWarningFilter
7777

7878

79-
if TYPE_CHECKING:
80-
from pytensor.tensor.rewriting.shape import ShapeFeature
81-
82-
8379
_logger = logging.getLogger("pytensor.tensor.rewriting.basic")
8480
_logger.addFilter(NoDuplicateOptWarningFilter())
8581

@@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node):
261257
def local_elemwise_alloc(fgraph, node):
262258
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
263259
264-
`Alloc`\s are effectively a type of `Elemwise` operation
265-
(e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
266-
this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
267-
`Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
268-
broadcasts).
269-
270-
In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
271-
`Alloc`\s.
272-
273260
The rewrite essentially performs the following replacement:
274-
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
275-
when ``y.shape`` for some input ``y`` (or the combined shapes of the
276-
non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
261+
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
277262
278-
In it's current form, it also explicitly accounts for `DimShuffle`\s of
263+
In its current form, it also explicitly accounts for `DimShuffle`\s of
279264
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
280265
introduces them as a canonicalization of `Alloc`'s with leading
281266
broadcastable dimensions.
282267
"""
283-
# Rewrite is only applicable when there are at least two inputs
284268
if len(node.inputs) == 1:
285-
return False
286-
287-
if len(node.outputs) > 1:
288-
return False
269+
return None
289270

290271
def dimshuffled_alloc(i):
291272
return (
@@ -305,76 +286,40 @@ def dimshuffled_alloc(i):
305286
if len(alloc_idxs) == 0:
306287
return False
307288

308-
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
309-
# baseline for the dimensions.
310-
ref_var_idx = None
311-
for idx, i in enumerate(node.inputs):
312-
if i.type.broadcastable == node.outputs[0].type.broadcastable:
313-
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
314-
# `Alloc`, so that all `Alloc`s can be rewritten.
315-
if idx not in alloc_idxs:
316-
ref_var_idx = idx
317-
break
318-
319-
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
320-
if ref_var_idx is None:
321-
for idx, i in enumerate(node.inputs):
322-
# XXX: This broadcastable comparison doesn't work
323-
if (
324-
i.type.broadcastable == node.outputs[0].type.broadcastable
325-
) and idx in alloc_idxs:
326-
ref_var_idx = idx
327-
break
328-
329-
if not hasattr(fgraph, "shape_feature"):
330-
return False
331-
332-
input_shapes = [
333-
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
334-
for i in node.inputs
335-
]
336-
bcasted_shape = broadcast_shape(
337-
*input_shapes,
338-
arrays_are_shapes=True,
339-
)
340-
341289
new_inputs = list(node.inputs)
342290
for idx in alloc_idxs:
343291
i = node.inputs[idx]
344292

345-
# Remove `Alloc`
293+
# Remove simple `Alloc`
346294
if isinstance(i.owner.op, Alloc):
347-
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
295+
new_inp = i.owner.inputs[0]
348296

349-
# TODO FIXME: This shouldn't be handled here.
350-
# `DimShuffle`s should be lifted through `Alloc`s
351-
# by other, more general rewrites.
352-
# Remove `Alloc` in `DimShuffle`
297+
# Remove `Dimshuffle(Alloc)`
353298
elif isinstance(i.owner.op, DimShuffle):
354299
old_alloc = i.owner.inputs[0]
355-
new_alloc = old_alloc.owner.inputs[0]
300+
old_alloc_inp = old_alloc.owner.inputs[0]
301+
missing_ndims = old_alloc.type.ndim - old_alloc_inp.type.ndim
302+
if missing_ndims > 0:
303+
# The `Alloc` added new dimensions to the left.
304+
# We replace those cases with a `DimShuffle` here.
305+
# Nested dimshuffles will be merged later by other rewrites.
306+
old_alloc_inp = shape_padleft(old_alloc_inp, missing_ndims)
356307
# We need to keep the old `DimShuffle`. It could swap axes or
357308
# add dimensions anywhere.
358-
if new_alloc.ndim != old_alloc.ndim:
359-
# The `Alloc` can add dimensions to the value.
360-
# We replace those cases with a `DimShuffle` here.
361-
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
362-
new_alloc = new_alloc.dimshuffle(
363-
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
364-
)
365-
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
309+
new_inp = i.owner.op(old_alloc_inp)
366310

367-
copy_stack_trace(i, new_alloc)
368-
new_inputs[idx] = new_alloc
311+
copy_stack_trace(i, new_inp)
312+
new_inputs[idx] = new_inp
369313

370-
# If this assert is triggered, it means we are recreating an equivalent graph
371-
# which would result in cyclical merge rewrites.
372-
if all(new is old for new, old in zip(new_inputs, node.inputs)):
373-
return
314+
new_outs = node.op(*new_inputs, return_list=True)
374315

375-
ret = node.op(*new_inputs, return_list=True)
376-
copy_stack_trace(node.outputs, ret)
377-
return ret
316+
if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable:
317+
new_outs = [
318+
alloc_like(new_out, node.outputs[0], fgraph) for new_out in new_outs
319+
]
320+
321+
copy_stack_trace(node.outputs, new_outs)
322+
return new_outs
378323

379324

380325
@register_canonicalize("shape_unsafe")
@@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node):
406351

407352
# The newly created node c doesn't has 'clients',
408353
# so this iteration is took place with node.outputs[0]
354+
# TODO: This should just be a WalkingGraphRewrite!
409355
replacements = {node.outputs[0]: c}
410356
for client, cl_idx in fgraph.clients[node.outputs[0]]:
411357
if (
@@ -438,23 +384,15 @@ def local_fill_to_alloc(fgraph, node):
438384
with their dependencies on those tensors' shapes, and sometimes those
439385
shapes can be computed without needing to compute the tensors themselves.
440386
441-
XXX: This rewrite can produce inconsistent results, so do *not* consider
442-
making it a canonicalization until those inconsistencies are
443-
resolved/justified.
387+
Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
388+
which could mask shape errors.
444389
"""
445390
shape_ref, values_ref = node.inputs
446391
out_type = node.outputs[0].type
447392

448393
if values_ref.type.broadcastable == out_type.broadcastable:
449394
# The assumption here is that `values_ref` already has the same shape
450395
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
451-
452-
# XXX FIXME TODO: The only way this can be determined is if one
453-
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
454-
# equal.
455-
# This is an old rewrite, and it's only a
456-
# "specialization/stabilization", so we're going to leave it be for
457-
# now.
458396
return [values_ref]
459397

460398
if shape_ref.type.broadcastable == out_type.broadcastable:
@@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node):
465403
copy_stack_trace(node.outputs[0], o)
466404
return [o]
467405

406+
# The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
407+
# TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))
408+
468409
return
469410

470411

@@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node):
1014955
return [element_sum]
1015956

1016957

1017-
@register_useless("local_remove_switch_const_cond")
1018-
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
1019-
@register_specialize
1020-
@node_rewriter([Elemwise])
958+
@register_useless("shape_unsafe")
959+
@register_canonicalize("fast_compile", "shape_unsafe")
960+
@register_specialize("shape_unsafe")
961+
@node_rewriter([switch])
1021962
def local_useless_switch(fgraph, node):
1022963
"""
1023964
This rewrite makes the following changes in a graph:
1024965
1025-
at.switch(cond, left, right) ->
1026-
if cond is constant and cond == 0: right
1027-
if cond is constant and cond != 0: left
1028-
if left is right -> left
966+
switch(cond, left, right) ->
967+
if cond is constant and cond == 0: right
968+
if cond is constant and cond != 0: left
969+
if left is right -> left
1029970
1030971
and
1031972
1032-
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
973+
switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
1033974
1034975
"""
1035-
if not isinstance(node.op.scalar_op, aes.Switch):
1036-
return False
1037-
1038-
shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None)
1039-
1040-
if shape_feature is None:
1041-
return False
1042976

1043977
left = node.inputs[1]
1044978
right = node.inputs[2]
1045979
cond_var = node.inputs[0]
1046980
cond = extract_constant(cond_var, only_process_constants=True)
981+
out_bcast = node.outputs[0].type.broadcastable
1047982

1048983
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
1049984
cond, (np.number, np.bool_)
@@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node):
1058993
else:
1059994
out = correct_out
1060995

1061-
input_shapes = [
1062-
tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim))
1063-
for inp in node.inputs
1064-
]
1065-
1066-
out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True)
1067-
1068-
out = alloc(out, *out_shape)
996+
if out.type.broadcastable != out_bcast:
997+
out = broadcast_arrays(out, *node.inputs)[0]
1069998

1070999
# Copy over stacktrace from selected output to new output
10711000
copy_stack_trace(node.outputs + correct_out, out)
@@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node):
10751004
if left == right:
10761005
# Note: No need to copy over stacktrace, because the input node
10771006
# already has its own stacktrace
1078-
if cond.type.is_super(left.type):
1007+
if left.type.broadcastable == out_bcast:
10791008
return [left]
10801009

1081-
ret = fill(cond, left)
1010+
ret = broadcast_arrays(left, cond)[0]
10821011

10831012
# Copy over stacktrace from switch output and correct branch
10841013
copy_stack_trace(node.outputs + left, ret)

0 commit comments

Comments
 (0)