Skip to content

Commit a7e2bf8

Browse files
committed
Updating squeeze
1 parent 22adb6f commit a7e2bf8

File tree

3 files changed

+122
-47
lines changed

3 files changed

+122
-47
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -154,38 +154,28 @@ def local_expand_dims_reshape(fgraph, node):
154154
@register_xcanonicalize
155155
@node_rewriter([Squeeze])
156156
def local_squeeze_reshape(fgraph, node):
157-
"""Rewrite rule to convert squeeze to pytensor.tensor.squeeze."""
157+
"""Rewrite rule to convert Squeeze to pytensor.tensor.squeeze."""
158158
if not isinstance(node.op, Squeeze):
159159
return False
160160

161-
x = node.inputs[0]
161+
[x] = node.inputs
162+
in_dims = x.type.dims
162163
dim = node.op.dim
163164

164-
# Convert single dimension to iterable for consistent handling
165-
dims_to_remove = [dim] if isinstance(dim, str) else dim
166-
167-
if dims_to_remove is not None:
168-
# Validate dimensions exist and have size 1
169-
dim_indices = []
170-
for d in dims_to_remove:
171-
if d not in x.type.dims:
172-
return False
173-
dim_idx = x.type.dims.index(d)
174-
# Only check shape != 1 if the shape is not None (symbolic)
175-
if x.type.shape[dim_idx] is not None and x.type.shape[dim_idx] != 1:
176-
return False
177-
dim_indices.append(dim_idx)
165+
# Determine which axes to squeeze
166+
if dim is None:
167+
# Infer axes by comparing input and output dims
168+
out_dims = node.outputs[0].type.dims
169+
axes_to_squeeze = tuple(i for i, d in enumerate(in_dims) if d not in out_dims)
178170
else:
179-
# Find all dimensions of size 1
180-
dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1]
181-
if not dim_indices:
182-
return False
171+
dims_to_remove = [dim] if isinstance(dim, str) else dim
172+
axes_to_squeeze = tuple(in_dims.index(d) for d in dims_to_remove)
183173

184-
# Create new dimensions list
185-
new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices]
174+
# Nothing to squeeze? Just return input unchanged
175+
if not axes_to_squeeze:
176+
return [x]
186177

