Skip to content

Add squeeze for labeled tensors #1434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b3e859c
Add ExpandDims and Squeeze operations with tests and rewrite rules
AllenDowney May 30, 2025
d824870
fix: match xarray's expand_dims behavior with None dimension - Update…
AllenDowney May 30, 2025
a076966
Simplify local_squeeze_reshape function by removing redundant checks …
AllenDowney May 30, 2025
e2ffe1c
Ruff
AllenDowney May 30, 2025
7489489
Restoring the commit with xarray-based tests
AllenDowney Jun 2, 2025
22adb6f
Merge fork/add_expand_dims_squeeze, keeping our changes
AllenDowney Jun 2, 2025
a7e2bf8
Updating squeeze
AllenDowney Jun 2, 2025
332139d
Working on expand_dims
AllenDowney Jun 2, 2025
4b2f0f7
Working on squeeze
AllenDowney Jun 3, 2025
2120b1a
Organizing squeeze tests
AllenDowney Jun 4, 2025
2bb1fce
Working on expand_dims
AllenDowney Jun 4, 2025
dd13fc7
Merge branch 'labeled_tensors' into add_expand_dims_squeeze
AllenDowney Jun 4, 2025
7a308b9
lint
AllenDowney Jun 4, 2025
9c1a0b7
Update pytensor/xtensor/shape.py
AllenDowney Jun 5, 2025
1024798
Cleaning up squeeze
AllenDowney Jun 5, 2025
260b9b6
Removing expand_dims
AllenDowney Jun 5, 2025
915a368
Removing unneded check
AllenDowney Jun 5, 2025
05dac9e
Update pytensor/xtensor/shape.py
AllenDowney Jun 5, 2025
98d297e
Update pytensor/xtensor/shape.py
AllenDowney Jun 5, 2025
3202c4c
Update pytensor/xtensor/rewriting/shape.py
AllenDowney Jun 5, 2025
8d4fdd5
Update pytensor/xtensor/shape.py
AllenDowney Jun 5, 2025
f000bbb
All but one requested change
AllenDowney Jun 5, 2025
2de9566
Picking a nit
AllenDowney Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
from pytensor.tensor import (
broadcast_to,
join,
moveaxis,
specify_shape,
squeeze,
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
from pytensor.xtensor.shape import (
Concat,
Squeeze,
Stack,
Transpose,
UnStack,
)


@register_xcanonicalize
Expand Down Expand Up @@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]


@register_xcanonicalize
@node_rewriter([Squeeze])
def local_squeeze_reshape(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
x_dims = x.type.dims
dims_to_remove = node.op.dims
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]
79 changes: 79 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,82 @@ def make_node(self, *inputs: Variable) -> Apply:

def concat(xtensors, dim: str):
return Concat(dim=dim)(*xtensors)


class Squeeze(XOp):
"""Remove specified dimensions from an XTensorVariable.

Only dimensions that are known statically to be size 1 will be removed.
Symbolic dimensions must be explicitly specified, and are assumed safe.

Parameters
----------
dim : tuple of str
The names of the dimensions to remove.
"""

__props__ = ("dims",)

def __init__(self, dims):
self.dims = tuple(sorted(set(dims)))

def make_node(self, x):
x = as_xtensor(x)

# Validate that dims exist and are size-1 if statically known
dims_to_remove = []
x_dims = x.type.dims
x_shape = x.type.shape
for d in self.dims:
if d not in x_dims:
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
idx = x_dims.index(d)
dim_size = x_shape[idx]
if dim_size is not None and dim_size != 1:
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
dims_to_remove.append(idx)

new_dims = tuple(
d for i, d in enumerate(x.type.dims) if i not in dims_to_remove
)
new_shape = tuple(
s for i, s in enumerate(x.type.shape) if i not in dims_to_remove
)

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x], [out])


