Skip to content

Commit 65b96c1

Browse files
committed
Canonicalize squeeze out of reshape and specialize back
1 parent dbf5f38 commit 65b96c1

File tree

4 files changed

+200
-58
lines changed

4 files changed

+200
-58
lines changed

pytensor/tensor/rewriting/shape.py

+133-48
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
register_useless,
3737
topo_constant_folding,
3838
)
39+
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
3940
from pytensor.tensor.shape import (
4041
Reshape,
4142
Shape,
@@ -757,40 +758,36 @@ def apply(self, fgraph):
757758
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
758759

759760

761+
@register_useless
760762
@register_canonicalize
761763
@node_rewriter([Reshape])
762-
def local_useless_dimshuffle_in_reshape(fgraph, node):
764+
def local_useless_expand_dims_in_reshape(fgraph, node):
763765
"""
764-
Removes useless DimShuffle operation inside Reshape:
765-
766-
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
767-
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
768-
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
769-
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
766+
Removes useless expand_dims `DimShuffle` operations inside Reshape:
767+
reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
768+
reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
770769
770+
Implicit (and useless) squeezes are kept in the graph, as they are
771+
part of the canonical form of the graph.
771772
"""
772-
dimshuffled_x, new_shape = node.inputs
773+
expanded_x, new_shape = node.inputs
773774

774775
if not (
775-
dimshuffled_x.owner is not None
776-
and isinstance(dimshuffled_x.owner.op, DimShuffle)
776+
expanded_x.owner is not None
777+
and isinstance(expanded_x.owner.op, DimShuffle)
778+
and expanded_x.owner.op.augment
777779
):
778780
return False
779781

780-
[inp] = dimshuffled_x.owner.inputs
781-
new_order = dimshuffled_x.owner.op.new_order
782-
new_order_of_nonbroadcast = []
783-
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
784-
if s != 1:
785-
new_order_of_nonbroadcast.append(i)
786-
no_change_in_order = all(
787-
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
788-
for i in range(len(new_order_of_nonbroadcast) - 1)
789-
)
790-
if no_change_in_order:
791-
ret = inp.reshape(new_shape)
792-
copy_stack_trace(node.outputs[0], ret)
793-
return [ret]
782+
[x] = expanded_x.owner.inputs
783+
784+
new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x")
785+
if new_order != tuple(range(x.type.ndim)):
786+
x = x.dimshuffle(new_order)
787+
788+
new_reshaped_x = x.reshape(new_shape)
789+
copy_stack_trace(node.outputs[0], new_reshaped_x)
790+
return [new_reshaped_x]
794791

795792

796793
@register_canonicalize("shape_unsafe")
@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node):
920917

921918
shape_feature = getattr(fgraph, "shape_feature", None)
922919

923-
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
924-
# or cases where all but one dimension are provably preserved
920+
# Match case where at least (n-1) entries correspond to the original shape:
921+
# Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]])
922+
# Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape.
925923
output_shape_is = _unpack_shape_vector(output_shape)
926-
927924
nb_m1 = 0
928925
shape_match = [False] * inp.type.ndim
929926
for dim in range(inp.type.ndim):
@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node):
935932
nb_m1 += 1
936933

937934
if nb_m1 <= 1 and all(shape_match):
938-
return [inp]
935+
return [inp] # This is provably correct
939936

940937
# There is one missing match, but all other dimensions match
938+
# Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y)
941939
if (nb_m1 == 0) and (shape_match.count(False) == 1):
942-
return [inp]
940+
return [inp] # This could mask a shape error
943941

944942
return False
945943

946944

