36
36
register_useless ,
37
37
topo_constant_folding ,
38
38
)
39
+ from pytensor .tensor .rewriting .elemwise import apply_local_dimshuffle_lift
39
40
from pytensor .tensor .shape import (
40
41
Reshape ,
41
42
Shape ,
@@ -757,40 +758,36 @@ def apply(self, fgraph):
757
758
pytensor .compile .mode .optdb .register ("UnShapeOpt" , UnShapeOptimizer (), position = 10 )
758
759
759
760
761
+ @register_useless
760
762
@register_canonicalize
761
763
@node_rewriter ([Reshape ])
762
- def local_useless_dimshuffle_in_reshape (fgraph , node ):
764
+ def local_useless_expand_dims_in_reshape (fgraph , node ):
763
765
"""
764
- Removes useless DimShuffle operation inside Reshape:
765
-
766
- reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
767
- reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
768
- reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
769
- reshape(col.dimshuffle(0), shp) => reshape(col, shp)
766
+ Removes useless expand_dims `DimShuffle` operations inside Reshape:
767
+ reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
768
+ reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
770
769
770
+ Implicit (and useless) squeezes are kept in the graph, as they are
771
+ part of the canonical form of the graph.
771
772
"""
772
- dimshuffled_x , new_shape = node .inputs
773
+ expanded_x , new_shape = node .inputs
773
774
774
775
if not (
775
- dimshuffled_x .owner is not None
776
- and isinstance (dimshuffled_x .owner .op , DimShuffle )
776
+ expanded_x .owner is not None
777
+ and isinstance (expanded_x .owner .op , DimShuffle )
778
+ and expanded_x .owner .op .augment
777
779
):
778
780
return False
779
781
780
- [inp ] = dimshuffled_x .owner .inputs
781
- new_order = dimshuffled_x .owner .op .new_order
782
- new_order_of_nonbroadcast = []
783
- for i , s in zip (new_order , node .inputs [0 ].type .shape , strict = True ):
784
- if s != 1 :
785
- new_order_of_nonbroadcast .append (i )
786
- no_change_in_order = all (
787
- new_order_of_nonbroadcast [i ] <= new_order_of_nonbroadcast [i + 1 ]
788
- for i in range (len (new_order_of_nonbroadcast ) - 1 )
789
- )
790
- if no_change_in_order :
791
- ret = inp .reshape (new_shape )
792
- copy_stack_trace (node .outputs [0 ], ret )
793
- return [ret ]
782
+ [x ] = expanded_x .owner .inputs
783
+
784
+ new_order = tuple (o for o in expanded_x .owner .op .new_order if o != "x" )
785
+ if new_order != tuple (range (x .type .ndim )):
786
+ x = x .dimshuffle (new_order )
787
+
788
+ new_reshaped_x = x .reshape (new_shape )
789
+ copy_stack_trace (node .outputs [0 ], new_reshaped_x )
790
+ return [new_reshaped_x ]
794
791
795
792
796
793
@register_canonicalize ("shape_unsafe" )
@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node):
920
917
921
918
shape_feature = getattr (fgraph , "shape_feature" , None )
922
919
923
- # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
924
- # or cases where all but one dimension are provably preserved
920
+ # Match case where at least (n-1) entries correspond to the original shape:
921
+ # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]])
922
+ # Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape.
925
923
output_shape_is = _unpack_shape_vector (output_shape )
926
-
927
924
nb_m1 = 0
928
925
shape_match = [False ] * inp .type .ndim
929
926
for dim in range (inp .type .ndim ):
@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node):
935
932
nb_m1 += 1
936
933
937
934
if nb_m1 <= 1 and all (shape_match ):
938
- return [inp ]
935
+ return [inp ] # This is provably correct
939
936
940
937
# There is one missing match, but all other dimensions match
938
+ # Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y)
941
939
if (nb_m1 == 0 ) and (shape_match .count (False ) == 1 ):
942
- return [inp ]
940
+ return [inp ] # This could mask a shape error
943
941
944
942
return False
945
943
946
944
947
- @register_canonicalize
945
+ @register_canonicalize ( "shape_unsafe" )
948
946
@node_rewriter ([Reshape ])
949
947
def local_reshape_to_dimshuffle (fgraph , node ):
950
- r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s .
948
+ r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions .
951
949
952
- The goal is to avoid using `Reshape` to add or remove broadcastable
953
- dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can
954
- cancel out and/or be removed later on.
950
+ It's always valid to squeeze an input before doing the same reshape operation.
951
+ Equivalently, it's always valid to remove `1` entries from the reshape shape
952
+ and replace them by an expand_dims after the rewritten reshape operation.
953
+
954
+ We chose to canonicalize the graph in this way as it allows isolating
955
+ operations that are unique to the reshaping operation (mixing dimensions)
956
+ from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims).
957
+ This can allow further simplifications by other rewrites that target
958
+ DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations.
955
959
956
960
For example:
957
- - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
958
- - reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
961
+ - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
962
+ - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
963
+ - reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
964
+
959
965
"""
960
966
inp , output_shape = node .inputs
961
967
[output ] = node .outputs
962
968
963
- unpacked_shape = _unpack_shape_vector (output_shape )
964
- expand_axes = []
965
- new_output_shape = []
966
- for i , dim in enumerate (unpacked_shape ):
967
- if isinstance (dim , Constant ) and dim .data == 1 :
968
- expand_axes .append (i )
969
- else :
970
- new_output_shape .append (dim )
969
+ # Remove any broadcastable dimensions from the input
970
+ squeeze_axes = [i for i , bcast in enumerate (inp .type .broadcastable ) if bcast ]
971
+
972
+ # Trivial case, all dimensions of input/output are known to be broadcastable:
973
+ # there's nothing to reshape
974
+ if all (inp .type .broadcastable ) or all (output .type .broadcastable ):
975
+ new_output_shape = []
976
+ expand_axes = tuple (range (output .type .ndim ))
977
+
978
+ else :
979
+ unpacked_shape = _unpack_shape_vector (output_shape )
980
+ new_output_shape = []
981
+ expand_axes = []
982
+ for i , dim_length in enumerate (unpacked_shape ):
983
+ if isinstance (dim_length , Constant ) and (
984
+ dim_length .data == 1
985
+ # -1 can be an implicit expand_dims, but it's tricky to prove
986
+ # as we would need to check whether all other dimensions
987
+ # already explain the full size of the array.
988
+ # Example: np.zeros((2, 2, 2)).reshape((8, -1))
989
+ # We rely on the output static shape which will already have figured
990
+ # it out for some (but not all) cases
991
+ or (dim_length .data == - 1 and output .type .shape [i ] == 1 )
992
+ ):
993
+ expand_axes .append (i )
994
+ else :
995
+ new_output_shape .append (dim_length )
996
+
997
+ if squeeze_axes or expand_axes :
998
+ new_out = inp .squeeze (squeeze_axes )
999
+
1000
+ if new_output_shape :
1001
+ new_out = new_out .reshape (new_output_shape )
1002
+ copy_stack_trace (output , new_out )
1003
+
1004
+ new_out = expand_dims (new_out , expand_axes )
1005
+
1006
+ if not new_output_shape :
1007
+ # Eagerly merge consecutive squeeze and expand_dims
1008
+ new_out = apply_local_dimshuffle_lift (fgraph , new_out )
971
1009
972
- if len (new_output_shape ) != output .type .ndim :
973
- inner = inp .reshape (new_output_shape )
974
- copy_stack_trace (output , inner )
975
- new_out = expand_dims (inner , expand_axes )
976
1010
copy_stack_trace (output , new_out )
977
1011
return [new_out ]
978
1012
979
1013
1014
+ @register_specialize
1015
+ @node_rewriter ([Reshape ])
1016
+ def local_fuse_squeeze_reshape (fgraph , node ):
1017
+ r"""If there is a squeeze right before a reshape, merge them.
1018
+
1019
+ This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1020
+ """
1021
+ x , new_shape = node .inputs
1022
+
1023
+ if (
1024
+ x .owner is not None
1025
+ and isinstance (x .owner .op , DimShuffle )
1026
+ and x .owner .op .is_squeeze
1027
+ ):
1028
+ # A reshape can always subsume a squeeze.
1029
+ x = x .owner .inputs [0 ]
1030
+ return [x .reshape (new_shape )]
1031
+
1032
+
1033
+ @register_specialize
1034
+ @node_rewriter ([DimShuffle ])
1035
+ def local_fuse_expand_dims_reshape (fgraph , node ):
1036
+ r"""If there is an expand_dims right after a reshape, merge them.
1037
+
1038
+ This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1039
+ """
1040
+ if not node .op .is_expand_dims :
1041
+ return None
1042
+
1043
+ reshaped_x = node .inputs [0 ]
1044
+
1045
+ if not (reshaped_x .owner and isinstance (reshaped_x .owner .op , Reshape )):
1046
+ return None
1047
+
1048
+ if len (fgraph .clients [reshaped_x ]) > 1 :
1049
+ # The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
1050
+ # Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]`
1051
+ return None
1052
+
1053
+ x , new_shape = reshaped_x .owner .inputs
1054
+
1055
+ # Add expand_dims to shape
1056
+ new_shape = list (_unpack_shape_vector (new_shape ))
1057
+ for i in node .op .augment :
1058
+ new_shape .insert (i , 1 )
1059
+
1060
+ new_reshaped_x = x .reshape (new_shape )
1061
+ copy_stack_trace (node .outputs [0 ], new_reshaped_x )
1062
+ return [new_reshaped_x ]
1063
+
1064
+
980
1065
@register_canonicalize
981
1066
@register_specialize
982
1067
@node_rewriter ([Reshape ])
0 commit comments