def squeeze(x, dim=None):
"""Remove dimensions of size 1 from an XTensorVariable.

Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.

Returns
-------
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
x = as_xtensor(x)

if dim is None:
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
elif isinstance(dim, str):
dims = (dim,)
else:
dims = tuple(dim)

if not dims:
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)
118 changes: 117 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
from itertools import chain, combinations

import numpy as np
import pytest
from xarray import DataArray
from xarray import concat as xr_concat

from pytensor.xtensor.shape import concat, stack, transpose, unstack
from pytensor.xtensor.shape import (
concat,
squeeze,
stack,
transpose,
unstack,
)
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import (
xr_arange_like,
Expand All @@ -21,6 +28,9 @@
)


pytest.importorskip("xarray")


def powerset(iterable, min_group_size=0):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
Expand Down Expand Up @@ -254,3 +264,109 @@ def test_concat_scalar():
res = fn(x1_test, x2_test)
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
xr_assert_allclose(res, expected_res)


def test_squeeze_explicit_dims():
"""Test squeeze with explicit dimension(s)."""

# Single dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
y1 = squeeze(x1, "country")
fn1 = xr_function([x1], y1)
x1_test = xr_arange_like(x1)
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))

# Multiple dimensions
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
y2 = squeeze(x2, ["b", "c"])
fn2 = xr_function([x2], y2)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"]))

# Order independence
x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1))
y3a = squeeze(x3, ["b", "c"])
y3b = squeeze(x3, ["c", "b"])
fn3a = xr_function([x3], y3a)
fn3b = xr_function([x3], y3b)
x3_test = xr_arange_like(x3)
xr_assert_allclose(fn3a(x3_test), fn3b(x3_test))
Comment on lines +286 to +293
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine this with the previous test. Test both of them against xarray. You don't need one function per case, the function can have two outputs, which should be a faster test, as it only trigger the compilation machinery once.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have too many questions about this comment. If you want to make this change after merging, that might be more efficient than explaining.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your previous check was already testing a squeeze of multiple dimensions, so you can combine this which also checks multiple dimensions + the fact that order doesn't matter. This test is a superset of the previous one.

Then the point about combining multiple outputs is to do xr_function([x3], [y3a, y3b]) instead of defining two separate functions. It's a small optimization question, although it has the side-benefit of testing multiple outputs aren't messed up either.


# Redundant dimensions
y3c = squeeze(x3, ["b", "b"])
fn3c = xr_function([x3], y3c)
xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"]))

# Empty list = no-op
y3d = squeeze(x3, [])
fn3d = xr_function([x3], y3d)
xr_assert_allclose(fn3d(x3_test), x3_test)


def test_squeeze_implicit_dims():
"""Test squeeze with implicit dim=None (all size-1 dimensions)."""

# All dimensions size 1
x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1))
y1 = squeeze(x1)
fn1 = xr_function([x1], y1)
x1_test = xr_arange_like(x1)
xr_assert_allclose(fn1(x1_test), x1_test.squeeze())

# No dimensions size 1 = no-op
x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4))
y2 = squeeze(x2)
fn2 = xr_function([x2], y2)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn2(x2_test), x2_test)

# Symbolic shape where runtime shape is 1 → should squeeze
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't call these symbolic shapes. They are just unknown static shapes. It's confusing label

x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
y3 = squeeze(x3, "b")
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
fn3 = xr_function([x3], y3)
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))

# Mixed static + symbolic shapes, where symbolic shape is 1
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
y4 = squeeze(x4, "b")
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
fn4 = xr_function([x4], y4)
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
Comment on lines +330 to +335
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is not interesting, remove? Or what was the reason for it?


"""
This test documents that we intentionally don't squeeze dimensions with symbolic shapes
(static_shape=None) even when they are 1 at runtime, while xarray does squeeze them.
"""
# Create a tensor with a symbolic dimension that will be 1 at runtime
x = xtensor("x", dims=("a", "b", "c")) # shape unknown
y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3)))
fn = xr_function([x], y)
res = fn(x_test)

# Our implementation should not squeeze the symbolic dimension
assert "b" in res.dims
# While xarray would squeeze it
assert "b" not in x_test.squeeze().dims


def test_squeeze_errors():
"""Test error cases for squeeze."""

# Non-existent dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
with pytest.raises(ValueError, match="Dimension .* not found"):
squeeze(x1, "time")

# Dimension size > 1
with pytest.raises(ValueError, match="has static size .* not 1"):
squeeze(x1, "city")

# Symbolic shape: dim is not 1 at runtime → should raise
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
y2 = squeeze(x2, "b")
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)
Loading