947-
@register_canonicalize
945+
@register_canonicalize("shape_unsafe")
948946
@node_rewriter([Reshape])
949947
def local_reshape_to_dimshuffle(fgraph, node):
950-
r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s.
948+
r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions.
951949
952-
The goal is to avoid using `Reshape` to add or remove broadcastable
953-
dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can
954-
cancel out and/or be removed later on.
950+
It's always valid to squeeze an input before doing the same reshape operation.
951+
Equivalently, it's always valid to remove `1` entries from the reshape shape
952+
and replace them by an expand_dims after the rewritten reshape operation.
953+
954+
We chose to canonicalize the graph in this way as it allows isolating
955+
operations that are unique to the reshaping operation (mixing dimensions)
956+
from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims).
957+
This can allow further simplifications by other rewrites that target
958+
DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations.
955959
956960
For example:
957-
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
958-
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
961+
- reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
962+
- reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
963+
- reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
964+
959965
"""
960966
inp, output_shape = node.inputs
961967
[output] = node.outputs
962968

963-
unpacked_shape = _unpack_shape_vector(output_shape)
964-
expand_axes = []
965-
new_output_shape = []
966-
for i, dim in enumerate(unpacked_shape):
967-
if isinstance(dim, Constant) and dim.data == 1:
968-
expand_axes.append(i)
969-
else:
970-
new_output_shape.append(dim)
969+
# Remove any broadcastable dimensions from the input
970+
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
971+
972+
# Trivial case, all dimensions of input/output are known to be broadcastable:
973+
# there's nothing to reshape
974+
if all(inp.type.broadcastable) or all(output.type.broadcastable):
975+
new_output_shape = []
976+
expand_axes = tuple(range(output.type.ndim))
977+
978+
else:
979+
unpacked_shape = _unpack_shape_vector(output_shape)
980+
new_output_shape = []
981+
expand_axes = []
982+
for i, dim_length in enumerate(unpacked_shape):
983+
if isinstance(dim_length, Constant) and (
984+
dim_length.data == 1
985+
# -1 can be an implicit expand_dims, but it's tricky to prove
986+
# as we would need to check whether all other dimensions
987+
# already explain the full size of the array.
988+
# Example: np.zeros((2, 2, 2)).reshape((8, -1))
989+
# We rely on the output static shape which will already have figured
990+
# it out for some (but not all) cases
991+
or (dim_length.data == -1 and output.type.shape[i] == 1)
992+
):
993+
expand_axes.append(i)
994+
else:
995+
new_output_shape.append(dim_length)
996+
997+
if squeeze_axes or expand_axes:
998+
new_out = inp.squeeze(squeeze_axes)
999+
1000+
if new_output_shape:
1001+
new_out = new_out.reshape(new_output_shape)
1002+
copy_stack_trace(output, new_out)
1003+
1004+
new_out = expand_dims(new_out, expand_axes)
1005+
1006+
if not new_output_shape:
1007+
# Eagerly merge consecutive squeeze and expand_dims
1008+
new_out = apply_local_dimshuffle_lift(fgraph, new_out)
9711009

972-
if len(new_output_shape) != output.type.ndim:
973-
inner = inp.reshape(new_output_shape)
974-
copy_stack_trace(output, inner)
975-
new_out = expand_dims(inner, expand_axes)
9761010
copy_stack_trace(output, new_out)
9771011
return [new_out]
9781012

9791013

1014+
@register_specialize
1015+
@node_rewriter([Reshape])
1016+
def local_fuse_squeeze_reshape(fgraph, node):
1017+
r"""If there is a squeeze right before a reshape, merge them.
1018+
1019+
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1020+
"""
1021+
x, new_shape = node.inputs
1022+
1023+
if (
1024+
x.owner is not None
1025+
and isinstance(x.owner.op, DimShuffle)
1026+
and x.owner.op.is_squeeze
1027+
):
1028+
# A reshape can always subsume a squeeze.
1029+
x = x.owner.inputs[0]
1030+
return [x.reshape(new_shape)]
1031+
1032+
1033+
@register_specialize
1034+
@node_rewriter([DimShuffle])
1035+
def local_fuse_expand_dims_reshape(fgraph, node):
1036+
r"""If there is an expand_dims right after a reshape, merge them.
1037+
1038+
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1039+
"""
1040+
if not node.op.is_expand_dims:
1041+
return None
1042+
1043+
reshaped_x = node.inputs[0]
1044+
1045+
if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)):
1046+
return None
1047+
1048+
if len(fgraph.clients[reshaped_x]) > 1:
1049+
# The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
1050+
# Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]`
1051+
return None
1052+
1053+
x, new_shape = reshaped_x.owner.inputs
1054+
1055+
# Add expand_dims to shape
1056+
new_shape = list(_unpack_shape_vector(new_shape))
1057+
for i in node.op.augment:
1058+
new_shape.insert(i, 1)
1059+
1060+
new_reshaped_x = x.reshape(new_shape)
1061+
copy_stack_trace(node.outputs[0], new_reshaped_x)
1062+
return [new_reshaped_x]
1063+
1064+
9801065
@register_canonicalize
9811066
@register_specialize
9821067
@node_rewriter([Reshape])

tests/tensor/rewriting/test_basic.py

-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def test_basic_tile(self):
332332

333333
mode = rewrite_mode.including(
334334
"local_dimshuffle_lift",
335-
"local_useless_dimshuffle_in_reshape",
336335
"local_alloc_sink_dimshuffle",
337336
)
338337
f = function([x], [y], mode=mode)

tests/tensor/rewriting/test_elemwise.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@
5656
from pytensor.tensor.math import round as pt_round
5757
from pytensor.tensor.math import sum as pt_sum
5858
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
59-
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
59+
from pytensor.tensor.rewriting.shape import (
60+
local_fuse_squeeze_reshape,
61+
local_useless_expand_dims_in_reshape,
62+
)
6063
from pytensor.tensor.shape import reshape
6164
from pytensor.tensor.type import (
6265
TensorType,
@@ -182,7 +185,7 @@ def test_dimshuffle_lift_multi_out_elemwise(self):
182185
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
183186

184187

185-
def test_local_useless_dimshuffle_in_reshape():
188+
def test_local_useless_expand_dims_in_reshape():
186189
vec = TensorType(dtype="float64", shape=(None,))("vector")
187190
mat = TensorType(dtype="float64", shape=(None, None))("mat")
188191
row = TensorType(dtype="float64", shape=(1, None))("row")
@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape():
204207
clone=False,
205208
)
206209
assert len(g.apply_nodes) == 4 * 3
207-
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
210+
useless_dimshuffle_in_reshape = out2in(
211+
local_useless_expand_dims_in_reshape,
212+
# Useless squeeze in reshape is not a canonicalization anymore
213+
local_fuse_squeeze_reshape,
214+
)
208215
useless_dimshuffle_in_reshape.rewrite(g)
209216
assert equal_computations(
210217
g.outputs,
@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape():
218225
# Check stacktrace was copied over correctly after rewrite was applied
219226
assert check_stack_trace(g, ops_to_check="all")
220227

221-
# Check that the rewrite does not get applied when the order
222-
# of dimensions has changed.
228+
# Check that the rewrite does not mess meaningful transpositions before the reshape
223229
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
224230
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
225231
assert len(h.apply_nodes) == 3
226232
useless_dimshuffle_in_reshape.rewrite(h)
227-
assert equal_computations(
228-
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
229-
)
233+
assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)])
230234

231235

232236
class TestFusion:

tests/tensor/rewriting/test_shape.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor.tensor as pt
77
from pytensor import shared
88
from pytensor.compile.function import function
9-
from pytensor.compile.mode import get_default_mode, get_mode
9+
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1010
from pytensor.compile.ops import deep_copy_op
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply, Variable, equal_computations
@@ -426,6 +426,60 @@ def test_basic(self):
426426

427427
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
428428

429+
def test_expand_dims(self):
430+
x = pt.scalar()
431+
# This reshape does an implicit expand_dims
432+
out = x.reshape((1, -1))
433+
assert isinstance(out.owner.op, Reshape)
434+
new_out = rewrite_graph(out, include=("canonicalize",))
435+
assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))])
436+
437+
def test_squeeze_of_alloc(self):
438+
# This shows up in the graph of repeat
439+
x = pt.vector("x", shape=(9,))
440+
bcast_x = pt.alloc(x, 1, 12, x.shape[0])
441+
442+
# This reshape does an implicit squeeze
443+
out = bcast_x.reshape((12, x.shape[0]))
444+
445+
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
446+
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
447+
448+
449+
def test_expand_dims_squeeze_reshape_fusion():
450+
x = pt.tensor("x", shape=(1, 9))
451+
reshape_x = x.squeeze(0).reshape((3, 3))[..., None]
452+
453+
assert isinstance(reshape_x.owner.op, DimShuffle)
454+
assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape)
455+
assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle)
456+
457+
out = rewrite_graph(reshape_x, include=("specialize",))
458+
459+
# In this case we cannot get rid of the reshape, squeeze or expand_dims,
460+
# so we fuse them all in one reshape
461+
assert equal_computations([out], [x.reshape((3, 3, 1))])
462+
463+
464+
def test_implicit_broadcasting_via_repeat():
465+
x = pt.vector("x", shape=(3,), dtype=int)
466+
y = pt.vector("y", shape=(9,), dtype=int)
467+
out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1)
468+
# There are two Reshapes in the graph
469+
assert isinstance(out.owner.inputs[0].owner.op, Reshape)
470+
assert isinstance(out.owner.inputs[1].owner.op, Reshape)
471+
472+
new_out = rewrite_graph(out, include=("canonicalize", "specialize"))
473+
assert equal_computations([new_out], [x[None] <= y[:, None]])
474+
475+
no_rewrite_mode = Mode(linker="py", optimizer=None)
476+
x_test = np.arange(3) + 1
477+
y_test = np.arange(9)
478+
np.testing.assert_allclose(
479+
new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
480+
out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
481+
)
482+
429483

430484
def test_local_reshape_lift():
431485
x = tensor4()

0 commit comments

Comments
 (0)