Skip to content

Commit a076966

Browse files
committed
Simplify local_squeeze_reshape function by removing redundant checks and streamlining logic.
1 parent b3e859c commit a076966

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose, UnStack
5+
from pytensor.xtensor.shape import (
6+
Concat,
7+
ExpandDims,
8+
Squeeze,
9+
Stack,
10+
Transpose,
11+
UnStack,
12+
)
613

714

815
@register_xcanonicalize
@@ -143,11 +150,7 @@ def local_squeeze_reshape(fgraph, node):
143150

144151
# Get the index of the dimension to remove
145152
if dim is not None:
146-
if dim not in x.type.dims:
147-
return False
148153
dim_idx = x.type.dims.index(dim)
149-
if x.type.shape[dim_idx] != 1:
150-
return False
151154
else:
152155
# Find all dimensions of size 1
153156
dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1]

0 commit comments

Comments
 (0)