187-
# Convert to tensor and use pytensor.tensor.squeeze
188178
x_tensor = tensor_from_xtensor(x)
189-
x_tensor_squeezed = squeeze(x_tensor, axis=tuple(dim_indices))
190-
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=tuple(new_dims))
179+
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
180+
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
191181
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -364,45 +364,62 @@ class Squeeze(XOp):
364364
365365
Parameters
366366
----------
367-
dim : str or None or iterable of str
368-
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 will be removed.
367+
dim : str, None, or iterable of str
368+
The name(s) of the dimension(s) to remove. If None, all dimensions
369+
that are statically known to have size 1 will be removed.
370+
Dimensions with symbolic shape will not be removed unless explicitly named.
371+
372+
Note: Unlike NumPy/xarray, if dim is None, only dimensions known to
373+
be size 1 at graph construction time will be removed, even if they happen
374+
to be size 1 at runtime.
369375
"""
370376

377+
__props__ = ("dim",)
378+
371379
def __init__(self, dim=None):
372-
self.dim = dim
380+
if dim is None:
381+
self.dim = None
382+
else:
383+
dims = [dim] if isinstance(dim, str) else dim
384+
if not all(isinstance(d, str) for d in dims):
385+
raise TypeError(f"All dimension names must be strings: got {dims}")
386+
# Deduplicate and sort to make __props__ deterministic and hashable
387+
self.dim = tuple(sorted(set(dims)))
388+
389+
if not self.dim:
390+
warnings.warn(
391+
"Squeeze received an empty dim list — no dimensions will be removed."
392+
)
373393

374394
def make_node(self, x):
375395
x = as_xtensor(x)
376396

377-
# Convert single dimension to iterable for consistent handling
378-
dims_to_remove = [self.dim] if isinstance(self.dim, str) else self.dim
397+
if self.dim is None:
398+
# Auto-detect static size-1 dimensions
399+
dims_to_remove = [d for d, s in zip(x.type.dims, x.type.shape) if s == 1]
400+
if not dims_to_remove:
401+
raise ValueError("No dimensions of size 1 to remove")
402+
else:
403+
dims_to_remove = list(self.dim)
379404

380-
if dims_to_remove is not None:
381-
# Validate dimensions exist and have size 1
405+
# Validate existence and static shape (when possible)
382406
for dim in dims_to_remove:
383407
if dim not in x.type.dims:
384408
raise ValueError(f"Dimension {dim} not found")
385409
dim_idx = x.type.dims.index(dim)
386-
# Only raise an error if the shape is statically known and not 1.
387-
# If the shape is None (symbolic), defer the error to runtime.
388-
if x.type.shape[dim_idx] is not None and x.type.shape[dim_idx] != 1:
389-
raise ValueError(
390-
f"Dimension {dim} has size {x.type.shape[dim_idx]}, not 1"
391-
)
392-
# Get indices of dimensions to remove
393-
dim_indices = [x.type.dims.index(dim) for dim in dims_to_remove]
394-
else:
395-
# Find all dimensions of size 1
396-
dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1]
397-
if not dim_indices:
398-
raise ValueError("No dimensions of size 1 to remove")
410+
shape = x.type.shape[dim_idx]
411+
if shape is not None and shape != 1:
412+
raise ValueError(f"Dimension {dim} has size {shape}, not 1")
413+
414+
dim_indices = [x.type.dims.index(dim) for dim in dims_to_remove]
399415

400-
# Create new dimensions and shape lists
401416
new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices]
402417
new_shape = [s for i, s in enumerate(x.type.shape) if i not in dim_indices]
403418

404419
output = xtensor(
405-
dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims)
420+
dtype=x.type.dtype,
421+
shape=tuple(new_shape),
422+
dims=tuple(new_dims),
406423
)
407424
return Apply(self, [x], [output])
408425

tests/xtensor/test_shape.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,71 @@ def test_squeeze():
386386
x3d = xtensor("x3d", dims=("row", "col", "batch"), shape=(2, 3, 4))
387387
with pytest.raises(ValueError):
388388
squeeze(x3d)
389+
390+
391+
def test_squeeze_additional_cases():
392+
# Redundant dimensions: squeeze(["b", "b"]) should behave like squeeze(["b"])
393+
x1 = xtensor("x1", dims=("a", "b", "c"), shape=(2, 1, 1))
394+
y1 = squeeze(x1, ["b", "b"])
395+
fn1 = xr_function([x1], y1)
396+
x1_test = xr_arange_like(x1)
397+
expected1 = x1_test.squeeze(["b"])
398+
xr_assert_allclose(fn1(x1_test), expected1)
399+
400+
# Symbolic shape: dim is 1 at runtime → should squeeze successfully
401+
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
402+
y2 = squeeze(x2, "b")
403+
fn2 = xr_function([x2], y2)
404+
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 1, 3)))
405+
expected2 = x2_test.squeeze("b")
406+
xr_assert_allclose(fn2(x2_test), expected2)
407+
408+
# Symbolic shape: dim is not 1 at runtime → should raise
409+
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
410+
y3 = squeeze(x3, "b")
411+
fn3 = xr_function([x3], y3)
412+
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 2, 3)))
413+
with pytest.raises(Exception):
414+
fn3(x3_test)
415+
416+
# Reversibility: squeeze then expand_dims should restore original
417+
# TODO: uncomment when we have expand_dims
418+
# x4 = xtensor("x4", dims=("batch", "time", "feature"), shape=(2, 1, 3))
419+
# y4 = squeeze(x4, "time")
420+
# z4 = expand_dims(y4, "time")
421+
# fn4 = xr_function([x4], z4)
422+
# x4_test = xr_arange_like(x4)
423+
# xr_assert_allclose(fn4(x4_test), x4_test)
424+
425+
426+
def test_squeeze_extra_cases():
427+
# 1. Order of dims shouldn't affect result
428+
x1 = xtensor("x1", dims=("a", "b", "c"), shape=(2, 1, 1))
429+
y1 = squeeze(x1, ["b", "c"])
430+
y2 = squeeze(x1, ["c", "b"])
431+
fn1 = xr_function([x1], y1)
432+
fn2 = xr_function([x1], y2)
433+
x1_test = xr_arange_like(x1)
434+
xr_assert_allclose(fn1(x1_test), fn2(x1_test))
435+
436+
# 2. Empty list of dims = no-op
437+
x2 = xtensor("x2", dims=("a", "b", "c"), shape=(2, 1, 1))
438+
y2 = squeeze(x2, [])
439+
fn2 = xr_function([x2], y2)
440+
x2_test = xr_arange_like(x2)
441+
xr_assert_allclose(fn2(x2_test), x2_test)
442+
443+
# 3. Explicit squeeze of all size-1 dims via dim=None
444+
x3 = xtensor("x3", dims=("a", "b"), shape=(1, 1))
445+
y3 = squeeze(x3)
446+
fn3 = xr_function([x3], y3)
447+
x3_test = xr_arange_like(x3)
448+
xr_assert_allclose(fn3(x3_test), x3_test.squeeze())
449+
450+
# 4. Static + symbolic shape mix: squeeze symbolic 1-sized dim
451+
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
452+
y4 = squeeze(x4, "b")
453+
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
454+
fn4 = xr_function([x4], y4)
455+
expected4 = x4_test.squeeze("b")
456+
xr_assert_allclose(fn4(x4_test), expected4)

0 commit comments

Comments
 (0)