diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 84447670c2..efe92fe367 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,8 +1,20 @@ from pytensor.graph import node_rewriter -from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape +from pytensor.tensor import ( + broadcast_to, + join, + moveaxis, + specify_shape, + squeeze, +) from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack +from pytensor.xtensor.shape import ( + Concat, + Squeeze, + Stack, + Transpose, + UnStack, +) @register_xcanonicalize @@ -105,3 +117,18 @@ def lower_transpose(fgraph, node): x_tensor_transposed = x_tensor.transpose(perm) new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter([Squeeze]) +def local_squeeze_reshape(fgraph, node): + """Rewrite Squeeze to tensor.squeeze.""" + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + x_dims = x.type.dims + dims_to_remove = node.op.dims + axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove) + x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze) + + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 38b702db84..329e8a5163 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -301,3 +301,82 @@ def make_node(self, *inputs: Variable) -> Apply: def concat(xtensors, dim: str): return Concat(dim=dim)(*xtensors) + + +class Squeeze(XOp): + """Remove specified dimensions from an XTensorVariable. + + Only dimensions that are known statically to be size 1 will be removed. + Symbolic dimensions must be explicitly specified, and are assumed safe. + + Parameters + ---------- + dim : tuple of str + The names of the dimensions to remove. + """ + + __props__ = ("dims",) + + def __init__(self, dims): + self.dims = tuple(sorted(set(dims))) + + def make_node(self, x): + x = as_xtensor(x) + + # Validate that dims exist and are size-1 if statically known + dims_to_remove = [] + x_dims = x.type.dims + x_shape = x.type.shape + for d in self.dims: + if d not in x_dims: + raise ValueError(f"Dimension {d} not found in {x.type.dims}") + idx = x_dims.index(d) + dim_size = x_shape[idx] + if dim_size is not None and dim_size != 1: + raise ValueError(f"Dimension {d} has static size {dim_size}, not 1") + dims_to_remove.append(idx) + + new_dims = tuple( + d for i, d in enumerate(x.type.dims) if i not in dims_to_remove + ) + new_shape = tuple( + s for i, s in enumerate(x.type.shape) if i not in dims_to_remove + ) + + out = xtensor( + dtype=x.type.dtype, + shape=new_shape, + dims=new_dims, + ) + return Apply(self, [x], [out]) + + +def squeeze(x, dim=None): + """Remove dimensions of size 1 from an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str or None or iterable of str, optional + The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 + (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. + + Returns + ------- + XTensorVariable + A new tensor with the specified dimension(s) removed. + """ + x = as_xtensor(x) + + if dim is None: + dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1) + elif isinstance(dim, str): + dims = (dim,) + else: + dims = tuple(dim) + + if not dims: + return x # no-op if nothing to squeeze + + return Squeeze(dims=dims)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index d3a5724f11..fcb7e03757 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,10 +8,17 @@ from itertools import chain, combinations import numpy as np +import pytest from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack, transpose, unstack +from pytensor.xtensor.shape import ( + concat, + squeeze, + stack, + transpose, + unstack, +) from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -21,6 +28,9 @@ ) +pytest.importorskip("xarray") + + def powerset(iterable, min_group_size=0): "Subsequences of the iterable from shortest to longest." # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) @@ -254,3 +264,109 @@ def test_concat_scalar(): res = fn(x1_test, x2_test) expected_res = xr_concat([x1_test, x2_test], dim="new_dim") xr_assert_allclose(res, expected_res) + + +def test_squeeze_explicit_dims(): + """Test squeeze with explicit dimension(s).""" + + # Single dimension + x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) + y1 = squeeze(x1, "country") + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country")) + + # Multiple dimensions + x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) + y2 = squeeze(x2, ["b", "c"]) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"])) + + # Order independence + x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1)) + y3a = squeeze(x3, ["b", "c"]) + y3b = squeeze(x3, ["c", "b"]) + fn3a = xr_function([x3], y3a) + fn3b = xr_function([x3], y3b) + x3_test = xr_arange_like(x3) + xr_assert_allclose(fn3a(x3_test), fn3b(x3_test)) + + # Redundant dimensions + y3c = squeeze(x3, ["b", "b"]) + fn3c = xr_function([x3], y3c) + xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"])) + + # Empty list = no-op + y3d = squeeze(x3, []) + fn3d = xr_function([x3], y3d) + xr_assert_allclose(fn3d(x3_test), x3_test) + + +def test_squeeze_implicit_dims(): + """Test squeeze with implicit dim=None (all size-1 dimensions).""" + + # All dimensions size 1 + x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1)) + y1 = squeeze(x1) + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), x1_test.squeeze()) + + # No dimensions size 1 = no-op + x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4)) + y2 = squeeze(x2) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test) + + # Symbolic shape where runtime shape is 1 → should squeeze + x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown + y3 = squeeze(x3, "b") + x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3))) + fn3 = xr_function([x3], y3) + xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b")) + + # Mixed static + symbolic shapes, where symbolic shape is 1 + x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) + y4 = squeeze(x4, "b") + x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) + fn4 = xr_function([x4], y4) + xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) + + """ + This test documents that we intentionally don't squeeze dimensions with symbolic shapes + (static_shape=None) even when they are 1 at runtime, while xarray does squeeze them. + """ + # Create a tensor with a symbolic dimension that will be 1 at runtime + x = xtensor("x", dims=("a", "b", "c")) # shape unknown + y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3))) + fn = xr_function([x], y) + res = fn(x_test) + + # Our implementation should not squeeze the symbolic dimension + assert "b" in res.dims + # While xarray would squeeze it + assert "b" not in x_test.squeeze().dims + + +def test_squeeze_errors(): + """Test error cases for squeeze.""" + + # Non-existent dimension + x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) + with pytest.raises(ValueError, match="Dimension .* not found"): + squeeze(x1, "time") + + # Dimension size > 1 + with pytest.raises(ValueError, match="has static size .* not 1"): + squeeze(x1, "city") + + # Symbolic shape: dim is not 1 at runtime → should raise + x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown + y2 = squeeze(x2, "b") + x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3))) + fn2 = xr_function([x2], y2) + with pytest.raises(Exception): + fn2(x2_test)